#![allow(clippy::suboptimal_flops, clippy::neg_cmp_op_on_partial_ord)]
extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
#[derive(Debug, Clone)]
pub struct OnlineGdF64 {
params: Box<[f64]>,
learning_rate: f64,
dims: usize,
count: u64,
}
#[derive(Debug, Clone)]
pub struct OnlineGdF64Builder {
dimensions: Option<usize>,
learning_rate: Option<f64>,
}
impl OnlineGdF64 {
#[inline]
#[must_use]
pub fn builder() -> OnlineGdF64Builder {
OnlineGdF64Builder {
dimensions: Option::None,
learning_rate: Option::None,
}
}
#[inline]
pub fn step(&mut self, gradient: &[f64]) -> Result<(), crate::DataError> {
assert_eq!(
gradient.len(),
self.dims,
"gradient length {} != dimensions {}",
gradient.len(),
self.dims,
);
for &g in gradient {
check_finite!(g);
}
for i in 0..self.dims {
self.params[i] -= self.learning_rate * gradient[i];
}
self.count += 1;
Ok(())
}
#[inline]
#[must_use]
pub fn parameters(&self) -> &[f64] {
&self.params
}
#[inline]
#[must_use]
pub fn parameter(&self, index: usize) -> f64 {
self.params[index]
}
#[inline]
pub fn set_parameters(&mut self, params: &[f64]) {
assert_eq!(
params.len(),
self.dims,
"params length {} != dimensions {}",
params.len(),
self.dims,
);
self.params.copy_from_slice(params);
}
#[inline]
#[must_use]
pub fn dimensions(&self) -> usize {
self.dims
}
#[inline]
#[must_use]
pub fn learning_rate(&self) -> f64 {
self.learning_rate
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count > 0
}
#[inline]
pub fn reset(&mut self) {
self.params.fill(0.0);
self.count = 0;
}
}
impl OnlineGdF64Builder {
#[inline]
#[must_use]
pub fn dimensions(mut self, dims: usize) -> Self {
self.dimensions = Option::Some(dims);
self
}
#[inline]
#[must_use]
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = Option::Some(lr);
self
}
#[inline]
pub fn build(self) -> Result<OnlineGdF64, crate::ConfigError> {
let dims = self
.dimensions
.ok_or(crate::ConfigError::Missing("dimensions"))?;
let lr = self
.learning_rate
.ok_or(crate::ConfigError::Missing("learning_rate"))?;
if dims < 1 {
return Err(crate::ConfigError::Invalid("dimensions must be >= 1"));
}
if !(lr > 0.0) {
return Err(crate::ConfigError::Invalid(
"learning_rate must be positive",
));
}
Ok(OnlineGdF64 {
params: vec![0.0; dims].into_boxed_slice(),
learning_rate: lr,
dims,
count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn minimizes_quadratic() {
let mut gd = OnlineGdF64::builder()
.dimensions(1)
.learning_rate(0.1)
.build()
.unwrap();
gd.set_parameters(&[5.0]);
for _ in 0..100 {
let grad = [2.0 * gd.parameters()[0]];
gd.step(&grad).unwrap();
}
assert!(
gd.parameters()[0].abs() < 1e-6,
"x = {}",
gd.parameters()[0]
);
}
#[test]
fn minimizes_2d_quadratic() {
let mut gd = OnlineGdF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
gd.set_parameters(&[5.0, -3.0]);
for _ in 0..100 {
let p = gd.parameters();
let grad = [2.0 * p[0], 2.0 * p[1]];
gd.step(&grad).unwrap();
}
assert!(gd.parameters()[0].abs() < 1e-6);
assert!(gd.parameters()[1].abs() < 1e-6);
}
#[test]
fn high_lr_diverges() {
let mut gd = OnlineGdF64::builder()
.dimensions(1)
.learning_rate(1.5) .build()
.unwrap();
gd.set_parameters(&[1.0]);
for _ in 0..20 {
let grad = [2.0 * gd.parameters()[0]];
gd.step(&grad).unwrap();
}
assert!(gd.parameters()[0].abs() > 1.0);
}
#[test]
fn rejects_nan_gradient() {
let mut gd = OnlineGdF64::builder()
.dimensions(2)
.learning_rate(0.01)
.build()
.unwrap();
assert_eq!(
gd.step(&[1.0, f64::NAN]),
Err(crate::DataError::NotANumber)
);
assert_eq!(gd.count(), 0);
}
#[test]
fn rejects_inf_gradient() {
let mut gd = OnlineGdF64::builder()
.dimensions(2)
.learning_rate(0.01)
.build()
.unwrap();
assert_eq!(
gd.step(&[f64::INFINITY, 0.0]),
Err(crate::DataError::Infinite)
);
assert_eq!(
gd.step(&[0.0, f64::NEG_INFINITY]),
Err(crate::DataError::Infinite)
);
assert_eq!(gd.count(), 0);
}
#[test]
#[should_panic(expected = "gradient length")]
fn dimension_mismatch_panics() {
let mut gd = OnlineGdF64::builder()
.dimensions(3)
.learning_rate(0.01)
.build()
.unwrap();
let _ = gd.step(&[1.0, 2.0]);
}
#[test]
fn reset_zeros_params() {
let mut gd = OnlineGdF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
gd.set_parameters(&[5.0, -3.0]);
gd.step(&[1.0, 1.0]).unwrap();
assert!(gd.count() > 0);
gd.reset();
assert_eq!(gd.count(), 0);
assert!(gd.parameters().iter().all(|&p| p == 0.0));
}
#[test]
fn set_parameters_works() {
let mut gd = OnlineGdF64::builder()
.dimensions(3)
.learning_rate(0.01)
.build()
.unwrap();
gd.set_parameters(&[1.0, 2.0, 3.0]);
assert!((gd.parameter(0) - 1.0).abs() < f64::EPSILON);
assert!((gd.parameter(1) - 2.0).abs() < f64::EPSILON);
assert!((gd.parameter(2) - 3.0).abs() < f64::EPSILON);
}
#[test]
fn count_tracks_steps() {
let mut gd = OnlineGdF64::builder()
.dimensions(1)
.learning_rate(0.01)
.build()
.unwrap();
assert_eq!(gd.count(), 0);
assert!(!gd.is_primed());
gd.step(&[1.0]).unwrap();
assert_eq!(gd.count(), 1);
assert!(gd.is_primed());
}
#[test]
fn builder_rejects_zero_dimensions() {
let result = OnlineGdF64::builder()
.dimensions(0)
.learning_rate(0.01)
.build();
assert!(result.is_err());
}
#[test]
fn builder_rejects_negative_lr() {
let result = OnlineGdF64::builder()
.dimensions(2)
.learning_rate(-0.01)
.build();
assert!(result.is_err());
}
#[test]
fn builder_requires_dimensions() {
let result = OnlineGdF64::builder().learning_rate(0.01).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("dimensions"))
));
}
#[test]
fn builder_requires_learning_rate() {
let result = OnlineGdF64::builder().dimensions(2).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("learning_rate"))
));
}
}