Skip to main content

axonml_optim/
optimizer.rs

1//! Optimizer Trait - Core Optimizer Interface
2//!
3//! # File
4//! `crates/axonml-optim/src/optimizer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18
19// =============================================================================
20// Optimizer Trait
21// =============================================================================
22
23/// Trait for all optimizers.
24///
25/// Optimizers update model parameters based on gradients.
26pub trait Optimizer {
27    /// Performs a single optimization step.
28    ///
29    /// Updates all parameters based on their gradients.
30    fn step(&mut self);
31
32    /// Zeros all parameter gradients.
33    fn zero_grad(&mut self);
34
35    /// Returns the current learning rate.
36    fn get_lr(&self) -> f32;
37
38    /// Sets the learning rate.
39    fn set_lr(&mut self, lr: f32);
40
41    /// Returns the parameters being optimized.
42    fn parameters(&self) -> &[Parameter];
43
44    /// Returns the number of parameters.
45    fn num_parameters(&self) -> usize {
46        self.parameters().len()
47    }
48}
49
50// =============================================================================
51// Parameter State
52// =============================================================================
53
54/// State associated with a parameter during optimization.
55///
56/// Different optimizers store different state (e.g., momentum, variance).
57#[derive(Debug, Clone)]
58pub struct ParamState {
59    /// First moment (momentum) - used by SGD with momentum, Adam
60    pub momentum_buffer: Option<Vec<f32>>,
61    /// Second moment (variance) - used by Adam, `RMSprop`
62    pub exp_avg_sq: Option<Vec<f32>>,
63    /// Max second moment - used by `AdaMax`
64    pub max_exp_avg_sq: Option<Vec<f32>>,
65    /// Step count for bias correction
66    pub step: usize,
67}
68
69impl ParamState {
70    /// Creates a new empty parameter state.
71    #[must_use]
72    pub fn new() -> Self {
73        Self {
74            momentum_buffer: None,
75            exp_avg_sq: None,
76            max_exp_avg_sq: None,
77            step: 0,
78        }
79    }
80
81    /// Initializes momentum buffer with zeros.
82    pub fn init_momentum(&mut self, size: usize) {
83        self.momentum_buffer = Some(vec![0.0; size]);
84    }
85
86    /// Initializes exponential average squared buffer with zeros.
87    pub fn init_exp_avg_sq(&mut self, size: usize) {
88        self.exp_avg_sq = Some(vec![0.0; size]);
89    }
90}
91
92impl Default for ParamState {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98// =============================================================================
99// Tests
100// =============================================================================
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_param_state_creation() {
108        let mut state = ParamState::new();
109        assert!(state.momentum_buffer.is_none());
110        assert!(state.exp_avg_sq.is_none());
111        assert_eq!(state.step, 0);
112
113        state.init_momentum(10);
114        assert!(state.momentum_buffer.is_some());
115        assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
116    }
117}