Skip to main content

entrenar/hf_pipeline/trainer/
config.rs

1//! Distillation training configuration.
2
3use crate::hf_pipeline::distillation::{
4    AttentionTransfer, DistillationLoss, ProgressiveDistillation,
5};
6use crate::hf_pipeline::fine_tune::FineTuneConfig;
7use std::path::PathBuf;
8
9/// Default number of training steps between checkpoint saves.
10const DEFAULT_SAVE_STEPS: usize = 500;
11
12/// Distillation training configuration
13#[derive(Debug, Clone)]
14pub struct TrainerConfig {
15    /// Teacher model ID or path
16    pub teacher_model: String,
17    /// Student model ID or path
18    pub student_model: String,
19    /// Output directory for checkpoints and logs
20    pub output_dir: PathBuf,
21    /// Distillation loss configuration
22    pub distillation_loss: DistillationLoss,
23    /// Progressive distillation (hidden state matching)
24    pub progressive: Option<ProgressiveDistillation>,
25    /// Attention transfer
26    pub attention_transfer: Option<AttentionTransfer>,
27    /// Fine-tuning configuration for student
28    pub fine_tune: FineTuneConfig,
29    /// Number of training epochs
30    pub epochs: usize,
31    /// Steps per epoch (0 = auto-detect from dataset)
32    pub steps_per_epoch: usize,
33    /// Logging frequency (steps)
34    pub log_every_n_steps: usize,
35    /// Checkpoint frequency (steps)
36    pub save_every_n_steps: usize,
37    /// Evaluation frequency (steps)
38    pub eval_every_n_steps: usize,
39    /// Maximum gradient norm for clipping
40    pub max_grad_norm: f32,
41    /// Random seed
42    pub seed: u64,
43}
44
45impl Default for TrainerConfig {
46    fn default() -> Self {
47        Self {
48            teacher_model: String::new(),
49            student_model: String::new(),
50            output_dir: PathBuf::from("./distillation_output"),
51            distillation_loss: DistillationLoss::default(),
52            progressive: None,
53            attention_transfer: None,
54            fine_tune: FineTuneConfig::default(),
55            epochs: 3,
56            steps_per_epoch: 0,
57            log_every_n_steps: 10,
58            save_every_n_steps: DEFAULT_SAVE_STEPS,
59            eval_every_n_steps: 100,
60            max_grad_norm: 1.0,
61            seed: 42,
62        }
63    }
64}
65
66impl TrainerConfig {
67    /// Create new trainer config with teacher and student models
68    #[must_use]
69    pub fn new(teacher: impl Into<String>, student: impl Into<String>) -> Self {
70        Self { teacher_model: teacher.into(), student_model: student.into(), ..Default::default() }
71    }
72
73    /// Set temperature for distillation
74    #[must_use]
75    pub fn temperature(mut self, temp: f32) -> Self {
76        contract_pre_temperature!();
77        self.distillation_loss = DistillationLoss::new(temp, self.distillation_loss.alpha);
78        contract_post_temperature_bounds!(temp);
79        self
80    }
81
82    /// Set alpha for soft vs hard loss weight
83    #[must_use]
84    pub fn alpha(mut self, alpha: f32) -> Self {
85        self.distillation_loss = DistillationLoss::new(self.distillation_loss.temperature, alpha);
86        self
87    }
88
89    /// Enable progressive distillation with layer mapping
90    #[must_use]
91    pub fn with_progressive(mut self, layer_mapping: Vec<(usize, usize)>) -> Self {
92        self.progressive = Some(ProgressiveDistillation::new(layer_mapping));
93        self
94    }
95
96    /// Enable attention transfer
97    #[must_use]
98    pub fn with_attention_transfer(mut self, weight: f32) -> Self {
99        self.attention_transfer = Some(AttentionTransfer::new(weight));
100        self
101    }
102
103    /// Set output directory
104    #[must_use]
105    pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
106        self.output_dir = path.into();
107        self
108    }
109
110    /// Set number of epochs
111    #[must_use]
112    pub fn epochs(mut self, n: usize) -> Self {
113        self.epochs = n;
114        self
115    }
116}