Skip to main content

entrenar/prune/trainer_integration/
config.rs

1//! Configuration for the prune-finetune trainer.
2//!
3//! # Toyota Way: Kaizen (Continuous Improvement)
4//! Fine-tuning allows the model to recover from pruning-induced accuracy loss.
5
6use crate::prune::config::PruningConfig;
7use crate::prune::data_loader::CalibrationDataConfig;
8use serde::{Deserialize, Serialize};
9
10/// Configuration for the prune-finetune trainer.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PruneTrainerConfig {
13    /// Pruning configuration.
14    pub pruning: PruningConfig,
15    /// Calibration data configuration.
16    pub calibration: CalibrationDataConfig,
17    /// Number of fine-tuning epochs after pruning.
18    pub finetune_epochs: usize,
19    /// Learning rate for fine-tuning.
20    pub finetune_lr: f32,
21    /// Whether to evaluate before and after pruning.
22    pub evaluate_pre_post: bool,
23    /// Checkpoint directory.
24    pub checkpoint_dir: Option<String>,
25    /// Whether to save intermediate checkpoints.
26    pub save_checkpoints: bool,
27}
28
29impl Default for PruneTrainerConfig {
30    fn default() -> Self {
31        Self {
32            pruning: PruningConfig::default(),
33            calibration: CalibrationDataConfig::default(),
34            finetune_epochs: 1,
35            finetune_lr: 1e-5,
36            evaluate_pre_post: true,
37            checkpoint_dir: None,
38            save_checkpoints: false,
39        }
40    }
41}
42
43impl PruneTrainerConfig {
44    /// Create a new configuration with default values.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Set the pruning configuration.
50    pub fn with_pruning(mut self, config: PruningConfig) -> Self {
51        self.pruning = config;
52        self
53    }
54
55    /// Set the calibration configuration.
56    pub fn with_calibration(mut self, config: CalibrationDataConfig) -> Self {
57        self.calibration = config;
58        self
59    }
60
61    /// Set the number of fine-tuning epochs.
62    pub fn with_finetune_epochs(mut self, epochs: usize) -> Self {
63        self.finetune_epochs = epochs;
64        self
65    }
66
67    /// Set the fine-tuning learning rate.
68    pub fn with_finetune_lr(mut self, lr: f32) -> Self {
69        self.finetune_lr = lr;
70        self
71    }
72
73    /// Enable or disable pre/post evaluation.
74    pub fn with_evaluate(mut self, enabled: bool) -> Self {
75        self.evaluate_pre_post = enabled;
76        self
77    }
78
79    /// Set the checkpoint directory.
80    pub fn with_checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
81        self.checkpoint_dir = Some(dir.into());
82        self
83    }
84
85    /// Enable or disable checkpoint saving.
86    pub fn with_save_checkpoints(mut self, enabled: bool) -> Self {
87        self.save_checkpoints = enabled;
88        self
89    }
90
91    /// Validate the configuration.
92    pub fn validate(&self) -> Result<(), String> {
93        self.pruning.validate()?;
94
95        if self.finetune_lr <= 0.0 {
96            return Err("finetune_lr must be positive".to_string());
97        }
98
99        Ok(())
100    }
101}