oxigdal_ml/optimization/distillation/
optimizer.rs1use super::config::{EarlyStopping, LearningRateSchedule, OptimizerType};
4
5#[derive(Debug, Clone)]
7pub struct TrainingState {
8 pub epoch: usize,
10 pub batch: usize,
12 pub total_batches: usize,
14 pub current_lr: f32,
16 pub best_val_loss: f32,
18 pub epochs_without_improvement: usize,
20 pub momentum_buffer: Vec<f32>,
22 pub adam_m: Vec<f32>,
24 pub adam_v: Vec<f32>,
26 pub adam_t: usize,
28 pub train_loss_history: Vec<f32>,
30 pub val_loss_history: Vec<f32>,
32 pub train_acc_history: Vec<f32>,
34 pub val_acc_history: Vec<f32>,
36}
37
38impl TrainingState {
39 #[must_use]
41 pub fn new(num_params: usize, initial_lr: f32) -> Self {
42 Self {
43 epoch: 0,
44 batch: 0,
45 total_batches: 0,
46 current_lr: initial_lr,
47 best_val_loss: f32::MAX,
48 epochs_without_improvement: 0,
49 momentum_buffer: vec![0.0; num_params],
50 adam_m: vec![0.0; num_params],
51 adam_v: vec![0.0; num_params],
52 adam_t: 0,
53 train_loss_history: Vec::new(),
54 val_loss_history: Vec::new(),
55 train_acc_history: Vec::new(),
56 val_acc_history: Vec::new(),
57 }
58 }
59
60 pub fn update_learning_rate(
62 &mut self,
63 base_lr: f32,
64 schedule: &LearningRateSchedule,
65 total_epochs: usize,
66 ) {
67 self.current_lr = match schedule {
68 LearningRateSchedule::Constant => base_lr,
69 LearningRateSchedule::StepDecay {
70 decay_factor,
71 step_size,
72 } => {
73 let num_decays = self.epoch / step_size;
74 base_lr * decay_factor.powi(num_decays as i32)
75 }
76 LearningRateSchedule::CosineAnnealing { min_lr } => {
77 let progress = self.epoch as f32 / total_epochs as f32;
78 let cos_value = (std::f32::consts::PI * progress).cos();
79 min_lr + (base_lr - min_lr) * (1.0 + cos_value) / 2.0
80 }
81 LearningRateSchedule::WarmupDecay {
82 warmup_epochs,
83 decay_factor,
84 } => {
85 if self.epoch < *warmup_epochs {
86 base_lr * (self.epoch + 1) as f32 / *warmup_epochs as f32
87 } else {
88 let epochs_after_warmup = self.epoch - warmup_epochs;
89 base_lr * decay_factor.powi(epochs_after_warmup as i32)
90 }
91 }
92 };
93 }
94
95 pub fn should_stop(&self, config: &Option<EarlyStopping>) -> bool {
97 if let Some(es) = config {
98 self.epochs_without_improvement >= es.patience
99 } else {
100 false
101 }
102 }
103
104 pub fn update_early_stopping(&mut self, val_loss: f32, config: &Option<EarlyStopping>) {
106 if let Some(es) = config {
107 if val_loss < self.best_val_loss - es.min_delta {
108 self.best_val_loss = val_loss;
109 self.epochs_without_improvement = 0;
110 } else {
111 self.epochs_without_improvement += 1;
112 }
113 }
114 }
115}
116
117pub fn sgd_update(weights: &mut [f32], gradients: &[f32], lr: f32) {
119 for (w, g) in weights.iter_mut().zip(gradients.iter()) {
120 *w -= lr * g;
121 }
122}
123
124pub fn sgd_momentum_update(
126 weights: &mut [f32],
127 gradients: &[f32],
128 momentum_buffer: &mut [f32],
129 lr: f32,
130 momentum: f32,
131) {
132 for ((w, g), m) in weights
133 .iter_mut()
134 .zip(gradients.iter())
135 .zip(momentum_buffer.iter_mut())
136 {
137 *m = momentum * *m + g;
138 *w -= lr * *m;
139 }
140}
141
142#[derive(Debug, Clone, Copy)]
144pub struct AdamParams {
145 pub lr: f32,
147 pub beta1: f32,
149 pub beta2: f32,
151 pub epsilon: f32,
153}
154
155impl Default for AdamParams {
156 fn default() -> Self {
157 Self {
158 lr: 0.001,
159 beta1: 0.9,
160 beta2: 0.999,
161 epsilon: 1e-8,
162 }
163 }
164}
165
166#[allow(clippy::too_many_arguments)]
168pub fn adam_update(
169 weights: &mut [f32],
170 gradients: &[f32],
171 m: &mut [f32],
172 v: &mut [f32],
173 t: usize,
174 lr: f32,
175 beta1: f32,
176 beta2: f32,
177 epsilon: f32,
178) {
179 let bias_correction1 = 1.0 - beta1.powi(t as i32);
180 let bias_correction2 = 1.0 - beta2.powi(t as i32);
181
182 for i in 0..weights.len() {
183 m[i] = beta1 * m[i] + (1.0 - beta1) * gradients[i];
185 v[i] = beta2 * v[i] + (1.0 - beta2) * gradients[i].powi(2);
187
188 let m_hat = m[i] / bias_correction1;
190 let v_hat = v[i] / bias_correction2;
191
192 weights[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
194 }
195}
196
197#[allow(clippy::too_many_arguments)]
199pub fn adamw_update(
200 weights: &mut [f32],
201 gradients: &[f32],
202 m: &mut [f32],
203 v: &mut [f32],
204 t: usize,
205 lr: f32,
206 beta1: f32,
207 beta2: f32,
208 epsilon: f32,
209 weight_decay: f32,
210) {
211 let bias_correction1 = 1.0 - beta1.powi(t as i32);
212 let bias_correction2 = 1.0 - beta2.powi(t as i32);
213
214 for i in 0..weights.len() {
215 weights[i] -= lr * weight_decay * weights[i];
217
218 m[i] = beta1 * m[i] + (1.0 - beta1) * gradients[i];
220 v[i] = beta2 * v[i] + (1.0 - beta2) * gradients[i].powi(2);
222
223 let m_hat = m[i] / bias_correction1;
225 let v_hat = v[i] / bias_correction2;
226
227 weights[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
229 }
230}
231
232pub fn clip_gradients(gradients: &mut [f32], max_norm: f32) {
234 let total_norm: f32 = gradients.iter().map(|g| g.powi(2)).sum::<f32>().sqrt();
235
236 if total_norm > max_norm {
237 let scale = max_norm / (total_norm + 1e-6);
238 for g in gradients.iter_mut() {
239 *g *= scale;
240 }
241 }
242}
243
244pub fn apply_optimizer_update(
246 params: &mut [f32],
247 gradients: &[f32],
248 state: &mut TrainingState,
249 optimizer: &OptimizerType,
250) {
251 match optimizer {
252 OptimizerType::SGD => {
253 sgd_update(params, gradients, state.current_lr);
254 }
255 OptimizerType::SGDMomentum { momentum } => {
256 let momentum_f = *momentum as f32 / 100.0;
257 sgd_momentum_update(
258 params,
259 gradients,
260 &mut state.momentum_buffer,
261 state.current_lr,
262 momentum_f,
263 );
264 }
265 OptimizerType::Adam => {
266 state.adam_t += 1;
267 adam_update(
268 params,
269 gradients,
270 &mut state.adam_m,
271 &mut state.adam_v,
272 state.adam_t,
273 state.current_lr,
274 0.9,
275 0.999,
276 1e-8,
277 );
278 }
279 OptimizerType::AdamW { weight_decay } => {
280 state.adam_t += 1;
281 let wd = *weight_decay as f32 / 100.0;
282 adamw_update(
283 params,
284 gradients,
285 &mut state.adam_m,
286 &mut state.adam_v,
287 state.adam_t,
288 state.current_lr,
289 0.9,
290 0.999,
291 1e-8,
292 wd,
293 );
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_gradient_clipping() {
304 let mut grads = vec![10.0, 20.0, 30.0];
305 clip_gradients(&mut grads, 1.0);
306
307 let norm: f32 = grads.iter().map(|g| g.powi(2)).sum::<f32>().sqrt();
308 assert!(norm <= 1.0 + 1e-6);
309 }
310
311 #[test]
312 fn test_optimizer_sgd() {
313 let mut weights = vec![1.0, 2.0, 3.0];
314 let gradients = vec![0.1, 0.2, 0.3];
315
316 sgd_update(&mut weights, &gradients, 0.1);
317
318 assert!((weights[0] - 0.99).abs() < 1e-6);
319 assert!((weights[1] - 1.98).abs() < 1e-6);
320 assert!((weights[2] - 2.97).abs() < 1e-6);
321 }
322
323 #[test]
324 fn test_optimizer_adam() {
325 let mut weights = vec![1.0, 2.0, 3.0];
326 let gradients = vec![0.1, 0.2, 0.3];
327 let mut m = vec![0.0; 3];
328 let mut v = vec![0.0; 3];
329
330 adam_update(
331 &mut weights,
332 &gradients,
333 &mut m,
334 &mut v,
335 1,
336 0.001,
337 0.9,
338 0.999,
339 1e-8,
340 );
341
342 assert!(weights[0] < 1.0);
343 assert!(weights[1] < 2.0);
344 assert!(weights[2] < 3.0);
345 }
346
347 #[test]
348 fn test_training_state_lr_schedule() {
349 let mut state = TrainingState::new(100, 0.1);
350
351 state.epoch = 50;
352 state.update_learning_rate(0.1, &LearningRateSchedule::Constant, 100);
353 assert!((state.current_lr - 0.1).abs() < 1e-6);
354
355 state.update_learning_rate(
356 0.1,
357 &LearningRateSchedule::StepDecay {
358 decay_factor: 0.5,
359 step_size: 10,
360 },
361 100,
362 );
363 assert!((state.current_lr - 0.003125).abs() < 1e-6);
364
365 state.epoch = 50;
366 state.update_learning_rate(
367 0.1,
368 &LearningRateSchedule::CosineAnnealing { min_lr: 0.0 },
369 100,
370 );
371 assert!(state.current_lr > 0.0 && state.current_lr < 0.1);
372 }
373
374 #[test]
375 fn test_early_stopping() {
376 let mut state = TrainingState::new(100, 0.1);
377 let early_stopping = Some(EarlyStopping {
378 patience: 3,
379 min_delta: 0.01,
380 });
381
382 assert!(!state.should_stop(&early_stopping));
383
384 state.update_early_stopping(1.0, &early_stopping);
385 assert_eq!(state.epochs_without_improvement, 0);
386
387 state.update_early_stopping(1.0, &early_stopping);
388 assert_eq!(state.epochs_without_improvement, 1);
389
390 state.update_early_stopping(0.995, &early_stopping);
391 assert_eq!(state.epochs_without_improvement, 2);
392
393 state.update_early_stopping(1.0, &early_stopping);
394 assert_eq!(state.epochs_without_improvement, 3);
395
396 assert!(state.should_stop(&early_stopping));
397 }
398}