use crate::prune::config::PruningConfig;
use crate::prune::data_loader::CalibrationDataConfig;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruneTrainerConfig {
pub pruning: PruningConfig,
pub calibration: CalibrationDataConfig,
pub finetune_epochs: usize,
pub finetune_lr: f32,
pub evaluate_pre_post: bool,
pub checkpoint_dir: Option<String>,
pub save_checkpoints: bool,
}
impl Default for PruneTrainerConfig {
fn default() -> Self {
Self {
pruning: PruningConfig::default(),
calibration: CalibrationDataConfig::default(),
finetune_epochs: 1,
finetune_lr: 1e-5,
evaluate_pre_post: true,
checkpoint_dir: None,
save_checkpoints: false,
}
}
}
impl PruneTrainerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_pruning(mut self, config: PruningConfig) -> Self {
self.pruning = config;
self
}
pub fn with_calibration(mut self, config: CalibrationDataConfig) -> Self {
self.calibration = config;
self
}
pub fn with_finetune_epochs(mut self, epochs: usize) -> Self {
self.finetune_epochs = epochs;
self
}
pub fn with_finetune_lr(mut self, lr: f32) -> Self {
self.finetune_lr = lr;
self
}
pub fn with_evaluate(mut self, enabled: bool) -> Self {
self.evaluate_pre_post = enabled;
self
}
pub fn with_checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
self.checkpoint_dir = Some(dir.into());
self
}
pub fn with_save_checkpoints(mut self, enabled: bool) -> Self {
self.save_checkpoints = enabled;
self
}
pub fn validate(&self) -> Result<(), String> {
self.pruning.validate()?;
if self.finetune_lr <= 0.0 {
return Err("finetune_lr must be positive".to_string());
}
Ok(())
}
}