use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like},
},
ops::arithmetic,
};
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
fn validate_rho(rho: f32) -> Result<()> {
if !rho.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"AdaDelta: rho",
rho as f64,
)));
}
if !(0.0..1.0).contains(&rho) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"AdaDelta: rho",
"must be in [0.0, 1.0)",
format_smolstr!("{rho}"),
)));
}
Ok(())
}
fn validate_eps(eps: f32) -> Result<()> {
if !eps.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"AdaDelta: eps",
eps as f64,
)));
}
if eps < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"AdaDelta: eps",
"must be >= 0.0",
format_smolstr!("{eps}"),
)));
}
Ok(())
}
pub struct AdaDelta {
learning_rate: LearningRate,
rho: f32,
eps: f32,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, (Array, Array)>,
}
impl AdaDelta {
pub fn new(learning_rate: impl Into<LearningRate>, rho: f32, eps: f32) -> Result<Self> {
validate_rho(rho)?;
validate_eps(eps)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
rho,
eps,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, 0.9, 1e-6)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn rho(&self) -> f32 {
self.rho
}
#[inline(always)]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_rho(mut self, rho: f32) -> Result<Self> {
validate_rho(rho)?;
self.rho = rho;
Ok(self)
}
pub fn with_eps(mut self, eps: f32) -> Result<Self> {
validate_eps(eps)?;
self.eps = eps;
Ok(self)
}
}
impl Optimizer for AdaDelta {
fn init(&mut self, params: &Weights) -> Result<()> {
let mut out = HashMap::with_capacity(params.len());
for (key, value) in params {
out.insert(key.clone(), (zeros_like(value)?, zeros_like(value)?));
}
self.state = out;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let rho_s = scalar(self.rho)?;
let one_minus_rho = scalar(1.0 - self.rho)?;
let eps_s = scalar(self.eps)?;
let lr_s = scalar(self.current_lr)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let (prev_v, prev_u) = match self.state.get(key) {
Some((v, u)) => (v.try_clone()?, u.try_clone()?),
None => (zeros_like(param)?, zeros_like(param)?),
};
let g_sq = arithmetic::square(grad)?;
let v_scaled = arithmetic::multiply(&rho_s, &prev_v)?;
let g_sq_scaled = arithmetic::multiply(&one_minus_rho, &g_sq)?;
let v_new = arithmetic::add(&v_scaled, &g_sq_scaled)?;
let u_plus_eps = arithmetic::add(&prev_u, &eps_s)?;
let v_plus_eps = arithmetic::add(&v_new, &eps_s)?;
let sqrt_u = arithmetic::sqrt(&u_plus_eps)?;
let sqrt_v = arithmetic::sqrt(&v_plus_eps)?;
let ratio = arithmetic::divide(&sqrt_u, &sqrt_v)?;
let dw = arithmetic::multiply(&ratio, grad)?;
let dw_sq = arithmetic::square(&dw)?;
let u_scaled = arithmetic::multiply(&rho_s, &prev_u)?;
let dw_sq_scaled = arithmetic::multiply(&one_minus_rho, &dw_sq)?;
let u_new = arithmetic::add(&u_scaled, &dw_sq_scaled)?;
let step_term = arithmetic::multiply(&lr_s, &dw)?;
let new_w = arithmetic::subtract(param, &step_term)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), (v_new, u_new));
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn adadelta_single_step_matches_python_ref() -> Result<()> {
let mut adadelta = AdaDelta::default_with_lr(1.0)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
adadelta.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.996_837).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adadelta_rejects_negative_rho() {
assert!(AdaDelta::new(1.0, -0.1, 1e-6).is_err());
}
#[test]
fn adadelta_rejects_rho_at_one() {
assert!(AdaDelta::new(1.0, 1.0, 1e-6).is_err());
}
#[test]
fn adadelta_rejects_non_finite_rho() {
assert!(AdaDelta::new(1.0, f32::NAN, 1e-6).is_err());
assert!(AdaDelta::new(1.0, f32::INFINITY, 1e-6).is_err());
}
#[test]
fn adadelta_rejects_negative_eps() {
assert!(AdaDelta::new(1.0, 0.9, -1e-6).is_err());
}
#[test]
fn adadelta_rejects_non_finite_eps() {
assert!(AdaDelta::new(1.0, 0.9, f32::NAN).is_err());
}
#[test]
fn adadelta_builder_with_rho_rejects_negative() {
let res = AdaDelta::default_with_lr(1.0).and_then(|a| a.with_rho(-0.5));
assert!(res.is_err());
}
#[test]
fn adadelta_builder_with_rho_rejects_at_one() {
let res = AdaDelta::default_with_lr(1.0).and_then(|a| a.with_rho(1.0));
assert!(res.is_err());
}
#[test]
fn adadelta_builder_with_rho_rejects_non_finite() {
let res = AdaDelta::default_with_lr(1.0).and_then(|a| a.with_rho(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adadelta_builder_with_eps_rejects_negative() {
let res = AdaDelta::default_with_lr(1.0).and_then(|a| a.with_eps(-1e-6));
assert!(res.is_err());
}
#[test]
fn adadelta_builder_with_eps_rejects_non_finite() {
let res = AdaDelta::default_with_lr(1.0).and_then(|a| a.with_eps(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adadelta_with_learning_rate_rejects_fixed_nan() {
let res = AdaDelta::default_with_lr(1.0)
.and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
}
#[test]
fn adadelta_getters_echo_inputs() -> Result<()> {
let ad = AdaDelta::new(LearningRate::Fixed(0.5), 0.8, 1e-5)?;
assert!(
ad.learning_rate_ref().is_fixed(),
"learning_rate_ref must echo the Fixed schedule"
);
assert_eq!(ad.rho(), 0.8);
assert_eq!(ad.eps(), 1e-5);
assert_eq!(ad.learning_rate(), 0.5);
assert_eq!(ad.step(), 0);
Ok(())
}
#[test]
fn adadelta_default_with_lr_getters() -> Result<()> {
let ad = AdaDelta::default_with_lr(1.0)?;
assert_eq!(ad.rho(), 0.9);
assert_eq!(ad.eps(), 1e-6);
assert!(ad.learning_rate_ref().is_fixed());
Ok(())
}
#[test]
fn adadelta_builder_success_paths_echo() -> Result<()> {
let ad = AdaDelta::default_with_lr(1.0)?
.with_learning_rate(LearningRate::Fixed(0.5))?
.with_rho(0.7)?
.with_eps(2e-7)?;
assert_eq!(ad.learning_rate(), 0.5);
assert!(ad.learning_rate_ref().is_fixed());
assert_eq!(ad.rho(), 0.7);
assert_eq!(ad.eps(), 2e-7);
Ok(())
}
#[test]
fn adadelta_two_steps_preflight_re_resolves() -> Result<()> {
let mut ad = AdaDelta::default_with_lr(1.0)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
ad.init(¶ms)?;
assert_eq!(ad.step(), 0);
ad.apply_gradients(&grads, &mut params)?;
let after_one = read_scalar(¶ms["w"])?;
assert_eq!(ad.step(), 1);
ad.apply_gradients(&grads, &mut params)?;
let after_two = read_scalar(¶ms["w"])?;
assert_eq!(ad.step(), 2);
assert_eq!(ad.learning_rate(), 1.0);
assert!(after_two < after_one, "weight should keep decreasing");
Ok(())
}
#[test]
fn adadelta_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut ad = AdaDelta::default_with_lr(1.0)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), scalar(1.0)?);
ad.init(&init_params)?;
assert!(
!ad.state.is_empty(),
"explicit init populated state for 'a'"
);
let mut params: Weights = HashMap::new();
params.insert("a".into(), scalar(1.0)?);
params.insert("b".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), scalar(0.5)?);
grads.insert("b".into(), scalar(0.5)?);
ad.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.996_837).abs() < 1e-4, "b got {got_b}");
Ok(())
}
#[test]
fn adadelta_skips_grad_key_absent_from_params() -> Result<()> {
let mut ad = AdaDelta::default_with_lr(1.0)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
ad.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.996_837).abs() < 1e-4, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
}