entrenar/prune/trainer_integration/
config.rs1use crate::prune::config::PruningConfig;
7use crate::prune::data_loader::CalibrationDataConfig;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PruneTrainerConfig {
13 pub pruning: PruningConfig,
15 pub calibration: CalibrationDataConfig,
17 pub finetune_epochs: usize,
19 pub finetune_lr: f32,
21 pub evaluate_pre_post: bool,
23 pub checkpoint_dir: Option<String>,
25 pub save_checkpoints: bool,
27}
28
29impl Default for PruneTrainerConfig {
30 fn default() -> Self {
31 Self {
32 pruning: PruningConfig::default(),
33 calibration: CalibrationDataConfig::default(),
34 finetune_epochs: 1,
35 finetune_lr: 1e-5,
36 evaluate_pre_post: true,
37 checkpoint_dir: None,
38 save_checkpoints: false,
39 }
40 }
41}
42
43impl PruneTrainerConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_pruning(mut self, config: PruningConfig) -> Self {
51 self.pruning = config;
52 self
53 }
54
55 pub fn with_calibration(mut self, config: CalibrationDataConfig) -> Self {
57 self.calibration = config;
58 self
59 }
60
61 pub fn with_finetune_epochs(mut self, epochs: usize) -> Self {
63 self.finetune_epochs = epochs;
64 self
65 }
66
67 pub fn with_finetune_lr(mut self, lr: f32) -> Self {
69 self.finetune_lr = lr;
70 self
71 }
72
73 pub fn with_evaluate(mut self, enabled: bool) -> Self {
75 self.evaluate_pre_post = enabled;
76 self
77 }
78
79 pub fn with_checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
81 self.checkpoint_dir = Some(dir.into());
82 self
83 }
84
85 pub fn with_save_checkpoints(mut self, enabled: bool) -> Self {
87 self.save_checkpoints = enabled;
88 self
89 }
90
91 pub fn validate(&self) -> Result<(), String> {
93 self.pruning.validate()?;
94
95 if self.finetune_lr <= 0.0 {
96 return Err("finetune_lr must be positive".to_string());
97 }
98
99 Ok(())
100 }
101}