use crate::error::{OptimError, Result};
use scirs2_core::ndarray_ext::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, Zero};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaDelta<T: Float> {
rho: T,
epsilon: T,
accumulated_gradients: Option<Array1<T>>,
accumulated_updates: Option<Array1<T>>,
step_count: usize,
}
impl<T: Float> Default for AdaDelta<T> {
fn default() -> Self {
Self::new(
T::from(0.95).expect("unwrap failed"), T::from(1e-6).expect("unwrap failed"), )
.expect("unwrap failed")
}
}
impl<T: Float> AdaDelta<T> {
pub fn new(rho: T, epsilon: T) -> Result<Self> {
let rho_f64 = rho.to_f64().expect("unwrap failed");
let epsilon_f64 = epsilon.to_f64().expect("unwrap failed");
if rho_f64 <= 0.0 || rho_f64 >= 1.0 {
return Err(OptimError::InvalidParameter(format!(
"rho must be in (0, 1), got {}",
rho_f64
)));
}
if epsilon_f64 <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"epsilon must be positive, got {}",
epsilon_f64
)));
}
Ok(Self {
rho,
epsilon,
accumulated_gradients: None,
accumulated_updates: None,
step_count: 0,
})
}
pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
let n = params.len();
if grads.len() != n {
return Err(OptimError::DimensionMismatch(format!(
"Expected gradient size {}, got {}",
n,
grads.len()
)));
}
if self.accumulated_gradients.is_none() {
self.accumulated_gradients = Some(Array1::zeros(n));
self.accumulated_updates = Some(Array1::zeros(n));
}
let acc_grad = self.accumulated_gradients.as_mut().expect("unwrap failed");
let acc_update = self.accumulated_updates.as_mut().expect("unwrap failed");
let one = T::one();
let one_minus_rho = one - self.rho;
for i in 0..n {
let grad = grads[i];
acc_grad[i] = self.rho * acc_grad[i] + one_minus_rho * grad * grad;
}
let mut delta_params = Array1::zeros(n);
let warmup_boost = if self.step_count < 10 {
T::from(10.0).expect("unwrap failed") } else {
T::one()
};
for i in 0..n {
let rms_grad = (acc_grad[i] + self.epsilon).sqrt();
let rms_update = (acc_update[i] + self.epsilon).sqrt();
delta_params[i] = -(rms_update / rms_grad) * grads[i] * warmup_boost;
}
for i in 0..n {
let delta = delta_params[i];
acc_update[i] = self.rho * acc_update[i] + one_minus_rho * delta * delta;
}
let mut updated_params = params.to_owned();
for i in 0..n {
updated_params[i] = updated_params[i] + delta_params[i];
}
self.step_count += 1;
Ok(updated_params)
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset(&mut self) {
self.accumulated_gradients = None;
self.accumulated_updates = None;
self.step_count = 0;
}
pub fn rms_gradients(&self) -> Option<Array1<T>> {
self.accumulated_gradients
.as_ref()
.map(|acc_grad| acc_grad.mapv(|x| (x + self.epsilon).sqrt()))
}
pub fn rms_updates(&self) -> Option<Array1<T>> {
self.accumulated_updates
.as_ref()
.map(|acc_update| acc_update.mapv(|x| (x + self.epsilon).sqrt()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_adadelta_creation() {
let optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
assert_eq!(optimizer.step_count(), 0);
}
#[test]
fn test_adadelta_invalid_rho() {
assert!(AdaDelta::<f32>::new(1.5, 1e-6).is_err());
assert!(AdaDelta::<f32>::new(-0.1, 1e-6).is_err());
}
#[test]
fn test_adadelta_invalid_epsilon() {
assert!(AdaDelta::<f32>::new(0.95, -1e-6).is_err());
}
#[test]
fn test_adadelta_single_step() {
let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let updated_params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert!(updated_params.len() == 3);
assert_eq!(optimizer.step_count(), 1);
for i in 0..3 {
assert_ne!(updated_params[i], params[i]);
}
}
#[test]
fn test_adadelta_multiple_steps() {
let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
let mut params = array![1.0, 2.0, 3.0];
for _ in 0..10 {
let grads = array![0.1, 0.2, 0.3];
params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
}
assert_eq!(optimizer.step_count(), 10);
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
assert!(params[2] < 3.0);
}
#[test]
fn test_adadelta_shape_mismatch() {
let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2];
assert!(optimizer.step(params.view(), grads.view()).is_err());
}
#[test]
fn test_adadelta_reset() {
let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert_eq!(optimizer.step_count(), 1);
assert!(optimizer.accumulated_gradients.is_some());
optimizer.reset();
assert_eq!(optimizer.step_count(), 0);
assert!(optimizer.accumulated_gradients.is_none());
assert!(optimizer.accumulated_updates.is_none());
}
#[test]
fn test_adadelta_convergence() {
let mut optimizer = AdaDelta::<f64>::new(0.99, 1e-6).expect("unwrap failed");
let mut params = array![10.0];
for _ in 0..500 {
let grads = params.mapv(|x| 2.0 * x); params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
}
assert!(
params[0].abs() < 0.5,
"Failed to converge, got {}",
params[0]
);
}
#[test]
fn test_adadelta_rms_values() {
let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).expect("unwrap failed");
assert!(optimizer.rms_gradients().is_none());
assert!(optimizer.rms_updates().is_none());
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert!(optimizer.rms_gradients().is_some());
assert!(optimizer.rms_updates().is_some());
let rms_grads = optimizer.rms_gradients().expect("unwrap failed");
assert_eq!(rms_grads.len(), 3);
}
#[test]
fn test_adadelta_f64() {
let mut optimizer = AdaDelta::<f64>::new(0.95, 1e-8).expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let updated_params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert_eq!(updated_params.len(), 3);
}
}