Skip to main content

oxigdal_ml/optimization/distillation/
optimizer.rs

1//! Optimizer implementations for knowledge distillation
2
3use super::config::{EarlyStopping, LearningRateSchedule, OptimizerType};
4
5/// Training state for tracking optimizer momentum and history
6#[derive(Debug, Clone)]
7pub struct TrainingState {
8    /// Current epoch
9    pub epoch: usize,
10    /// Current batch within epoch
11    pub batch: usize,
12    /// Total batches processed
13    pub total_batches: usize,
14    /// Current learning rate
15    pub current_lr: f32,
16    /// Best validation loss seen
17    pub best_val_loss: f32,
18    /// Epochs since improvement (for early stopping)
19    pub epochs_without_improvement: usize,
20    /// Momentum buffer for SGD with momentum
21    pub momentum_buffer: Vec<f32>,
22    /// First moment estimate for Adam (m)
23    pub adam_m: Vec<f32>,
24    /// Second moment estimate for Adam (v)
25    pub adam_v: Vec<f32>,
26    /// Adam timestep
27    pub adam_t: usize,
28    /// Training loss history
29    pub train_loss_history: Vec<f32>,
30    /// Validation loss history
31    pub val_loss_history: Vec<f32>,
32    /// Training accuracy history
33    pub train_acc_history: Vec<f32>,
34    /// Validation accuracy history
35    pub val_acc_history: Vec<f32>,
36}
37
38impl TrainingState {
39    /// Creates a new training state
40    #[must_use]
41    pub fn new(num_params: usize, initial_lr: f32) -> Self {
42        Self {
43            epoch: 0,
44            batch: 0,
45            total_batches: 0,
46            current_lr: initial_lr,
47            best_val_loss: f32::MAX,
48            epochs_without_improvement: 0,
49            momentum_buffer: vec![0.0; num_params],
50            adam_m: vec![0.0; num_params],
51            adam_v: vec![0.0; num_params],
52            adam_t: 0,
53            train_loss_history: Vec::new(),
54            val_loss_history: Vec::new(),
55            train_acc_history: Vec::new(),
56            val_acc_history: Vec::new(),
57        }
58    }
59
60    /// Updates learning rate based on schedule
61    pub fn update_learning_rate(
62        &mut self,
63        base_lr: f32,
64        schedule: &LearningRateSchedule,
65        total_epochs: usize,
66    ) {
67        self.current_lr = match schedule {
68            LearningRateSchedule::Constant => base_lr,
69            LearningRateSchedule::StepDecay {
70                decay_factor,
71                step_size,
72            } => {
73                let num_decays = self.epoch / step_size;
74                base_lr * decay_factor.powi(num_decays as i32)
75            }
76            LearningRateSchedule::CosineAnnealing { min_lr } => {
77                let progress = self.epoch as f32 / total_epochs as f32;
78                let cos_value = (std::f32::consts::PI * progress).cos();
79                min_lr + (base_lr - min_lr) * (1.0 + cos_value) / 2.0
80            }
81            LearningRateSchedule::WarmupDecay {
82                warmup_epochs,
83                decay_factor,
84            } => {
85                if self.epoch < *warmup_epochs {
86                    base_lr * (self.epoch + 1) as f32 / *warmup_epochs as f32
87                } else {
88                    let epochs_after_warmup = self.epoch - warmup_epochs;
89                    base_lr * decay_factor.powi(epochs_after_warmup as i32)
90                }
91            }
92        };
93    }
94
95    /// Checks if early stopping should trigger
96    pub fn should_stop(&self, config: &Option<EarlyStopping>) -> bool {
97        if let Some(es) = config {
98            self.epochs_without_improvement >= es.patience
99        } else {
100            false
101        }
102    }
103
104    /// Updates early stopping state based on validation loss
105    pub fn update_early_stopping(&mut self, val_loss: f32, config: &Option<EarlyStopping>) {
106        if let Some(es) = config {
107            if val_loss < self.best_val_loss - es.min_delta {
108                self.best_val_loss = val_loss;
109                self.epochs_without_improvement = 0;
110            } else {
111                self.epochs_without_improvement += 1;
112            }
113        }
114    }
115}
116
117/// Applies SGD update to weights
118pub fn sgd_update(weights: &mut [f32], gradients: &[f32], lr: f32) {
119    for (w, g) in weights.iter_mut().zip(gradients.iter()) {
120        *w -= lr * g;
121    }
122}
123
124/// Applies SGD with momentum update to weights
125pub fn sgd_momentum_update(
126    weights: &mut [f32],
127    gradients: &[f32],
128    momentum_buffer: &mut [f32],
129    lr: f32,
130    momentum: f32,
131) {
132    for ((w, g), m) in weights
133        .iter_mut()
134        .zip(gradients.iter())
135        .zip(momentum_buffer.iter_mut())
136    {
137        *m = momentum * *m + g;
138        *w -= lr * *m;
139    }
140}
141
142/// Adam optimizer parameters
143#[derive(Debug, Clone, Copy)]
144pub struct AdamParams {
145    /// Learning rate
146    pub lr: f32,
147    /// Beta1 parameter (first moment decay)
148    pub beta1: f32,
149    /// Beta2 parameter (second moment decay)
150    pub beta2: f32,
151    /// Epsilon for numerical stability
152    pub epsilon: f32,
153}
154
155impl Default for AdamParams {
156    fn default() -> Self {
157        Self {
158            lr: 0.001,
159            beta1: 0.9,
160            beta2: 0.999,
161            epsilon: 1e-8,
162        }
163    }
164}
165
166/// Applies Adam optimizer update to weights
167#[allow(clippy::too_many_arguments)]
168pub fn adam_update(
169    weights: &mut [f32],
170    gradients: &[f32],
171    m: &mut [f32],
172    v: &mut [f32],
173    t: usize,
174    lr: f32,
175    beta1: f32,
176    beta2: f32,
177    epsilon: f32,
178) {
179    let bias_correction1 = 1.0 - beta1.powi(t as i32);
180    let bias_correction2 = 1.0 - beta2.powi(t as i32);
181
182    for i in 0..weights.len() {
183        // Update biased first moment estimate
184        m[i] = beta1 * m[i] + (1.0 - beta1) * gradients[i];
185        // Update biased second raw moment estimate
186        v[i] = beta2 * v[i] + (1.0 - beta2) * gradients[i].powi(2);
187
188        // Compute bias-corrected estimates
189        let m_hat = m[i] / bias_correction1;
190        let v_hat = v[i] / bias_correction2;
191
192        // Update weights
193        weights[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
194    }
195}
196
197/// Applies AdamW optimizer update with decoupled weight decay
198#[allow(clippy::too_many_arguments)]
199pub fn adamw_update(
200    weights: &mut [f32],
201    gradients: &[f32],
202    m: &mut [f32],
203    v: &mut [f32],
204    t: usize,
205    lr: f32,
206    beta1: f32,
207    beta2: f32,
208    epsilon: f32,
209    weight_decay: f32,
210) {
211    let bias_correction1 = 1.0 - beta1.powi(t as i32);
212    let bias_correction2 = 1.0 - beta2.powi(t as i32);
213
214    for i in 0..weights.len() {
215        // Decoupled weight decay
216        weights[i] -= lr * weight_decay * weights[i];
217
218        // Update biased first moment estimate
219        m[i] = beta1 * m[i] + (1.0 - beta1) * gradients[i];
220        // Update biased second raw moment estimate
221        v[i] = beta2 * v[i] + (1.0 - beta2) * gradients[i].powi(2);
222
223        // Compute bias-corrected estimates
224        let m_hat = m[i] / bias_correction1;
225        let v_hat = v[i] / bias_correction2;
226
227        // Update weights
228        weights[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
229    }
230}
231
232/// Clips gradients by global norm
233pub fn clip_gradients(gradients: &mut [f32], max_norm: f32) {
234    let total_norm: f32 = gradients.iter().map(|g| g.powi(2)).sum::<f32>().sqrt();
235
236    if total_norm > max_norm {
237        let scale = max_norm / (total_norm + 1e-6);
238        for g in gradients.iter_mut() {
239            *g *= scale;
240        }
241    }
242}
243
244/// Applies optimizer update based on configuration
245pub fn apply_optimizer_update(
246    params: &mut [f32],
247    gradients: &[f32],
248    state: &mut TrainingState,
249    optimizer: &OptimizerType,
250) {
251    match optimizer {
252        OptimizerType::SGD => {
253            sgd_update(params, gradients, state.current_lr);
254        }
255        OptimizerType::SGDMomentum { momentum } => {
256            let momentum_f = *momentum as f32 / 100.0;
257            sgd_momentum_update(
258                params,
259                gradients,
260                &mut state.momentum_buffer,
261                state.current_lr,
262                momentum_f,
263            );
264        }
265        OptimizerType::Adam => {
266            state.adam_t += 1;
267            adam_update(
268                params,
269                gradients,
270                &mut state.adam_m,
271                &mut state.adam_v,
272                state.adam_t,
273                state.current_lr,
274                0.9,
275                0.999,
276                1e-8,
277            );
278        }
279        OptimizerType::AdamW { weight_decay } => {
280            state.adam_t += 1;
281            let wd = *weight_decay as f32 / 100.0;
282            adamw_update(
283                params,
284                gradients,
285                &mut state.adam_m,
286                &mut state.adam_v,
287                state.adam_t,
288                state.current_lr,
289                0.9,
290                0.999,
291                1e-8,
292                wd,
293            );
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_gradient_clipping() {
304        let mut grads = vec![10.0, 20.0, 30.0];
305        clip_gradients(&mut grads, 1.0);
306
307        let norm: f32 = grads.iter().map(|g| g.powi(2)).sum::<f32>().sqrt();
308        assert!(norm <= 1.0 + 1e-6);
309    }
310
311    #[test]
312    fn test_optimizer_sgd() {
313        let mut weights = vec![1.0, 2.0, 3.0];
314        let gradients = vec![0.1, 0.2, 0.3];
315
316        sgd_update(&mut weights, &gradients, 0.1);
317
318        assert!((weights[0] - 0.99).abs() < 1e-6);
319        assert!((weights[1] - 1.98).abs() < 1e-6);
320        assert!((weights[2] - 2.97).abs() < 1e-6);
321    }
322
323    #[test]
324    fn test_optimizer_adam() {
325        let mut weights = vec![1.0, 2.0, 3.0];
326        let gradients = vec![0.1, 0.2, 0.3];
327        let mut m = vec![0.0; 3];
328        let mut v = vec![0.0; 3];
329
330        adam_update(
331            &mut weights,
332            &gradients,
333            &mut m,
334            &mut v,
335            1,
336            0.001,
337            0.9,
338            0.999,
339            1e-8,
340        );
341
342        assert!(weights[0] < 1.0);
343        assert!(weights[1] < 2.0);
344        assert!(weights[2] < 3.0);
345    }
346
347    #[test]
348    fn test_training_state_lr_schedule() {
349        let mut state = TrainingState::new(100, 0.1);
350
351        state.epoch = 50;
352        state.update_learning_rate(0.1, &LearningRateSchedule::Constant, 100);
353        assert!((state.current_lr - 0.1).abs() < 1e-6);
354
355        state.update_learning_rate(
356            0.1,
357            &LearningRateSchedule::StepDecay {
358                decay_factor: 0.5,
359                step_size: 10,
360            },
361            100,
362        );
363        assert!((state.current_lr - 0.003125).abs() < 1e-6);
364
365        state.epoch = 50;
366        state.update_learning_rate(
367            0.1,
368            &LearningRateSchedule::CosineAnnealing { min_lr: 0.0 },
369            100,
370        );
371        assert!(state.current_lr > 0.0 && state.current_lr < 0.1);
372    }
373
374    #[test]
375    fn test_early_stopping() {
376        let mut state = TrainingState::new(100, 0.1);
377        let early_stopping = Some(EarlyStopping {
378            patience: 3,
379            min_delta: 0.01,
380        });
381
382        assert!(!state.should_stop(&early_stopping));
383
384        state.update_early_stopping(1.0, &early_stopping);
385        assert_eq!(state.epochs_without_improvement, 0);
386
387        state.update_early_stopping(1.0, &early_stopping);
388        assert_eq!(state.epochs_without_improvement, 1);
389
390        state.update_early_stopping(0.995, &early_stopping);
391        assert_eq!(state.epochs_without_improvement, 2);
392
393        state.update_early_stopping(1.0, &early_stopping);
394        assert_eq!(state.epochs_without_improvement, 3);
395
396        assert!(state.should_stop(&early_stopping));
397    }
398}