axonml_optim/
optimizer.rs

1//! Optimizer Trait - Core Optimizer Interface
2//!
3//! Defines the trait that all optimizers implement.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9
10// =============================================================================
11// Optimizer Trait
12// =============================================================================
13
14/// Trait for all optimizers.
15///
16/// Optimizers update model parameters based on gradients.
17pub trait Optimizer {
18    /// Performs a single optimization step.
19    ///
20    /// Updates all parameters based on their gradients.
21    fn step(&mut self);
22
23    /// Zeros all parameter gradients.
24    fn zero_grad(&mut self);
25
26    /// Returns the current learning rate.
27    fn get_lr(&self) -> f32;
28
29    /// Sets the learning rate.
30    fn set_lr(&mut self, lr: f32);
31
32    /// Returns the parameters being optimized.
33    fn parameters(&self) -> &[Parameter];
34
35    /// Returns the number of parameters.
36    fn num_parameters(&self) -> usize {
37        self.parameters().len()
38    }
39}
40
41// =============================================================================
42// Parameter State
43// =============================================================================
44
45/// State associated with a parameter during optimization.
46///
47/// Different optimizers store different state (e.g., momentum, variance).
48#[derive(Debug, Clone)]
49pub struct ParamState {
50    /// First moment (momentum) - used by SGD with momentum, Adam
51    pub momentum_buffer: Option<Vec<f32>>,
52    /// Second moment (variance) - used by Adam, `RMSprop`
53    pub exp_avg_sq: Option<Vec<f32>>,
54    /// Max second moment - used by `AdaMax`
55    pub max_exp_avg_sq: Option<Vec<f32>>,
56    /// Step count for bias correction
57    pub step: usize,
58}
59
60impl ParamState {
61    /// Creates a new empty parameter state.
62    #[must_use] pub fn new() -> Self {
63        Self {
64            momentum_buffer: None,
65            exp_avg_sq: None,
66            max_exp_avg_sq: None,
67            step: 0,
68        }
69    }
70
71    /// Initializes momentum buffer with zeros.
72    pub fn init_momentum(&mut self, size: usize) {
73        self.momentum_buffer = Some(vec![0.0; size]);
74    }
75
76    /// Initializes exponential average squared buffer with zeros.
77    pub fn init_exp_avg_sq(&mut self, size: usize) {
78        self.exp_avg_sq = Some(vec![0.0; size]);
79    }
80}
81
82impl Default for ParamState {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88// =============================================================================
89// Tests
90// =============================================================================
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_param_state_creation() {
98        let mut state = ParamState::new();
99        assert!(state.momentum_buffer.is_none());
100        assert!(state.exp_avg_sq.is_none());
101        assert_eq!(state.step, 0);
102
103        state.init_momentum(10);
104        assert!(state.momentum_buffer.is_some());
105        assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
106    }
107}