Skip to main content

axonml_optim/
optimizer.rs

1//! `Optimizer` trait — the core interface for all gradient-based optimizers.
2//!
3//! Requires `step()` (apply one update), `zero_grad()` (clear accumulated
4//! gradients), `get_lr()` / `set_lr()` (learning rate access), and
5//! `parameters()` (list of tracked Parameter refs). Also `clip_grad_norm`
6//! utility for gradient clipping before the step.
7//!
8//! # File
9//! `crates/axonml-optim/src/optimizer.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use axonml_nn::Parameter;
24
25// =============================================================================
26// Optimizer Trait
27// =============================================================================
28
29/// Trait for all optimizers.
30///
31/// Optimizers update model parameters based on gradients.
32pub trait Optimizer {
33    /// Performs a single optimization step.
34    ///
35    /// Updates all parameters based on their gradients.
36    fn step(&mut self);
37
38    /// Zeros all parameter gradients.
39    fn zero_grad(&mut self);
40
41    /// Returns the current learning rate.
42    fn get_lr(&self) -> f32;
43
44    /// Sets the learning rate.
45    fn set_lr(&mut self, lr: f32);
46
47    /// Returns the parameters being optimized.
48    fn parameters(&self) -> &[Parameter];
49
50    /// Returns the number of parameters.
51    fn num_parameters(&self) -> usize {
52        self.parameters().len()
53    }
54}
55
56// =============================================================================
57// Parameter State
58// =============================================================================
59
60/// State associated with a parameter during optimization.
61///
62/// Different optimizers store different state (e.g., momentum, variance).
63#[derive(Debug, Clone)]
64pub struct ParamState {
65    /// First moment (momentum) - used by SGD with momentum, Adam
66    pub momentum_buffer: Option<Vec<f32>>,
67    /// Second moment (variance) - used by Adam, `RMSprop`
68    pub exp_avg_sq: Option<Vec<f32>>,
69    /// Max second moment - used by `AdaMax`
70    pub max_exp_avg_sq: Option<Vec<f32>>,
71    /// Step count for bias correction
72    pub step: usize,
73}
74
75impl ParamState {
76    /// Creates a new empty parameter state.
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            momentum_buffer: None,
81            exp_avg_sq: None,
82            max_exp_avg_sq: None,
83            step: 0,
84        }
85    }
86
87    /// Initializes momentum buffer with zeros.
88    pub fn init_momentum(&mut self, size: usize) {
89        self.momentum_buffer = Some(vec![0.0; size]);
90    }
91
92    /// Initializes exponential average squared buffer with zeros.
93    pub fn init_exp_avg_sq(&mut self, size: usize) {
94        self.exp_avg_sq = Some(vec![0.0; size]);
95    }
96}
97
98impl Default for ParamState {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104// =============================================================================
105// Tests
106// =============================================================================
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_param_state_creation() {
114        let mut state = ParamState::new();
115        assert!(state.momentum_buffer.is_none());
116        assert!(state.exp_avg_sq.is_none());
117        assert_eq!(state.step, 0);
118
119        state.init_momentum(10);
120        assert!(state.momentum_buffer.is_some());
121        assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
122    }
123}