axonml_optim/optimizer.rs
1//! `Optimizer` trait — the core interface for all gradient-based optimizers.
2//!
3//! Requires `step()` (apply one update), `zero_grad()` (clear accumulated
4//! gradients), `get_lr()` / `set_lr()` (learning rate access), and
5//! `parameters()` (list of tracked Parameter refs). Also `clip_grad_norm`
6//! utility for gradient clipping before the step.
7//!
8//! # File
9//! `crates/axonml-optim/src/optimizer.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use axonml_nn::Parameter;
24
25// =============================================================================
26// Optimizer Trait
27// =============================================================================
28
29/// Trait for all optimizers.
30///
31/// Optimizers update model parameters based on gradients.
32pub trait Optimizer {
33 /// Performs a single optimization step.
34 ///
35 /// Updates all parameters based on their gradients.
36 fn step(&mut self);
37
38 /// Zeros all parameter gradients.
39 fn zero_grad(&mut self);
40
41 /// Returns the current learning rate.
42 fn get_lr(&self) -> f32;
43
44 /// Sets the learning rate.
45 fn set_lr(&mut self, lr: f32);
46
47 /// Returns the parameters being optimized.
48 fn parameters(&self) -> &[Parameter];
49
50 /// Returns the number of parameters.
51 fn num_parameters(&self) -> usize {
52 self.parameters().len()
53 }
54}
55
56// =============================================================================
57// Parameter State
58// =============================================================================
59
60/// State associated with a parameter during optimization.
61///
62/// Different optimizers store different state (e.g., momentum, variance).
63#[derive(Debug, Clone)]
64pub struct ParamState {
65 /// First moment (momentum) - used by SGD with momentum, Adam
66 pub momentum_buffer: Option<Vec<f32>>,
67 /// Second moment (variance) - used by Adam, `RMSprop`
68 pub exp_avg_sq: Option<Vec<f32>>,
69 /// Max second moment - used by `AdaMax`
70 pub max_exp_avg_sq: Option<Vec<f32>>,
71 /// Step count for bias correction
72 pub step: usize,
73}
74
75impl ParamState {
76 /// Creates a new empty parameter state.
77 #[must_use]
78 pub fn new() -> Self {
79 Self {
80 momentum_buffer: None,
81 exp_avg_sq: None,
82 max_exp_avg_sq: None,
83 step: 0,
84 }
85 }
86
87 /// Initializes momentum buffer with zeros.
88 pub fn init_momentum(&mut self, size: usize) {
89 self.momentum_buffer = Some(vec![0.0; size]);
90 }
91
92 /// Initializes exponential average squared buffer with zeros.
93 pub fn init_exp_avg_sq(&mut self, size: usize) {
94 self.exp_avg_sq = Some(vec![0.0; size]);
95 }
96}
97
98impl Default for ParamState {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104// =============================================================================
105// Tests
106// =============================================================================
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_param_state_creation() {
114 let mut state = ParamState::new();
115 assert!(state.momentum_buffer.is_none());
116 assert!(state.exp_avg_sq.is_none());
117 assert_eq!(state.step, 0);
118
119 state.init_momentum(10);
120 assert!(state.momentum_buffer.is_some());
121 assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
122 }
123}