use axonml_nn::Parameter;
pub trait Optimizer {
fn step(&mut self);
fn zero_grad(&mut self);
fn get_lr(&self) -> f32;
fn set_lr(&mut self, lr: f32);
fn parameters(&self) -> &[Parameter];
fn num_parameters(&self) -> usize {
self.parameters().len()
}
}
#[derive(Debug, Clone)]
pub struct ParamState {
pub momentum_buffer: Option<Vec<f32>>,
pub exp_avg_sq: Option<Vec<f32>>,
pub max_exp_avg_sq: Option<Vec<f32>>,
pub step: usize,
}
impl ParamState {
#[must_use]
pub fn new() -> Self {
Self {
momentum_buffer: None,
exp_avg_sq: None,
max_exp_avg_sq: None,
step: 0,
}
}
pub fn init_momentum(&mut self, size: usize) {
self.momentum_buffer = Some(vec![0.0; size]);
}
pub fn init_exp_avg_sq(&mut self, size: usize) {
self.exp_avg_sq = Some(vec![0.0; size]);
}
}
impl Default for ParamState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_state_creation() {
let mut state = ParamState::new();
assert!(state.momentum_buffer.is_none());
assert!(state.exp_avg_sq.is_none());
assert_eq!(state.step, 0);
state.init_momentum(10);
assert!(state.momentum_buffer.is_some());
assert_eq!(state.momentum_buffer.as_ref().unwrap().len(), 10);
}
}