axonml_optim/optimizer.rs
1//! Optimizer Trait - Core Optimizer Interface
2//!
3//! # File
4//! `crates/axonml-optim/src/optimizer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18
19// =============================================================================
20// Optimizer Trait
21// =============================================================================
22
23/// Trait for all optimizers.
24///
25/// Optimizers update model parameters based on gradients.
26pub trait Optimizer {
27 /// Performs a single optimization step.
28 ///
29 /// Updates all parameters based on their gradients.
30 fn step(&mut self);
31
32 /// Zeros all parameter gradients.
33 fn zero_grad(&mut self);
34
35 /// Returns the current learning rate.
36 fn get_lr(&self) -> f32;
37
38 /// Sets the learning rate.
39 fn set_lr(&mut self, lr: f32);
40
41 /// Returns the parameters being optimized.
42 fn parameters(&self) -> &[Parameter];
43
44 /// Returns the number of parameters.
45 fn num_parameters(&self) -> usize {
46 self.parameters().len()
47 }
48}
49
50// =============================================================================
51// Parameter State
52// =============================================================================
53
54/// State associated with a parameter during optimization.
55///
56/// Different optimizers store different state (e.g., momentum, variance).
57#[derive(Debug, Clone)]
58pub struct ParamState {
59 /// First moment (momentum) - used by SGD with momentum, Adam
60 pub momentum_buffer: Option<Vec<f32>>,
61 /// Second moment (variance) - used by Adam, `RMSprop`
62 pub exp_avg_sq: Option<Vec<f32>>,
63 /// Max second moment - used by `AdaMax`
64 pub max_exp_avg_sq: Option<Vec<f32>>,
65 /// Step count for bias correction
66 pub step: usize,
67}
68
69impl ParamState {
70 /// Creates a new empty parameter state.
71 #[must_use]
72 pub fn new() -> Self {
73 Self {
74 momentum_buffer: None,
75 exp_avg_sq: None,
76 max_exp_avg_sq: None,
77 step: 0,
78 }
79 }
80
81 /// Initializes momentum buffer with zeros.
82 pub fn init_momentum(&mut self, size: usize) {
83 self.momentum_buffer = Some(vec![0.0; size]);
84 }
85
86 /// Initializes exponential average squared buffer with zeros.
87 pub fn init_exp_avg_sq(&mut self, size: usize) {
88 self.exp_avg_sq = Some(vec![0.0; size]);
89 }
90}
91
92impl Default for ParamState {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98// =============================================================================
99// Tests
100// =============================================================================
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_param_state_creation() {
108 let mut state = ParamState::new();
109 assert!(state.momentum_buffer.is_none());
110 assert!(state.exp_avg_sq.is_none());
111 assert_eq!(state.step, 0);
112
113 state.init_momentum(10);
114 assert!(state.momentum_buffer.is_some());
115 assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
116 }
117}