Skip to main content

axonml_train/
trainer.rs

1//! High-Level Training Utilities
2//!
3//! # File
4//! `crates/axonml-train/src/trainer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19
20use axonml_nn::Parameter;
21
22// =============================================================================
23// Training Configuration
24// =============================================================================
25
26/// Configuration for training.
27#[derive(Debug, Clone)]
28pub struct TrainingConfig {
29    /// Number of training epochs
30    pub epochs: usize,
31    /// Batch size
32    pub batch_size: usize,
33    /// Learning rate
34    pub learning_rate: f32,
35    /// Gradient clipping max norm (None = no clipping)
36    pub gradient_clip_norm: Option<f32>,
37    /// Number of gradient accumulation steps
38    pub gradient_accumulation_steps: usize,
39    /// Logging frequency (steps)
40    pub log_every: usize,
41    /// Evaluation frequency (epochs)
42    pub eval_every: usize,
43    /// Save checkpoints
44    pub save_checkpoints: bool,
45    /// Checkpoint directory
46    pub checkpoint_dir: String,
47    /// Use mixed precision training
48    pub mixed_precision: bool,
49    /// Seed for reproducibility
50    pub seed: Option<u64>,
51}
52
53impl Default for TrainingConfig {
54    fn default() -> Self {
55        Self {
56            epochs: 10,
57            batch_size: 32,
58            learning_rate: 1e-3,
59            gradient_clip_norm: None,
60            gradient_accumulation_steps: 1,
61            log_every: 100,
62            eval_every: 1,
63            save_checkpoints: false,
64            checkpoint_dir: "checkpoints".to_string(),
65            mixed_precision: false,
66            seed: None,
67        }
68    }
69}
70
71impl TrainingConfig {
72    /// Creates a new training configuration with defaults.
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    /// Builder: set number of epochs.
78    pub fn epochs(mut self, epochs: usize) -> Self {
79        self.epochs = epochs;
80        self
81    }
82
83    /// Builder: set batch size.
84    pub fn batch_size(mut self, batch_size: usize) -> Self {
85        self.batch_size = batch_size;
86        self
87    }
88
89    /// Builder: set learning rate.
90    pub fn learning_rate(mut self, lr: f32) -> Self {
91        self.learning_rate = lr;
92        self
93    }
94
95    /// Builder: set gradient clipping.
96    pub fn gradient_clip_norm(mut self, max_norm: f32) -> Self {
97        self.gradient_clip_norm = Some(max_norm);
98        self
99    }
100
101    /// Builder: set gradient accumulation steps.
102    pub fn gradient_accumulation_steps(mut self, steps: usize) -> Self {
103        self.gradient_accumulation_steps = steps.max(1);
104        self
105    }
106
107    /// Builder: set logging frequency.
108    pub fn log_every(mut self, steps: usize) -> Self {
109        self.log_every = steps;
110        self
111    }
112
113    /// Builder: enable mixed precision.
114    pub fn mixed_precision(mut self, enabled: bool) -> Self {
115        self.mixed_precision = enabled;
116        self
117    }
118
119    /// Builder: set seed.
120    pub fn seed(mut self, seed: u64) -> Self {
121        self.seed = Some(seed);
122        self
123    }
124}
125
126// =============================================================================
127// Training State
128// =============================================================================
129
130/// Current training state.
131#[derive(Debug, Clone)]
132pub struct TrainingState {
133    /// Current epoch (0-indexed)
134    pub epoch: usize,
135    /// Global step count
136    pub global_step: usize,
137    /// Best validation metric
138    pub best_metric: f32,
139    /// Training loss history
140    pub train_losses: Vec<f32>,
141    /// Validation loss history
142    pub val_losses: Vec<f32>,
143    /// Learning rate history
144    pub lr_history: Vec<f32>,
145}
146
147impl Default for TrainingState {
148    fn default() -> Self {
149        Self {
150            epoch: 0,
151            global_step: 0,
152            best_metric: f32::INFINITY,
153            train_losses: Vec::new(),
154            val_losses: Vec::new(),
155            lr_history: Vec::new(),
156        }
157    }
158}
159
160impl TrainingState {
161    /// Creates a new training state.
162    pub fn new() -> Self {
163        Self::default()
164    }
165
166    /// Returns the current epoch (1-indexed for display).
167    pub fn current_epoch(&self) -> usize {
168        self.epoch + 1
169    }
170
171    /// Returns average training loss for current epoch.
172    pub fn avg_train_loss(&self) -> f32 {
173        if self.train_losses.is_empty() {
174            0.0
175        } else {
176            self.train_losses.iter().sum::<f32>() / self.train_losses.len() as f32
177        }
178    }
179
180    /// Returns the last validation loss.
181    pub fn last_val_loss(&self) -> Option<f32> {
182        self.val_losses.last().copied()
183    }
184}
185
186// =============================================================================
187// Training Metrics
188// =============================================================================
189
190/// Metrics collected during training.
191#[derive(Debug, Clone)]
192pub struct TrainingMetrics {
193    /// Loss value
194    pub loss: f32,
195    /// Accuracy (if applicable)
196    pub accuracy: Option<f32>,
197    /// Additional metrics
198    pub extras: std::collections::HashMap<String, f32>,
199}
200
201impl TrainingMetrics {
202    /// Creates metrics with just loss.
203    pub fn new(loss: f32) -> Self {
204        Self {
205            loss,
206            accuracy: None,
207            extras: std::collections::HashMap::new(),
208        }
209    }
210
211    /// Adds accuracy metric.
212    pub fn with_accuracy(mut self, accuracy: f32) -> Self {
213        self.accuracy = Some(accuracy);
214        self
215    }
216
217    /// Adds a custom metric.
218    pub fn with_metric(mut self, name: &str, value: f32) -> Self {
219        self.extras.insert(name.to_string(), value);
220        self
221    }
222}
223
224// =============================================================================
225// Callback Trait
226// =============================================================================
227
228/// Callback for training events.
229pub trait Callback: Send {
230    /// Called at the start of training.
231    fn on_train_begin(&mut self, _state: &TrainingState) {}
232
233    /// Called at the end of training.
234    fn on_train_end(&mut self, _state: &TrainingState) {}
235
236    /// Called at the start of an epoch.
237    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) {}
238
239    /// Called at the end of an epoch.
240    fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> bool {
241        true // Continue training
242    }
243
244    /// Called after each training step.
245    fn on_step_end(&mut self, _step: usize, _metrics: &TrainingMetrics, _state: &TrainingState) {}
246
247    /// Called after validation.
248    fn on_validation_end(&mut self, _metrics: &TrainingMetrics, _state: &TrainingState) {}
249}
250
251// =============================================================================
252// Early Stopping Callback
253// =============================================================================
254
255/// Early stopping callback.
256pub struct EarlyStopping {
257    patience: usize,
258    min_delta: f32,
259    counter: usize,
260    best_loss: f32,
261    mode: String,
262}
263
264impl EarlyStopping {
265    /// Creates a new early stopping callback.
266    pub fn new(patience: usize) -> Self {
267        Self {
268            patience,
269            min_delta: 0.0,
270            counter: 0,
271            best_loss: f32::INFINITY,
272            mode: "min".to_string(),
273        }
274    }
275
276    /// Sets minimum delta for improvement.
277    pub fn min_delta(mut self, delta: f32) -> Self {
278        self.min_delta = delta;
279        self
280    }
281
282    /// Sets mode ("min" or "max").
283    pub fn mode(mut self, mode: &str) -> Self {
284        self.mode = mode.to_string();
285        self
286    }
287}
288
289impl Callback for EarlyStopping {
290    fn on_epoch_end(&mut self, _epoch: usize, state: &TrainingState) -> bool {
291        let current = state.val_losses.last().copied().unwrap_or(f32::INFINITY);
292
293        let improved = if self.mode == "min" {
294            current < self.best_loss - self.min_delta
295        } else {
296            current > self.best_loss + self.min_delta
297        };
298
299        if improved {
300            self.best_loss = current;
301            self.counter = 0;
302        } else {
303            self.counter += 1;
304        }
305
306        self.counter < self.patience
307    }
308}
309
310// =============================================================================
311// Progress Logger Callback
312// =============================================================================
313
314/// Simple progress logging callback.
315pub struct ProgressLogger {
316    log_every: usize,
317}
318
319impl ProgressLogger {
320    /// Creates a new progress logger.
321    pub fn new(log_every: usize) -> Self {
322        Self { log_every }
323    }
324}
325
326impl Callback for ProgressLogger {
327    fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) {
328        println!("Epoch {}", epoch + 1);
329    }
330
331    fn on_step_end(&mut self, step: usize, metrics: &TrainingMetrics, _state: &TrainingState) {
332        if step % self.log_every == 0 {
333            print!("  Step {}: loss = {:.4}", step, metrics.loss);
334            if let Some(acc) = metrics.accuracy {
335                print!(", accuracy = {:.2}%", acc * 100.0);
336            }
337            println!();
338        }
339    }
340
341    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> bool {
342        println!(
343            "Epoch {} complete: avg_loss = {:.4}",
344            epoch + 1,
345            state.avg_train_loss()
346        );
347        if let Some(val_loss) = state.last_val_loss() {
348            println!("  Validation loss: {:.4}", val_loss);
349        }
350        true
351    }
352}
353
354// =============================================================================
355// Training History
356// =============================================================================
357
358/// Complete training history.
359#[derive(Debug, Clone)]
360pub struct TrainingHistory {
361    /// Training losses per epoch
362    pub train_loss: Vec<f32>,
363    /// Validation losses per epoch
364    pub val_loss: Vec<f32>,
365    /// Learning rates per epoch
366    pub learning_rates: Vec<f32>,
367    /// Training duration in seconds
368    pub duration_secs: f64,
369    /// Number of epochs completed
370    pub epochs_completed: usize,
371    /// Whether training completed successfully
372    pub completed: bool,
373}
374
375impl TrainingHistory {
376    /// Creates an empty history.
377    pub fn new() -> Self {
378        Self {
379            train_loss: Vec::new(),
380            val_loss: Vec::new(),
381            learning_rates: Vec::new(),
382            duration_secs: 0.0,
383            epochs_completed: 0,
384            completed: false,
385        }
386    }
387
388    /// Returns the best training loss.
389    pub fn best_train_loss(&self) -> Option<f32> {
390        self.train_loss.iter().copied().reduce(f32::min)
391    }
392
393    /// Returns the best validation loss.
394    pub fn best_val_loss(&self) -> Option<f32> {
395        self.val_loss.iter().copied().reduce(f32::min)
396    }
397}
398
399impl Default for TrainingHistory {
400    fn default() -> Self {
401        Self::new()
402    }
403}
404
405// =============================================================================
406// Utility Functions
407// =============================================================================
408
409/// Clips gradients by global norm.
410pub fn clip_grad_norm(parameters: &[Parameter], max_norm: f32) -> f32 {
411    let mut total_norm_sq = 0.0f32;
412
413    for param in parameters {
414        if let Some(grad) = param.grad() {
415            let grad_vec = grad.to_vec();
416            total_norm_sq += grad_vec.iter().map(|x| x * x).sum::<f32>();
417        }
418    }
419
420    let total_norm = total_norm_sq.sqrt();
421
422    if total_norm > max_norm {
423        let clip_coef = max_norm / (total_norm + 1e-6);
424        for param in parameters {
425            if let Some(grad) = param.grad() {
426                let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
427                {
428                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
429                    param.variable().set_grad(clipped_tensor);
430                }
431            }
432        }
433    }
434
435    total_norm
436}
437
438/// Computes accuracy for classification.
439pub fn compute_accuracy(predictions: &Tensor<f32>, targets: &Tensor<f32>) -> f32 {
440    let pred_vec = predictions.to_vec();
441    let target_vec = targets.to_vec();
442
443    // Assume predictions are logits [batch, num_classes] and targets are indices
444    let batch_size = predictions.shape()[0];
445    let num_classes = if predictions.shape().len() > 1 {
446        predictions.shape()[1]
447    } else {
448        1
449    };
450
451    let mut correct = 0;
452
453    for (b, &target_f) in target_vec.iter().enumerate().take(batch_size) {
454        // Find argmax of predictions
455        let mut max_idx = 0;
456        let mut max_val = f32::NEG_INFINITY;
457        for c in 0..num_classes {
458            let idx = b * num_classes + c;
459            if pred_vec[idx] > max_val {
460                max_val = pred_vec[idx];
461                max_idx = c;
462            }
463        }
464
465        // Compare with target
466        let target = target_f as usize;
467        if max_idx == target {
468            correct += 1;
469        }
470    }
471
472    correct as f32 / batch_size as f32
473}
474
475// =============================================================================
476// Tests
477// =============================================================================
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_training_config_default() {
485        let config = TrainingConfig::default();
486        assert_eq!(config.epochs, 10);
487        assert_eq!(config.batch_size, 32);
488    }
489
490    #[test]
491    fn test_training_config_builder() {
492        let config = TrainingConfig::new()
493            .epochs(20)
494            .batch_size(64)
495            .learning_rate(0.01)
496            .gradient_clip_norm(1.0);
497
498        assert_eq!(config.epochs, 20);
499        assert_eq!(config.batch_size, 64);
500        assert!((config.learning_rate - 0.01).abs() < 1e-6);
501        assert_eq!(config.gradient_clip_norm, Some(1.0));
502    }
503
504    #[test]
505    fn test_training_state() {
506        let mut state = TrainingState::new();
507        state.train_losses.push(0.5);
508        state.train_losses.push(0.3);
509
510        assert!((state.avg_train_loss() - 0.4).abs() < 1e-6);
511    }
512
513    #[test]
514    fn test_early_stopping() {
515        let mut callback = EarlyStopping::new(3);
516        let mut state = TrainingState::new();
517
518        // Improving
519        state.val_losses.push(1.0);
520        assert!(callback.on_epoch_end(0, &state));
521
522        state.val_losses.push(0.8);
523        assert!(callback.on_epoch_end(1, &state));
524
525        // Not improving
526        state.val_losses.push(0.9);
527        assert!(callback.on_epoch_end(2, &state)); // counter = 1
528
529        state.val_losses.push(0.85);
530        assert!(callback.on_epoch_end(3, &state)); // counter = 2
531
532        state.val_losses.push(0.82);
533        assert!(!callback.on_epoch_end(4, &state)); // counter = 3, stop
534    }
535
536    #[test]
537    fn test_training_metrics() {
538        let metrics = TrainingMetrics::new(0.5)
539            .with_accuracy(0.9)
540            .with_metric("f1", 0.85);
541
542        assert!((metrics.loss - 0.5).abs() < 1e-6);
543        assert_eq!(metrics.accuracy, Some(0.9));
544        assert_eq!(metrics.extras.get("f1"), Some(&0.85));
545    }
546
547    #[test]
548    fn test_training_history() {
549        let mut history = TrainingHistory::new();
550        history.train_loss = vec![0.5, 0.3, 0.2];
551        history.val_loss = vec![0.6, 0.4, 0.35];
552
553        assert_eq!(history.best_train_loss(), Some(0.2));
554        assert_eq!(history.best_val_loss(), Some(0.35));
555    }
556
557    #[test]
558    fn test_compute_accuracy() {
559        use axonml_tensor::Tensor;
560
561        // 2 samples, 3 classes
562        // Sample 0: [0.1, 0.8, 0.1] -> predicted class 1
563        // Sample 1: [0.9, 0.05, 0.05] -> predicted class 0
564        let predictions = Tensor::from_vec(vec![0.1, 0.8, 0.1, 0.9, 0.05, 0.05], &[2, 3]).unwrap();
565
566        // Targets: [1, 0] (both correct)
567        let targets = Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap();
568
569        let accuracy = compute_accuracy(&predictions, &targets);
570        assert!((accuracy - 1.0).abs() < 1e-6);
571    }
572}