#![allow(clippy::suboptimal_flops, clippy::neg_cmp_op_on_partial_ord)]
extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
#[derive(Debug, Clone)]
pub struct AdamF64 {
params: Box<[f64]>,
m: Box<[f64]>,
v: Box<[f64]>,
learning_rate: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
weight_decay: f64,
beta1_t: f64,
beta2_t: f64,
dims: usize,
count: u64,
}
#[derive(Debug, Clone)]
pub struct AdamF64Builder {
dimensions: Option<usize>,
learning_rate: Option<f64>,
beta1: Option<f64>,
beta2: Option<f64>,
epsilon: Option<f64>,
weight_decay: Option<f64>,
}
impl AdamF64 {
#[inline]
#[must_use]
pub fn builder() -> AdamF64Builder {
AdamF64Builder {
dimensions: Option::None,
learning_rate: Option::None,
beta1: Option::None,
beta2: Option::None,
epsilon: Option::None,
weight_decay: Option::None,
}
}
#[inline]
pub fn step(&mut self, gradient: &[f64]) -> Result<(), crate::DataError> {
assert_eq!(
gradient.len(),
self.dims,
"gradient length {} != dimensions {}",
gradient.len(),
self.dims,
);
for &g in gradient {
check_finite!(g);
}
let one_minus_b1 = 1.0 - self.beta1;
let one_minus_b2 = 1.0 - self.beta2;
for i in 0..self.dims {
self.m[i] = self.beta1 * self.m[i] + one_minus_b1 * gradient[i];
self.v[i] = self.beta2 * self.v[i] + one_minus_b2 * gradient[i] * gradient[i];
}
self.beta1_t *= self.beta1;
self.beta2_t *= self.beta2;
let m_correction = 1.0 / (1.0 - self.beta1_t);
let v_correction = 1.0 / (1.0 - self.beta2_t);
for i in 0..self.dims {
let m_hat = self.m[i] * m_correction;
let v_hat = self.v[i] * v_correction;
self.params[i] -=
self.learning_rate * m_hat / (crate::math::sqrt(v_hat) + self.epsilon);
if self.weight_decay > 0.0 {
self.params[i] -= self.weight_decay * self.params[i];
}
}
self.count += 1;
Ok(())
}
#[inline]
#[must_use]
pub fn parameters(&self) -> &[f64] {
&self.params
}
#[inline]
#[must_use]
pub fn parameter(&self, index: usize) -> f64 {
self.params[index]
}
#[inline]
pub fn set_parameters(&mut self, params: &[f64]) {
assert_eq!(
params.len(),
self.dims,
"params length {} != dimensions {}",
params.len(),
self.dims,
);
self.params.copy_from_slice(params);
}
#[inline]
#[must_use]
pub fn first_moment(&self) -> &[f64] {
&self.m
}
#[inline]
#[must_use]
pub fn second_moment(&self) -> &[f64] {
&self.v
}
#[inline]
#[must_use]
pub fn dimensions(&self) -> usize {
self.dims
}
#[inline]
#[must_use]
pub fn learning_rate(&self) -> f64 {
self.learning_rate
}
#[inline]
#[must_use]
pub fn beta1(&self) -> f64 {
self.beta1
}
#[inline]
#[must_use]
pub fn beta2(&self) -> f64 {
self.beta2
}
#[inline]
#[must_use]
pub fn epsilon(&self) -> f64 {
self.epsilon
}
#[inline]
#[must_use]
pub fn weight_decay(&self) -> f64 {
self.weight_decay
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count > 0
}
#[inline]
pub fn reset(&mut self) {
self.params.fill(0.0);
self.m.fill(0.0);
self.v.fill(0.0);
self.beta1_t = 1.0;
self.beta2_t = 1.0;
self.count = 0;
}
}
impl AdamF64Builder {
#[inline]
#[must_use]
pub fn dimensions(mut self, dims: usize) -> Self {
self.dimensions = Option::Some(dims);
self
}
#[inline]
#[must_use]
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = Option::Some(lr);
self
}
#[inline]
#[must_use]
pub fn beta1(mut self, b1: f64) -> Self {
self.beta1 = Option::Some(b1);
self
}
#[inline]
#[must_use]
pub fn beta2(mut self, b2: f64) -> Self {
self.beta2 = Option::Some(b2);
self
}
#[inline]
#[must_use]
pub fn epsilon(mut self, eps: f64) -> Self {
self.epsilon = Option::Some(eps);
self
}
#[inline]
#[must_use]
pub fn weight_decay(mut self, wd: f64) -> Self {
self.weight_decay = Option::Some(wd);
self
}
#[inline]
pub fn build(self) -> Result<AdamF64, crate::ConfigError> {
let dims = self
.dimensions
.ok_or(crate::ConfigError::Missing("dimensions"))?;
let lr = self
.learning_rate
.ok_or(crate::ConfigError::Missing("learning_rate"))?;
let b1 = self.beta1.unwrap_or(0.9);
let b2 = self.beta2.unwrap_or(0.999);
let eps = self.epsilon.unwrap_or(1e-8);
let wd = self.weight_decay.unwrap_or(0.0);
if dims < 1 {
return Err(crate::ConfigError::Invalid("dimensions must be >= 1"));
}
if !(lr > 0.0) {
return Err(crate::ConfigError::Invalid(
"learning_rate must be positive",
));
}
if !(b1 > 0.0 && b1 < 1.0) {
return Err(crate::ConfigError::Invalid("beta1 must be in (0, 1)"));
}
if !(b2 > 0.0 && b2 < 1.0) {
return Err(crate::ConfigError::Invalid("beta2 must be in (0, 1)"));
}
if !(eps > 0.0) {
return Err(crate::ConfigError::Invalid("epsilon must be positive"));
}
if !(wd >= 0.0) {
return Err(crate::ConfigError::Invalid(
"weight_decay must be non-negative",
));
}
Ok(AdamF64 {
params: vec![0.0; dims].into_boxed_slice(),
m: vec![0.0; dims].into_boxed_slice(),
v: vec![0.0; dims].into_boxed_slice(),
learning_rate: lr,
beta1: b1,
beta2: b2,
epsilon: eps,
weight_decay: wd,
beta1_t: 1.0,
beta2_t: 1.0,
dims,
count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn minimizes_quadratic() {
let mut adam = AdamF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
adam.set_parameters(&[5.0, -3.0]);
for _ in 0..2000 {
let p = adam.parameters();
let grad = [2.0 * p[0], 2.0 * p[1]];
adam.step(&grad).unwrap();
}
assert!(
adam.parameters()[0].abs() < 0.01,
"x = {}",
adam.parameters()[0]
);
assert!(
adam.parameters()[1].abs() < 0.01,
"y = {}",
adam.parameters()[1]
);
}
#[test]
fn minimizes_rosenbrock() {
let mut adam = AdamF64::builder()
.dimensions(2)
.learning_rate(0.001)
.build()
.unwrap();
adam.set_parameters(&[-1.0, -1.0]);
for _ in 0..50_000 {
let p = adam.parameters();
let x = p[0];
let y = p[1];
let dx = -2.0 * (1.0 - x) - 400.0 * x * (y - x * x);
let dy = 200.0 * (y - x * x);
adam.step(&[dx, dy]).unwrap();
}
let p = adam.parameters();
assert!((p[0] - 1.0).abs() < 0.1, "x = {}, expected ~1.0", p[0]);
assert!((p[1] - 1.0).abs() < 0.1, "y = {}, expected ~1.0", p[1]);
}
#[test]
fn bias_correction_works() {
let mut adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.1)
.build()
.unwrap();
adam.set_parameters(&[5.0]);
adam.step(&[10.0]).unwrap();
let p = adam.parameters()[0];
assert!(p < 5.0, "parameter should decrease with positive gradient");
assert!(
(5.0 - p - 0.1).abs() < 0.01,
"step should be ~learning_rate, got {}",
5.0 - p
);
}
#[test]
fn weight_decay_shrinks_params() {
let mut adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.weight_decay(0.1)
.build()
.unwrap();
adam.set_parameters(&[10.0]);
for _ in 0..100 {
adam.step(&[0.0]).unwrap();
}
assert!(
adam.parameters()[0].abs() < 10.0,
"weight decay should shrink parameters, got {}",
adam.parameters()[0]
);
}
#[test]
fn zero_weight_decay_is_standard_adam() {
let mut adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.1)
.weight_decay(0.0)
.build()
.unwrap();
assert!((adam.weight_decay() - 0.0).abs() < f64::EPSILON);
adam.set_parameters(&[5.0]);
adam.step(&[0.0]).unwrap();
assert!((adam.parameters()[0] - 5.0).abs() < f64::EPSILON);
}
#[test]
fn default_hyperparameters() {
let adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.build()
.unwrap();
assert!((adam.beta1() - 0.9).abs() < f64::EPSILON);
assert!((adam.beta2() - 0.999).abs() < f64::EPSILON);
assert!((adam.epsilon() - 1e-8).abs() < 1e-15);
assert!((adam.weight_decay() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn rejects_nan_gradient() {
let mut adam = AdamF64::builder()
.dimensions(2)
.learning_rate(0.001)
.build()
.unwrap();
assert_eq!(
adam.step(&[1.0, f64::NAN]),
Err(crate::DataError::NotANumber)
);
assert_eq!(adam.count(), 0);
}
#[test]
fn rejects_inf_gradient() {
let mut adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.build()
.unwrap();
assert_eq!(
adam.step(&[f64::INFINITY]),
Err(crate::DataError::Infinite)
);
assert_eq!(adam.count(), 0);
}
#[test]
#[should_panic(expected = "gradient length")]
fn dimension_mismatch_panics() {
let mut adam = AdamF64::builder()
.dimensions(3)
.learning_rate(0.001)
.build()
.unwrap();
let _ = adam.step(&[1.0]);
}
#[test]
fn reset_clears_all() {
let mut adam = AdamF64::builder()
.dimensions(2)
.learning_rate(0.001)
.build()
.unwrap();
adam.set_parameters(&[1.0, 2.0]);
adam.step(&[1.0, 1.0]).unwrap();
adam.step(&[1.0, 1.0]).unwrap();
adam.reset();
assert_eq!(adam.count(), 0);
assert!(adam.parameters().iter().all(|&p| p == 0.0));
assert!(adam.first_moment().iter().all(|&m| m == 0.0));
assert!(adam.second_moment().iter().all(|&v| v == 0.0));
}
#[test]
fn set_parameters_preserves_moments() {
let mut adam = AdamF64::builder()
.dimensions(2)
.learning_rate(0.001)
.build()
.unwrap();
adam.step(&[1.0, 2.0]).unwrap();
let m_before: alloc::vec::Vec<f64> = adam.first_moment().to_vec();
let v_before: alloc::vec::Vec<f64> = adam.second_moment().to_vec();
adam.set_parameters(&[10.0, 20.0]);
assert_eq!(adam.first_moment(), m_before.as_slice());
assert_eq!(adam.second_moment(), v_before.as_slice());
}
#[test]
fn count_tracks_steps() {
let mut adam = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.build()
.unwrap();
assert_eq!(adam.count(), 0);
assert!(!adam.is_primed());
adam.step(&[1.0]).unwrap();
assert_eq!(adam.count(), 1);
assert!(adam.is_primed());
}
#[test]
fn builder_validates_beta1() {
let result = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.beta1(0.0)
.build();
assert!(result.is_err());
let result = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.beta1(1.0)
.build();
assert!(result.is_err());
}
#[test]
fn builder_validates_beta2() {
let result = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.beta2(0.0)
.build();
assert!(result.is_err());
}
#[test]
fn builder_validates_weight_decay() {
let result = AdamF64::builder()
.dimensions(1)
.learning_rate(0.001)
.weight_decay(-0.01)
.build();
assert!(result.is_err());
}
#[test]
fn builder_requires_dimensions() {
let result = AdamF64::builder().learning_rate(0.001).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("dimensions"))
));
}
#[test]
fn builder_requires_learning_rate() {
let result = AdamF64::builder().dimensions(2).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("learning_rate"))
));
}
}