1use axonml_tensor::Tensor;
19
20use axonml_nn::Parameter;
21
22#[derive(Debug, Clone)]
28pub struct TrainingConfig {
29 pub epochs: usize,
31 pub batch_size: usize,
33 pub learning_rate: f32,
35 pub gradient_clip_norm: Option<f32>,
37 pub gradient_accumulation_steps: usize,
39 pub log_every: usize,
41 pub eval_every: usize,
43 pub save_checkpoints: bool,
45 pub checkpoint_dir: String,
47 pub mixed_precision: bool,
49 pub seed: Option<u64>,
51}
52
53impl Default for TrainingConfig {
54 fn default() -> Self {
55 Self {
56 epochs: 10,
57 batch_size: 32,
58 learning_rate: 1e-3,
59 gradient_clip_norm: None,
60 gradient_accumulation_steps: 1,
61 log_every: 100,
62 eval_every: 1,
63 save_checkpoints: false,
64 checkpoint_dir: "checkpoints".to_string(),
65 mixed_precision: false,
66 seed: None,
67 }
68 }
69}
70
71impl TrainingConfig {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn epochs(mut self, epochs: usize) -> Self {
79 self.epochs = epochs;
80 self
81 }
82
83 pub fn batch_size(mut self, batch_size: usize) -> Self {
85 self.batch_size = batch_size;
86 self
87 }
88
89 pub fn learning_rate(mut self, lr: f32) -> Self {
91 self.learning_rate = lr;
92 self
93 }
94
95 pub fn gradient_clip_norm(mut self, max_norm: f32) -> Self {
97 self.gradient_clip_norm = Some(max_norm);
98 self
99 }
100
101 pub fn gradient_accumulation_steps(mut self, steps: usize) -> Self {
103 self.gradient_accumulation_steps = steps.max(1);
104 self
105 }
106
107 pub fn log_every(mut self, steps: usize) -> Self {
109 self.log_every = steps;
110 self
111 }
112
113 pub fn mixed_precision(mut self, enabled: bool) -> Self {
115 self.mixed_precision = enabled;
116 self
117 }
118
119 pub fn seed(mut self, seed: u64) -> Self {
121 self.seed = Some(seed);
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
132pub struct TrainingState {
133 pub epoch: usize,
135 pub global_step: usize,
137 pub best_metric: f32,
139 pub train_losses: Vec<f32>,
141 pub val_losses: Vec<f32>,
143 pub lr_history: Vec<f32>,
145}
146
147impl Default for TrainingState {
148 fn default() -> Self {
149 Self {
150 epoch: 0,
151 global_step: 0,
152 best_metric: f32::INFINITY,
153 train_losses: Vec::new(),
154 val_losses: Vec::new(),
155 lr_history: Vec::new(),
156 }
157 }
158}
159
160impl TrainingState {
161 pub fn new() -> Self {
163 Self::default()
164 }
165
166 pub fn current_epoch(&self) -> usize {
168 self.epoch + 1
169 }
170
171 pub fn avg_train_loss(&self) -> f32 {
173 if self.train_losses.is_empty() {
174 0.0
175 } else {
176 self.train_losses.iter().sum::<f32>() / self.train_losses.len() as f32
177 }
178 }
179
180 pub fn last_val_loss(&self) -> Option<f32> {
182 self.val_losses.last().copied()
183 }
184}
185
186#[derive(Debug, Clone)]
192pub struct TrainingMetrics {
193 pub loss: f32,
195 pub accuracy: Option<f32>,
197 pub extras: std::collections::HashMap<String, f32>,
199}
200
201impl TrainingMetrics {
202 pub fn new(loss: f32) -> Self {
204 Self {
205 loss,
206 accuracy: None,
207 extras: std::collections::HashMap::new(),
208 }
209 }
210
211 pub fn with_accuracy(mut self, accuracy: f32) -> Self {
213 self.accuracy = Some(accuracy);
214 self
215 }
216
217 pub fn with_metric(mut self, name: &str, value: f32) -> Self {
219 self.extras.insert(name.to_string(), value);
220 self
221 }
222}
223
224pub trait Callback: Send {
230 fn on_train_begin(&mut self, _state: &TrainingState) {}
232
233 fn on_train_end(&mut self, _state: &TrainingState) {}
235
236 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) {}
238
239 fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> bool {
241 true }
243
244 fn on_step_end(&mut self, _step: usize, _metrics: &TrainingMetrics, _state: &TrainingState) {}
246
247 fn on_validation_end(&mut self, _metrics: &TrainingMetrics, _state: &TrainingState) {}
249}
250
251pub struct EarlyStopping {
257 patience: usize,
258 min_delta: f32,
259 counter: usize,
260 best_loss: f32,
261 mode: String,
262}
263
264impl EarlyStopping {
265 pub fn new(patience: usize) -> Self {
267 Self {
268 patience,
269 min_delta: 0.0,
270 counter: 0,
271 best_loss: f32::INFINITY,
272 mode: "min".to_string(),
273 }
274 }
275
276 pub fn min_delta(mut self, delta: f32) -> Self {
278 self.min_delta = delta;
279 self
280 }
281
282 pub fn mode(mut self, mode: &str) -> Self {
284 self.mode = mode.to_string();
285 self
286 }
287}
288
289impl Callback for EarlyStopping {
290 fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> bool {
291 let current = state.val_losses.last().copied().unwrap_or(f32::INFINITY);
292
293 let improved = if self.mode == "min" {
294 current < self.best_loss - self.min_delta
295 } else {
296 current > self.best_loss + self.min_delta
297 };
298
299 if improved {
300 self.best_loss = current;
301 self.counter = 0;
302 } else {
303 self.counter += 1;
304 }
305
306 self.counter < self.patience
307 }
308}
309
310pub struct ProgressLogger {
316 log_every: usize,
317}
318
319impl ProgressLogger {
320 pub fn new(log_every: usize) -> Self {
322 Self { log_every }
323 }
324}
325
326impl Callback for ProgressLogger {
327 fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) {
328 println!("Epoch {}", epoch + 1);
329 }
330
331 fn on_step_end(&mut self, step: usize, metrics: &TrainingMetrics, _state: &TrainingState) {
332 if step % self.log_every == 0 {
333 print!(" Step {}: loss = {:.4}", step, metrics.loss);
334 if let Some(acc) = metrics.accuracy {
335 print!(", accuracy = {:.2}%", acc * 100.0);
336 }
337 println!();
338 }
339 }
340
341 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> bool {
342 println!(
343 "Epoch {} complete: avg_loss = {:.4}",
344 epoch + 1,
345 state.avg_train_loss()
346 );
347 if let Some(val_loss) = state.last_val_loss() {
348 println!(" Validation loss: {:.4}", val_loss);
349 }
350 true
351 }
352}
353
354#[derive(Debug, Clone)]
360pub struct TrainingHistory {
361 pub train_loss: Vec<f32>,
363 pub val_loss: Vec<f32>,
365 pub learning_rates: Vec<f32>,
367 pub duration_secs: f64,
369 pub epochs_completed: usize,
371 pub completed: bool,
373}
374
375impl TrainingHistory {
376 pub fn new() -> Self {
378 Self {
379 train_loss: Vec::new(),
380 val_loss: Vec::new(),
381 learning_rates: Vec::new(),
382 duration_secs: 0.0,
383 epochs_completed: 0,
384 completed: false,
385 }
386 }
387
388 pub fn best_train_loss(&self) -> Option<f32> {
390 self.train_loss.iter().copied().reduce(f32::min)
391 }
392
393 pub fn best_val_loss(&self) -> Option<f32> {
395 self.val_loss.iter().copied().reduce(f32::min)
396 }
397}
398
399impl Default for TrainingHistory {
400 fn default() -> Self {
401 Self::new()
402 }
403}
404
405pub fn clip_grad_norm(parameters: &[Parameter], max_norm: f32) -> f32 {
411 let mut total_norm_sq = 0.0f32;
412
413 for param in parameters {
414 if let Some(grad) = param.grad() {
415 let grad_vec = grad.to_vec();
416 total_norm_sq += grad_vec.iter().map(|x| x * x).sum::<f32>();
417 }
418 }
419
420 let total_norm = total_norm_sq.sqrt();
421
422 if total_norm > max_norm {
423 let clip_coef = max_norm / (total_norm + 1e-6);
424 for param in parameters {
425 if let Some(grad) = param.grad() {
426 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
427 {
428 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
429 param.variable().set_grad(clipped_tensor);
430 }
431 }
432 }
433 }
434
435 total_norm
436}
437
438pub fn compute_accuracy(predictions: &Tensor<f32>, targets: &Tensor<f32>) -> f32 {
440 let pred_vec = predictions.to_vec();
441 let target_vec = targets.to_vec();
442
443 let batch_size = predictions.shape()[0];
445 let num_classes = if predictions.shape().len() > 1 {
446 predictions.shape()[1]
447 } else {
448 1
449 };
450
451 let mut correct = 0;
452
453 for (b, &target_f) in target_vec.iter().enumerate().take(batch_size) {
454 let mut max_idx = 0;
456 let mut max_val = f32::NEG_INFINITY;
457 for c in 0..num_classes {
458 let idx = b * num_classes + c;
459 if pred_vec[idx] > max_val {
460 max_val = pred_vec[idx];
461 max_idx = c;
462 }
463 }
464
465 let target = target_f as usize;
467 if max_idx == target {
468 correct += 1;
469 }
470 }
471
472 correct as f32 / batch_size as f32
473}
474
475#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_training_config_default() {
485 let config = TrainingConfig::default();
486 assert_eq!(config.epochs, 10);
487 assert_eq!(config.batch_size, 32);
488 }
489
490 #[test]
491 fn test_training_config_builder() {
492 let config = TrainingConfig::new()
493 .epochs(20)
494 .batch_size(64)
495 .learning_rate(0.01)
496 .gradient_clip_norm(1.0);
497
498 assert_eq!(config.epochs, 20);
499 assert_eq!(config.batch_size, 64);
500 assert!((config.learning_rate - 0.01).abs() < 1e-6);
501 assert_eq!(config.gradient_clip_norm, Some(1.0));
502 }
503
504 #[test]
505 fn test_training_state() {
506 let mut state = TrainingState::new();
507 state.train_losses.push(0.5);
508 state.train_losses.push(0.3);
509
510 assert!((state.avg_train_loss() - 0.4).abs() < 1e-6);
511 }
512
513 #[test]
514 fn test_early_stopping() {
515 let mut callback = EarlyStopping::new(3);
516 let mut state = TrainingState::new();
517
518 state.val_losses.push(1.0);
520 assert!(callback.on_epoch_end(0, &state));
521
522 state.val_losses.push(0.8);
523 assert!(callback.on_epoch_end(1, &state));
524
525 state.val_losses.push(0.9);
527 assert!(callback.on_epoch_end(2, &state)); state.val_losses.push(0.85);
530 assert!(callback.on_epoch_end(3, &state)); state.val_losses.push(0.82);
533 assert!(!callback.on_epoch_end(4, &state)); }
535
536 #[test]
537 fn test_training_metrics() {
538 let metrics = TrainingMetrics::new(0.5)
539 .with_accuracy(0.9)
540 .with_metric("f1", 0.85);
541
542 assert!((metrics.loss - 0.5).abs() < 1e-6);
543 assert_eq!(metrics.accuracy, Some(0.9));
544 assert_eq!(metrics.extras.get("f1"), Some(&0.85));
545 }
546
547 #[test]
548 fn test_training_history() {
549 let mut history = TrainingHistory::new();
550 history.train_loss = vec![0.5, 0.3, 0.2];
551 history.val_loss = vec![0.6, 0.4, 0.35];
552
553 assert_eq!(history.best_train_loss(), Some(0.2));
554 assert_eq!(history.best_val_loss(), Some(0.35));
555 }
556
557 #[test]
558 fn test_compute_accuracy() {
559 use axonml_tensor::Tensor;
560
561 let predictions = Tensor::from_vec(vec![0.1, 0.8, 0.1, 0.9, 0.05, 0.05], &[2, 3]).unwrap();
565
566 let targets = Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap();
568
569 let accuracy = compute_accuracy(&predictions, &targets);
570 assert!((accuracy - 1.0).abs() < 1e-6);
571 }
572}