mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
//! [`Adagrad`] — cumulative-squared-gradients normalization (Duchi et al.,
//! 2011).
//!
//! Mirrors Python `mlx.optimizers.Adagrad`
//! (`mlx/python/mlx/optimizers/optimizers.py:353..=400`).
//!
//! Update formula:
//!
//! ```text
//! v = v + g²
//! w_new = w - lr·g / (sqrt(v) + eps)
//! ```
//!
//! Per-parameter state: a single `v` Array (Python `state["v"]`).

use std::collections::HashMap;

use smol_str::format_smolstr;

use crate::{
  Array, Result,
  error::{Error, NonFiniteScalarPayload, OutOfRangePayload},
  lm::{
    load::Weights,
    tuner::optimizers::base::{LearningRate, Optimizer, zeros_like, zeros_like_map},
  },
  ops::arithmetic,
};

/// Validate `eps` is finite and `>= 0.0`.
fn validate_eps(eps: f32) -> Result<()> {
  if !eps.is_finite() {
    return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
      "Adagrad: eps",
      eps as f64,
    )));
  }
  if eps < 0.0 {
    return Err(Error::OutOfRange(OutOfRangePayload::new(
      "Adagrad: eps",
      "must be >= 0.0",
      format_smolstr!("{eps}"),
    )));
  }
  Ok(())
}

fn scalar(v: f32) -> Result<Array> {
  Array::full::<f32>(&[0i32; 0], v)
}

/// Adagrad optimizer.
pub struct Adagrad {
  /// Learning rate `λ`.
  learning_rate: LearningRate,
  /// Numerical-stability epsilon. Default Python: `1e-8`.
  eps: f32,
  step_count: usize,
  current_lr: f32,
  /// Skip-if-fresh stamp — `Some(N)` means `current_lr` is valid for step N.
  lr_resolved_for_step: Option<usize>,
  state: HashMap<String, Array>,
}

impl Adagrad {
  /// Construct an [`Adagrad`] optimizer.
  pub fn new(learning_rate: impl Into<LearningRate>, eps: f32) -> Result<Self> {
    validate_eps(eps)?;
    let lr = learning_rate.into();
    let current_lr = lr.try_current(0)?;
    Ok(Self {
      learning_rate: lr,
      eps,
      step_count: 0,
      current_lr,
      // Stamp the cache for step 0: the constructor's `try_current(0)` above
      // already consumed one schedule slot. Leaving `None` would force the
      // first `preflight()` at step 0 to re-resolve, double-calling stateful
      // schedules.
      lr_resolved_for_step: Some(0),
      state: HashMap::new(),
    })
  }

  /// Python-default-args constructor (`eps=1e-8`).
  pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
    Self::new(learning_rate, 1e-8)
  }

  /// The learning rate (or schedule).
  #[inline(always)]
  pub fn learning_rate_ref(&self) -> &LearningRate {
    &self.learning_rate
  }

  /// Numerical-stability epsilon.
  #[inline(always)]
  pub fn eps(&self) -> f32 {
    self.eps
  }

  /// Set the learning rate. Returns `Ok(self)` on success or `Err` if the
  /// resolved value at the current step is non-finite.
  pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
    let lr = learning_rate.into();
    let current_lr = lr.try_current(self.step_count)?;
    self.learning_rate = lr;
    self.current_lr = current_lr;
    self.lr_resolved_for_step = Some(self.step_count);
    Ok(self)
  }

  /// Set epsilon. Returns `Ok(self)` on success or `Err` if `eps` is not
  /// finite or `< 0.0`.
  pub fn with_eps(mut self, eps: f32) -> Result<Self> {
    validate_eps(eps)?;
    self.eps = eps;
    Ok(self)
  }
}

impl Optimizer for Adagrad {
  fn init(&mut self, params: &Weights) -> Result<()> {
    self.state = zeros_like_map(params)?;
    Ok(())
  }

