Skip to main content

oxigdal_ml/optimization/distillation/
config.rs

1//! Configuration types for knowledge distillation
2
3use crate::error::{MlError, Result};
4
5/// Distillation loss function
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum DistillationLoss {
8    /// Kullback-Leibler divergence
9    #[default]
10    KLDivergence,
11    /// Mean squared error
12    MSE,
13    /// Cross-entropy
14    CrossEntropy,
15    /// Custom weighted combination
16    Weighted {
17        /// Weight for distillation loss
18        distill_weight: u8,
19        /// Weight for ground truth loss
20        ground_truth_weight: u8,
21    },
22}
23
24/// Temperature for softening probability distributions
25#[derive(Debug, Clone, Copy)]
26pub struct Temperature(pub f32);
27
28impl Default for Temperature {
29    fn default() -> Self {
30        Self(2.0) // Standard temperature for distillation
31    }
32}
33
34impl Temperature {
35    /// Creates a new temperature value
36    #[must_use]
37    pub fn new(value: f32) -> Self {
38        Self(value.max(0.1)) // Minimum temperature to avoid numerical issues
39    }
40
41    /// Applies temperature scaling to logits
42    #[must_use]
43    pub fn scale_logits(&self, logits: &[f32]) -> Vec<f32> {
44        logits.iter().map(|&x| x / self.0).collect()
45    }
46
47    /// Returns the temperature value
48    #[must_use]
49    pub fn value(&self) -> f32 {
50        self.0
51    }
52}
53
54/// Optimizer type for training
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum OptimizerType {
57    /// Stochastic Gradient Descent
58    SGD,
59    /// SGD with momentum
60    SGDMomentum {
61        /// Momentum coefficient (typically 0.9)
62        momentum: u8,
63    },
64    /// Adam optimizer
65    #[default]
66    Adam,
67    /// AdamW with weight decay
68    AdamW {
69        /// Weight decay coefficient (as percentage, e.g., 1 = 0.01)
70        weight_decay: u8,
71    },
72}
73
74/// Learning rate schedule
75#[derive(Debug, Clone, Copy, PartialEq, Default)]
76pub enum LearningRateSchedule {
77    /// Constant learning rate
78    #[default]
79    Constant,
80    /// Step decay
81    StepDecay {
82        /// Decay factor
83        decay_factor: f32,
84        /// Steps between decays
85        step_size: usize,
86    },
87    /// Cosine annealing
88    CosineAnnealing {
89        /// Minimum learning rate
90        min_lr: f32,
91    },
92    /// Warmup then decay
93    WarmupDecay {
94        /// Warmup epochs
95        warmup_epochs: usize,
96        /// Decay factor per epoch after warmup
97        decay_factor: f32,
98    },
99}
100
101/// Early stopping configuration
102#[derive(Debug, Clone, Copy)]
103pub struct EarlyStopping {
104    /// Patience (epochs without improvement)
105    pub patience: usize,
106    /// Minimum delta for improvement
107    pub min_delta: f32,
108}
109
110impl Default for EarlyStopping {
111    fn default() -> Self {
112        Self {
113            patience: 10,
114            min_delta: 0.001,
115        }
116    }
117}
118
119/// Knowledge distillation configuration
120#[derive(Debug, Clone)]
121pub struct DistillationConfig {
122    /// Distillation loss function
123    pub loss: DistillationLoss,
124    /// Temperature for softening
125    pub temperature: Temperature,
126    /// Number of training epochs
127    pub epochs: usize,
128    /// Learning rate
129    pub learning_rate: f32,
130    /// Batch size
131    pub batch_size: usize,
132    /// Alpha weight for distillation loss (1 - alpha for hard label loss)
133    pub alpha: f32,
134    /// Optimizer type
135    pub optimizer: OptimizerType,
136    /// Learning rate schedule
137    pub lr_schedule: LearningRateSchedule,
138    /// Early stopping configuration
139    pub early_stopping: Option<EarlyStopping>,
140    /// Gradient clipping threshold (None = no clipping)
141    pub gradient_clip: Option<f32>,
142    /// Validation split ratio (0.0 to 0.3)
143    pub validation_split: f32,
144    /// Number of classes for classification
145    pub num_classes: usize,
146    /// Random seed for reproducibility
147    pub seed: u64,
148}
149
150impl Default for DistillationConfig {
151    fn default() -> Self {
152        Self {
153            loss: DistillationLoss::KLDivergence,
154            temperature: Temperature::default(),
155            epochs: 100,
156            learning_rate: 0.001,
157            batch_size: 32,
158            alpha: 0.5,
159            optimizer: OptimizerType::Adam,
160            lr_schedule: LearningRateSchedule::Constant,
161            early_stopping: Some(EarlyStopping::default()),
162            gradient_clip: Some(1.0),
163            validation_split: 0.1,
164            num_classes: 10,
165            seed: 42,
166        }
167    }
168}
169
170impl DistillationConfig {
171    /// Creates a configuration builder
172    #[must_use]
173    pub fn builder() -> DistillationConfigBuilder {
174        DistillationConfigBuilder::default()
175    }
176
177    /// Validates the configuration
178    pub fn validate(&self) -> Result<()> {
179        if self.alpha < 0.0 || self.alpha > 1.0 {
180            return Err(MlError::InvalidConfig(format!(
181                "Alpha must be between 0.0 and 1.0, got {}",
182                self.alpha
183            )));
184        }
185        if self.learning_rate <= 0.0 {
186            return Err(MlError::InvalidConfig(format!(
187                "Learning rate must be positive, got {}",
188                self.learning_rate
189            )));
190        }
191        if self.epochs == 0 {
192            return Err(MlError::InvalidConfig(
193                "Epochs must be at least 1".to_string(),
194            ));
195        }
196        if self.batch_size == 0 {
197            return Err(MlError::InvalidConfig(
198                "Batch size must be at least 1".to_string(),
199            ));
200        }
201        if self.validation_split < 0.0 || self.validation_split > 0.5 {
202            return Err(MlError::InvalidConfig(format!(
203                "Validation split must be between 0.0 and 0.5, got {}",
204                self.validation_split
205            )));
206        }
207        Ok(())
208    }
209}
210
211/// Builder for distillation configuration
212#[derive(Debug, Default)]
213pub struct DistillationConfigBuilder {
214    loss: Option<DistillationLoss>,
215    temperature: Option<f32>,
216    epochs: Option<usize>,
217    learning_rate: Option<f32>,
218    batch_size: Option<usize>,
219    alpha: Option<f32>,
220    optimizer: Option<OptimizerType>,
221    lr_schedule: Option<LearningRateSchedule>,
222    early_stopping: Option<Option<EarlyStopping>>,
223    gradient_clip: Option<Option<f32>>,
224    validation_split: Option<f32>,
225    num_classes: Option<usize>,
226    seed: Option<u64>,
227}
228
229impl DistillationConfigBuilder {
230    /// Sets the distillation loss
231    #[must_use]
232    pub fn loss(mut self, loss: DistillationLoss) -> Self {
233        self.loss = Some(loss);
234        self
235    }
236
237    /// Sets the temperature
238    #[must_use]
239    pub fn temperature(mut self, temp: f32) -> Self {
240        self.temperature = Some(temp);
241        self
242    }
243
244    /// Sets the number of epochs
245    #[must_use]
246    pub fn epochs(mut self, epochs: usize) -> Self {
247        self.epochs = Some(epochs);
248        self
249    }
250
251    /// Sets the learning rate
252    #[must_use]
253    pub fn learning_rate(mut self, lr: f32) -> Self {
254        self.learning_rate = Some(lr);
255        self
256    }
257
258    /// Sets the batch size
259    #[must_use]
260    pub fn batch_size(mut self, size: usize) -> Self {
261        self.batch_size = Some(size);
262        self
263    }
264
265    /// Sets the alpha weight for distillation loss
266    #[must_use]
267    pub fn alpha(mut self, alpha: f32) -> Self {
268        self.alpha = Some(alpha.clamp(0.0, 1.0));
269        self
270    }
271
272    /// Sets the optimizer type
273    #[must_use]
274    pub fn optimizer(mut self, optimizer: OptimizerType) -> Self {
275        self.optimizer = Some(optimizer);
276        self
277    }
278
279    /// Sets the learning rate schedule
280    #[must_use]
281    pub fn lr_schedule(mut self, schedule: LearningRateSchedule) -> Self {
282        self.lr_schedule = Some(schedule);
283        self
284    }
285
286    /// Sets early stopping configuration
287    #[must_use]
288    pub fn early_stopping(mut self, early_stopping: Option<EarlyStopping>) -> Self {
289        self.early_stopping = Some(early_stopping);
290        self
291    }
292
293    /// Sets gradient clipping threshold
294    #[must_use]
295    pub fn gradient_clip(mut self, clip: Option<f32>) -> Self {
296        self.gradient_clip = Some(clip);
297        self
298    }
299
300    /// Sets validation split ratio
301    #[must_use]
302    pub fn validation_split(mut self, split: f32) -> Self {
303        self.validation_split = Some(split.clamp(0.0, 0.5));
304        self
305    }
306
307    /// Sets number of classes
308    #[must_use]
309    pub fn num_classes(mut self, num: usize) -> Self {
310        self.num_classes = Some(num);
311        self
312    }
313
314    /// Sets random seed
315    #[must_use]
316    pub fn seed(mut self, seed: u64) -> Self {
317        self.seed = Some(seed);
318        self
319    }
320
321    /// Builds the configuration
322    #[must_use]
323    pub fn build(self) -> DistillationConfig {
324        DistillationConfig {
325            loss: self.loss.unwrap_or(DistillationLoss::KLDivergence),
326            temperature: Temperature::new(self.temperature.unwrap_or(2.0)),
327            epochs: self.epochs.unwrap_or(100),
328            learning_rate: self.learning_rate.unwrap_or(0.001),
329            batch_size: self.batch_size.unwrap_or(32),
330            alpha: self.alpha.unwrap_or(0.5),
331            optimizer: self.optimizer.unwrap_or(OptimizerType::Adam),
332            lr_schedule: self.lr_schedule.unwrap_or(LearningRateSchedule::Constant),
333            early_stopping: self
334                .early_stopping
335                .unwrap_or(Some(EarlyStopping::default())),
336            gradient_clip: self.gradient_clip.unwrap_or(Some(1.0)),
337            validation_split: self.validation_split.unwrap_or(0.1),
338            num_classes: self.num_classes.unwrap_or(10),
339            seed: self.seed.unwrap_or(42),
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_distillation_config_builder() {
350        let config = DistillationConfig::builder()
351            .loss(DistillationLoss::MSE)
352            .temperature(3.0)
353            .epochs(50)
354            .learning_rate(0.01)
355            .batch_size(64)
356            .alpha(0.7)
357            .build();
358
359        assert_eq!(config.loss, DistillationLoss::MSE);
360        assert!((config.temperature.0 - 3.0).abs() < 1e-6);
361        assert_eq!(config.epochs, 50);
362        assert!((config.learning_rate - 0.01).abs() < 1e-6);
363        assert_eq!(config.batch_size, 64);
364        assert!((config.alpha - 0.7).abs() < 1e-6);
365    }
366
367    #[test]
368    fn test_config_validation() {
369        let valid_config = DistillationConfig::default();
370        assert!(valid_config.validate().is_ok());
371
372        let invalid_alpha = DistillationConfig {
373            alpha: 1.5,
374            ..Default::default()
375        };
376        assert!(invalid_alpha.validate().is_err());
377
378        let invalid_lr = DistillationConfig {
379            learning_rate: -0.1,
380            ..Default::default()
381        };
382        assert!(invalid_lr.validate().is_err());
383    }
384
385    #[test]
386    fn test_temperature_scaling() {
387        let temp = Temperature::new(2.0);
388        let logits = vec![1.0, 2.0, 3.0];
389        let scaled = temp.scale_logits(&logits);
390
391        assert!((scaled[0] - 0.5).abs() < 1e-6);
392        assert!((scaled[1] - 1.0).abs() < 1e-6);
393        assert!((scaled[2] - 1.5).abs() < 1e-6);
394    }
395
396    #[test]
397    fn test_temperature_minimum() {
398        let temp = Temperature::new(0.01);
399        assert!(temp.0 >= 0.1);
400    }
401}