oxigdal_ml/optimization/distillation/
config.rs1use crate::error::{MlError, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum DistillationLoss {
8 #[default]
10 KLDivergence,
11 MSE,
13 CrossEntropy,
15 Weighted {
17 distill_weight: u8,
19 ground_truth_weight: u8,
21 },
22}
23
24#[derive(Debug, Clone, Copy)]
26pub struct Temperature(pub f32);
27
28impl Default for Temperature {
29 fn default() -> Self {
30 Self(2.0) }
32}
33
34impl Temperature {
35 #[must_use]
37 pub fn new(value: f32) -> Self {
38 Self(value.max(0.1)) }
40
41 #[must_use]
43 pub fn scale_logits(&self, logits: &[f32]) -> Vec<f32> {
44 logits.iter().map(|&x| x / self.0).collect()
45 }
46
47 #[must_use]
49 pub fn value(&self) -> f32 {
50 self.0
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum OptimizerType {
57 SGD,
59 SGDMomentum {
61 momentum: u8,
63 },
64 #[default]
66 Adam,
67 AdamW {
69 weight_decay: u8,
71 },
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Default)]
76pub enum LearningRateSchedule {
77 #[default]
79 Constant,
80 StepDecay {
82 decay_factor: f32,
84 step_size: usize,
86 },
87 CosineAnnealing {
89 min_lr: f32,
91 },
92 WarmupDecay {
94 warmup_epochs: usize,
96 decay_factor: f32,
98 },
99}
100
101#[derive(Debug, Clone, Copy)]
103pub struct EarlyStopping {
104 pub patience: usize,
106 pub min_delta: f32,
108}
109
110impl Default for EarlyStopping {
111 fn default() -> Self {
112 Self {
113 patience: 10,
114 min_delta: 0.001,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct DistillationConfig {
122 pub loss: DistillationLoss,
124 pub temperature: Temperature,
126 pub epochs: usize,
128 pub learning_rate: f32,
130 pub batch_size: usize,
132 pub alpha: f32,
134 pub optimizer: OptimizerType,
136 pub lr_schedule: LearningRateSchedule,
138 pub early_stopping: Option<EarlyStopping>,
140 pub gradient_clip: Option<f32>,
142 pub validation_split: f32,
144 pub num_classes: usize,
146 pub seed: u64,
148}
149
150impl Default for DistillationConfig {
151 fn default() -> Self {
152 Self {
153 loss: DistillationLoss::KLDivergence,
154 temperature: Temperature::default(),
155 epochs: 100,
156 learning_rate: 0.001,
157 batch_size: 32,
158 alpha: 0.5,
159 optimizer: OptimizerType::Adam,
160 lr_schedule: LearningRateSchedule::Constant,
161 early_stopping: Some(EarlyStopping::default()),
162 gradient_clip: Some(1.0),
163 validation_split: 0.1,
164 num_classes: 10,
165 seed: 42,
166 }
167 }
168}
169
170impl DistillationConfig {
171 #[must_use]
173 pub fn builder() -> DistillationConfigBuilder {
174 DistillationConfigBuilder::default()
175 }
176
177 pub fn validate(&self) -> Result<()> {
179 if self.alpha < 0.0 || self.alpha > 1.0 {
180 return Err(MlError::InvalidConfig(format!(
181 "Alpha must be between 0.0 and 1.0, got {}",
182 self.alpha
183 )));
184 }
185 if self.learning_rate <= 0.0 {
186 return Err(MlError::InvalidConfig(format!(
187 "Learning rate must be positive, got {}",
188 self.learning_rate
189 )));
190 }
191 if self.epochs == 0 {
192 return Err(MlError::InvalidConfig(
193 "Epochs must be at least 1".to_string(),
194 ));
195 }
196 if self.batch_size == 0 {
197 return Err(MlError::InvalidConfig(
198 "Batch size must be at least 1".to_string(),
199 ));
200 }
201 if self.validation_split < 0.0 || self.validation_split > 0.5 {
202 return Err(MlError::InvalidConfig(format!(
203 "Validation split must be between 0.0 and 0.5, got {}",
204 self.validation_split
205 )));
206 }
207 Ok(())
208 }
209}
210
211#[derive(Debug, Default)]
213pub struct DistillationConfigBuilder {
214 loss: Option<DistillationLoss>,
215 temperature: Option<f32>,
216 epochs: Option<usize>,
217 learning_rate: Option<f32>,
218 batch_size: Option<usize>,
219 alpha: Option<f32>,
220 optimizer: Option<OptimizerType>,
221 lr_schedule: Option<LearningRateSchedule>,
222 early_stopping: Option<Option<EarlyStopping>>,
223 gradient_clip: Option<Option<f32>>,
224 validation_split: Option<f32>,
225 num_classes: Option<usize>,
226 seed: Option<u64>,
227}
228
229impl DistillationConfigBuilder {
230 #[must_use]
232 pub fn loss(mut self, loss: DistillationLoss) -> Self {
233 self.loss = Some(loss);
234 self
235 }
236
237 #[must_use]
239 pub fn temperature(mut self, temp: f32) -> Self {
240 self.temperature = Some(temp);
241 self
242 }
243
244 #[must_use]
246 pub fn epochs(mut self, epochs: usize) -> Self {
247 self.epochs = Some(epochs);
248 self
249 }
250
251 #[must_use]
253 pub fn learning_rate(mut self, lr: f32) -> Self {
254 self.learning_rate = Some(lr);
255 self
256 }
257
258 #[must_use]
260 pub fn batch_size(mut self, size: usize) -> Self {
261 self.batch_size = Some(size);
262 self
263 }
264
265 #[must_use]
267 pub fn alpha(mut self, alpha: f32) -> Self {
268 self.alpha = Some(alpha.clamp(0.0, 1.0));
269 self
270 }
271
272 #[must_use]
274 pub fn optimizer(mut self, optimizer: OptimizerType) -> Self {
275 self.optimizer = Some(optimizer);
276 self
277 }
278
279 #[must_use]
281 pub fn lr_schedule(mut self, schedule: LearningRateSchedule) -> Self {
282 self.lr_schedule = Some(schedule);
283 self
284 }
285
286 #[must_use]
288 pub fn early_stopping(mut self, early_stopping: Option<EarlyStopping>) -> Self {
289 self.early_stopping = Some(early_stopping);
290 self
291 }
292
293 #[must_use]
295 pub fn gradient_clip(mut self, clip: Option<f32>) -> Self {
296 self.gradient_clip = Some(clip);
297 self
298 }
299
300 #[must_use]
302 pub fn validation_split(mut self, split: f32) -> Self {
303 self.validation_split = Some(split.clamp(0.0, 0.5));
304 self
305 }
306
307 #[must_use]
309 pub fn num_classes(mut self, num: usize) -> Self {
310 self.num_classes = Some(num);
311 self
312 }
313
314 #[must_use]
316 pub fn seed(mut self, seed: u64) -> Self {
317 self.seed = Some(seed);
318 self
319 }
320
321 #[must_use]
323 pub fn build(self) -> DistillationConfig {
324 DistillationConfig {
325 loss: self.loss.unwrap_or(DistillationLoss::KLDivergence),
326 temperature: Temperature::new(self.temperature.unwrap_or(2.0)),
327 epochs: self.epochs.unwrap_or(100),
328 learning_rate: self.learning_rate.unwrap_or(0.001),
329 batch_size: self.batch_size.unwrap_or(32),
330 alpha: self.alpha.unwrap_or(0.5),
331 optimizer: self.optimizer.unwrap_or(OptimizerType::Adam),
332 lr_schedule: self.lr_schedule.unwrap_or(LearningRateSchedule::Constant),
333 early_stopping: self
334 .early_stopping
335 .unwrap_or(Some(EarlyStopping::default())),
336 gradient_clip: self.gradient_clip.unwrap_or(Some(1.0)),
337 validation_split: self.validation_split.unwrap_or(0.1),
338 num_classes: self.num_classes.unwrap_or(10),
339 seed: self.seed.unwrap_or(42),
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_distillation_config_builder() {
350 let config = DistillationConfig::builder()
351 .loss(DistillationLoss::MSE)
352 .temperature(3.0)
353 .epochs(50)
354 .learning_rate(0.01)
355 .batch_size(64)
356 .alpha(0.7)
357 .build();
358
359 assert_eq!(config.loss, DistillationLoss::MSE);
360 assert!((config.temperature.0 - 3.0).abs() < 1e-6);
361 assert_eq!(config.epochs, 50);
362 assert!((config.learning_rate - 0.01).abs() < 1e-6);
363 assert_eq!(config.batch_size, 64);
364 assert!((config.alpha - 0.7).abs() < 1e-6);
365 }
366
367 #[test]
368 fn test_config_validation() {
369 let valid_config = DistillationConfig::default();
370 assert!(valid_config.validate().is_ok());
371
372 let invalid_alpha = DistillationConfig {
373 alpha: 1.5,
374 ..Default::default()
375 };
376 assert!(invalid_alpha.validate().is_err());
377
378 let invalid_lr = DistillationConfig {
379 learning_rate: -0.1,
380 ..Default::default()
381 };
382 assert!(invalid_lr.validate().is_err());
383 }
384
385 #[test]
386 fn test_temperature_scaling() {
387 let temp = Temperature::new(2.0);
388 let logits = vec![1.0, 2.0, 3.0];
389 let scaled = temp.scale_logits(&logits);
390
391 assert!((scaled[0] - 0.5).abs() < 1e-6);
392 assert!((scaled[1] - 1.0).abs() < 1e-6);
393 assert!((scaled[2] - 1.5).abs() < 1e-6);
394 }
395
396 #[test]
397 fn test_temperature_minimum() {
398 let temp = Temperature::new(0.01);
399 assert!(temp.0 >= 0.1);
400 }
401}