  fn preflight(&mut self) -> Result<()> {
    if self.lr_resolved_for_step == Some(self.step_count) {
      return Ok(()); // cache hit: schedule already consulted at this step
    }
    self.current_lr = self.learning_rate.try_current(self.step_count)?;
    self.lr_resolved_for_step = Some(self.step_count);
    Ok(())
  }

  fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
    if self.state.is_empty() {
      self.init(gradients)?;
    }
    // Resolve scheduled LR via skip-if-fresh cache (no-op if MultiOptimizer
    // already preflighted this step). Matches Python `optimizers.py:102..=106`.
    self.preflight()?;
    self.step_count += 1;
    let eps_s = scalar(self.eps)?;
    let lr_s = scalar(self.current_lr)?;
    for (key, grad) in gradients {
      let Some(param) = params.get(key) else {
        continue;
      };
      let prev_v = match self.state.get(key) {
        Some(v) => v.try_clone()?,
        None => zeros_like(param)?,
      };
      // v = v + g²
      let g_sq = arithmetic::square(grad)?;
      let v_new = arithmetic::add(&prev_v, &g_sq)?;
      // w_new = w - lr·g / (sqrt(v) + eps)
      let lr_g = arithmetic::multiply(&lr_s, grad)?;
      let sqrt_v = arithmetic::sqrt(&v_new)?;
      let denom = arithmetic::add(&sqrt_v, &eps_s)?;
      let step_term = arithmetic::divide(&lr_g, &denom)?;
      let new_w = arithmetic::subtract(param, &step_term)?;
      params.insert(key.clone(), new_w);
      self.state.insert(key.clone(), v_new);
    }
    Ok(())
  }

  fn step(&self) -> usize {
    self.step_count
  }

  fn learning_rate(&self) -> f32 {
    self.current_lr
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  fn read_scalar(a: &Array) -> Result<f32> {
    let mut clone = a.try_clone()?;
    clone.item::<f32>()
  }

  #[test]
  fn adagrad_single_step_matches_python_ref() -> Result<()> {
    // Python first step: v = g², w_new = w - lr·g / (sqrt(v) + eps).
    // w=1.0, g=0.5, lr=0.1, eps=1e-8
    //   v = 0.25
    //   sqrt(v) = 0.5
    //   step = 0.1·0.5 / (0.5 + 1e-8) ≈ 0.1
    //   w_new ≈ 0.9
    let mut adagrad = Adagrad::default_with_lr(0.1)?;
    let mut params: Weights = HashMap::new();
    params.insert("w".into(), scalar(1.0)?);
    let mut grads: Weights = HashMap::new();
    grads.insert("w".into(), scalar(0.5)?);
    adagrad.apply_gradients(&grads, &mut params)?;
    let got = read_scalar(&params["w"])?;
    assert!((got - 0.9).abs() < 1e-4, "got {got}");
    Ok(())
  }

  #[test]
  fn adagrad_rejects_negative_eps() {
    assert!(Adagrad::new(0.001, -1e-8).is_err());
  }

  #[test]
  fn adagrad_new_rejects_nan_eps() {
    assert!(Adagrad::new(0.001, f32::NAN).is_err());
  }

  #[test]
  fn adagrad_builder_with_eps_rejects_negative() {
    let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(-1e-8));
    assert!(res.is_err());
  }

  #[test]
  fn adagrad_with_eps_rejects_nan() {
    let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(f32::NAN));
    assert!(res.is_err());
  }

  #[test]
  fn adagrad_with_eps_rejects_inf() {
    let res = Adagrad::default_with_lr(0.1).and_then(|a| a.with_eps(f32::INFINITY));
    assert!(res.is_err());
  }

  #[test]
  fn adagrad_with_learning_rate_rejects_fixed_nan() {
    let res = Adagrad::default_with_lr(0.1)
      .and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
    assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
  }
}