axonml_optim/optimizer.rs
1//! Optimizer Trait - Core Optimizer Interface
2//!
3//! Defines the trait that all optimizers implement.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9
10// =============================================================================
11// Optimizer Trait
12// =============================================================================
13
14/// Trait for all optimizers.
15///
16/// Optimizers update model parameters based on gradients.
17pub trait Optimizer {
18 /// Performs a single optimization step.
19 ///
20 /// Updates all parameters based on their gradients.
21 fn step(&mut self);
22
23 /// Zeros all parameter gradients.
24 fn zero_grad(&mut self);
25
26 /// Returns the current learning rate.
27 fn get_lr(&self) -> f32;
28
29 /// Sets the learning rate.
30 fn set_lr(&mut self, lr: f32);
31
32 /// Returns the parameters being optimized.
33 fn parameters(&self) -> &[Parameter];
34
35 /// Returns the number of parameters.
36 fn num_parameters(&self) -> usize {
37 self.parameters().len()
38 }
39}
40
41// =============================================================================
42// Parameter State
43// =============================================================================
44
45/// State associated with a parameter during optimization.
46///
47/// Different optimizers store different state (e.g., momentum, variance).
48#[derive(Debug, Clone)]
49pub struct ParamState {
50 /// First moment (momentum) - used by SGD with momentum, Adam
51 pub momentum_buffer: Option<Vec<f32>>,
52 /// Second moment (variance) - used by Adam, `RMSprop`
53 pub exp_avg_sq: Option<Vec<f32>>,
54 /// Max second moment - used by `AdaMax`
55 pub max_exp_avg_sq: Option<Vec<f32>>,
56 /// Step count for bias correction
57 pub step: usize,
58}
59
60impl ParamState {
61 /// Creates a new empty parameter state.
62 #[must_use] pub fn new() -> Self {
63 Self {
64 momentum_buffer: None,
65 exp_avg_sq: None,
66 max_exp_avg_sq: None,
67 step: 0,
68 }
69 }
70
71 /// Initializes momentum buffer with zeros.
72 pub fn init_momentum(&mut self, size: usize) {
73 self.momentum_buffer = Some(vec![0.0; size]);
74 }
75
76 /// Initializes exponential average squared buffer with zeros.
77 pub fn init_exp_avg_sq(&mut self, size: usize) {
78 self.exp_avg_sq = Some(vec![0.0; size]);
79 }
80}
81
82impl Default for ParamState {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88// =============================================================================
89// Tests
90// =============================================================================
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_param_state_creation() {
98 let mut state = ParamState::new();
99 assert!(state.momentum_buffer.is_none());
100 assert!(state.exp_avg_sq.is_none());
101 assert_eq!(state.step, 0);
102
103 state.init_momentum(10);
104 assert!(state.momentum_buffer.is_some());
105 assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
106 }
107}