use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like, zeros_like_map},
},
ops::arithmetic,
};
fn validate_alpha(alpha: f32) -> Result<()> {
if !alpha.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"RMSprop: alpha",
alpha as f64,
)));
}
if !(0.0..1.0).contains(&alpha) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"RMSprop: alpha",
"must be in [0.0, 1.0)",
format_smolstr!("{alpha}"),
)));
}
Ok(())
}
fn validate_eps(eps: f32) -> Result<()> {
if !eps.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"RMSprop: eps",
eps as f64,
)));
}
if eps < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"RMSprop: eps",
"must be >= 0.0",
format_smolstr!("{eps}"),
)));
}
Ok(())
}
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
pub struct RMSprop {
learning_rate: LearningRate,
alpha: f32,
eps: f32,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, Array>,
}
impl RMSprop {
pub fn new(learning_rate: impl Into<LearningRate>, alpha: f32, eps: f32) -> Result<Self> {
validate_alpha(alpha)?;
validate_eps(eps)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
alpha,
eps,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, 0.99, 1e-8)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn alpha(&self) -> f32 {
self.alpha
}
#[inline(always)]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_alpha(mut self, alpha: f32) -> Result<Self> {
validate_alpha(alpha)?;
self.alpha = alpha;
Ok(self)
}
pub fn with_eps(mut self, eps: f32) -> Result<Self> {
validate_eps(eps)?;
self.eps = eps;
Ok(self)
}
}
impl Optimizer for RMSprop {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = zeros_like_map(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let alpha_s = scalar(self.alpha)?;
let one_minus_alpha = scalar(1.0 - self.alpha)?;
let eps_s = scalar(self.eps)?;
let lr_s = scalar(self.current_lr)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let prev_v = match self.state.get(key) {
Some(v) => v.try_clone()?,
None => zeros_like(param)?,
};
let g_sq = arithmetic::square(grad)?;
let v_scaled = arithmetic::multiply(&alpha_s, &prev_v)?;
let g_sq_scaled = arithmetic::multiply(&one_minus_alpha, &g_sq)?;
let v_new = arithmetic::add(&v_scaled, &g_sq_scaled)?;
let lr_g = arithmetic::multiply(&lr_s, grad)?;
let sqrt_v = arithmetic::sqrt(&v_new)?;
let denom = arithmetic::add(&sqrt_v, &eps_s)?;
let step_term = arithmetic::divide(&lr_g, &denom)?;
let new_w = arithmetic::subtract(param, &step_term)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), v_new);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn rmsprop_single_step_matches_python_ref() -> Result<()> {
let mut rms = RMSprop::default_with_lr(0.001)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
rms.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.99).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn rmsprop_rejects_negative_alpha() {
assert!(RMSprop::new(0.001, -0.1, 1e-8).is_err());
}
#[test]
fn rmsprop_new_rejects_alpha_above_one() {
assert!(RMSprop::new(0.001, 1.0, 1e-8).is_err());
assert!(RMSprop::new(0.001, 1.5, 1e-8).is_err());
}
#[test]
fn rmsprop_new_rejects_nan_alpha() {
assert!(RMSprop::new(0.001, f32::NAN, 1e-8).is_err());
}
#[test]
fn rmsprop_rejects_negative_eps() {
assert!(RMSprop::new(0.001, 0.99, -1e-8).is_err());
}
#[test]
fn rmsprop_new_rejects_nan_eps() {
assert!(RMSprop::new(0.001, 0.99, f32::NAN).is_err());
}
#[test]
fn rmsprop_builder_with_alpha_rejects_negative() {
let res = RMSprop::default_with_lr(0.001).and_then(|r| r.with_alpha(-1.0));
assert!(res.is_err());
}
#[test]
fn rmsprop_with_alpha_rejects_above_one() {
let res = RMSprop::default_with_lr(0.001).and_then(|r| r.with_alpha(1.0));
assert!(res.is_err());
}
#[test]
fn rmsprop_with_alpha_rejects_nan() {
let res = RMSprop::default_with_lr(0.001).and_then(|r| r.with_alpha(f32::NAN));
assert!(res.is_err());
}
#[test]
fn rmsprop_builder_with_eps_rejects_negative() {
let res = RMSprop::default_with_lr(0.001).and_then(|r| r.with_eps(-1.0));
assert!(res.is_err());
}
#[test]
fn rmsprop_with_eps_rejects_nan() {
let res = RMSprop::default_with_lr(0.001).and_then(|r| r.with_eps(f32::NAN));
assert!(res.is_err());
}
#[test]
fn rmsprop_with_learning_rate_rejects_fixed_nan() {
let res = RMSprop::default_with_lr(0.001)
.and_then(|r| r.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
}
#[test]
fn rmsprop_getters_echo_inputs() -> Result<()> {
let rms = RMSprop::new(LearningRate::Fixed(0.02), 0.95, 1e-6)?;
assert!(
rms.learning_rate_ref().is_fixed(),
"learning_rate_ref must echo the Fixed schedule"
);
assert_eq!(rms.alpha(), 0.95);
assert_eq!(rms.eps(), 1e-6);
assert_eq!(rms.learning_rate(), 0.02);
assert_eq!(rms.step(), 0);
Ok(())
}
#[test]
fn rmsprop_default_with_lr_getters() -> Result<()> {
let rms = RMSprop::default_with_lr(0.001)?;
assert_eq!(rms.alpha(), 0.99);
assert_eq!(rms.eps(), 1e-8);
assert!(rms.learning_rate_ref().is_fixed());
Ok(())
}
#[test]
fn rmsprop_builder_success_paths_echo() -> Result<()> {
let rms = RMSprop::default_with_lr(0.001)?
.with_learning_rate(LearningRate::Fixed(0.05))?
.with_alpha(0.9)?
.with_eps(2e-7)?;
assert_eq!(rms.learning_rate(), 0.05);
assert!(rms.learning_rate_ref().is_fixed());
assert_eq!(rms.alpha(), 0.9);
assert_eq!(rms.eps(), 2e-7);
Ok(())
}
#[test]
fn rmsprop_two_steps_preflight_re_resolves() -> Result<()> {
let mut rms = RMSprop::default_with_lr(0.001)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
rms.init(¶ms)?;
assert_eq!(rms.step(), 0);
rms.apply_gradients(&grads, &mut params)?;
let after_one = read_scalar(¶ms["w"])?;
assert_eq!(rms.step(), 1);
rms.apply_gradients(&grads, &mut params)?;
let after_two = read_scalar(¶ms["w"])?;
assert_eq!(rms.step(), 2);
assert_eq!(rms.learning_rate(), 0.001);
assert!(after_two < after_one, "weight should keep decreasing");
Ok(())
}
#[test]
fn rmsprop_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut rms = RMSprop::default_with_lr(0.001)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), scalar(1.0)?);
rms.init(&init_params)?;
assert!(
!rms.state.is_empty(),
"explicit init populated state for 'a'"
);
let mut params: Weights = HashMap::new();
params.insert("a".into(), scalar(1.0)?);
params.insert("b".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), scalar(0.5)?);
grads.insert("b".into(), scalar(0.5)?);
rms.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.99).abs() < 1e-4, "b got {got_b}");
Ok(())
}
#[test]
fn rmsprop_skips_grad_key_absent_from_params() -> Result<()> {
let mut rms = RMSprop::default_with_lr(0.001)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
rms.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.99).abs() < 1e-4, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
}