use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like, zeros_like_map},
},
ops::arithmetic,
};
fn validate_eps(eps: f32) -> Result<()> {
if !eps.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adagrad: eps",
eps as f64,
)));
}
if eps < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adagrad: eps",
"must be >= 0.0",
format_smolstr!("{eps}"),
)));
}
Ok(())
}
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
pub struct Adagrad {
learning_rate: LearningRate,
eps: f32,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, Array>,
}
impl Adagrad {
pub fn new(learning_rate: impl Into<LearningRate>, eps: f32) -> Result<Self> {
validate_eps(eps)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
eps,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, 1e-8)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_eps(mut self, eps: f32) -> Result<Self> {
validate_eps(eps)?;
self.eps = eps;
Ok(self)
}
}
impl Optimizer for Adagrad {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = zeros_like_map(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let eps_s = scalar(self.eps)?;
let lr_s = scalar(self.current_lr)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let prev_v = match self.state.get(key) {
Some(v) => v.try_clone()?,
None => zeros_like(param)?,
};
let g_sq = arithmetic::square(grad)?;
let v_new = arithmetic::add(&prev_v, &g_sq)?;
let lr_g = arithmetic::multiply(&lr_s, grad)?;
let sqrt_v = arithmetic::sqrt(&v_new)?;
let denom = arithmetic::add(&sqrt_v, &eps_s)?;
let step_term = arithmetic::divide(&lr_g, &denom)?;
let new_w = arithmetic::subtract(param, &step_term)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), v_new);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn adagrad_single_step_matches_python_ref() -> Result<()> {
let mut adagrad = Adagrad::default_with_lr(0.1)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
adagrad.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.9).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adagrad_rejects_negative_eps() {
assert!(Adagrad::new(0.001, -1e-8).is_err());
}
#[test]
fn adagrad_new_rejects_nan_eps() {
assert!(Adagrad::new(0.001, f32::NAN).is_err());
}
#[test]
fn adagrad_builder_with_eps_rejects_negative() {
let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(-1e-8));
assert!(res.is_err());
}
#[test]
fn adagrad_with_eps_rejects_nan() {
let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adagrad_with_eps_rejects_inf() {
let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(f32::INFINITY));
assert!(res.is_err());
}
#[test]
fn adagrad_with_learning_rate_rejects_fixed_nan() {
let res = Adagrad::default_with_lr(0.1)
.and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
}
}