use scirs2_core::ndarray::{Array, Ix1, ScalarOperand};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
pub struct SparseGradient<A: Float + ScalarOperand + Debug> {
pub indices: Vec<usize>,
pub values: Vec<A>,
pub dim: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseGradient<A> {
pub fn new(indices: Vec<usize>, values: Vec<A>, dim: usize) -> Self {
assert_eq!(
indices.len(),
values.len(),
"Indices and values must have the same length"
);
if let Some(&max_idx) = indices.iter().max() {
assert!(
max_idx < dim,
"Index {} is out of bounds for dimension {}",
max_idx,
dim
);
}
Self {
indices,
values,
dim,
}
}
pub fn from_array(array: &Array<A, Ix1>) -> Self {
let mut indices = Vec::new();
let mut values = Vec::new();
for (idx, &val) in array.iter().enumerate() {
if !val.is_zero() {
indices.push(idx);
values.push(val);
}
}
Self {
indices,
values,
dim: array.len(),
}
}
pub fn to_array(&self) -> Array<A, Ix1> {
let mut array = Array::zeros(self.dim);
for (&idx, &val) in self.indices.iter().zip(&self.values) {
array[idx] = val;
}
array
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SparseAdam<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
m: HashMap<usize, A>,
v: HashMap<usize, A>,
t: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseAdam<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.999).expect("unwrap failed"),
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::zero(),
m: HashMap::new(),
v: HashMap::new(),
t: 0,
}
}
pub fn new_with_config(
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
m: HashMap::new(),
v: HashMap::new(),
t: 0,
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn with_beta1(mut self, beta1: A) -> Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn with_beta2(mut self, beta2: A) -> Self {
self.beta2 = beta2;
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn with_epsilon(mut self, epsilon: A) -> Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn step_sparse(
&mut self,
params: &Array<A, Ix1>,
gradient: &SparseGradient<A>,
) -> Result<Array<A, Ix1>> {
if params.len() != gradient.dim {
return Err(OptimError::InvalidConfig(format!(
"Parameter dimension ({}) doesn't match gradient dimension ({})",
params.len(),
gradient.dim
)));
}
if gradient.is_empty() {
return Ok(params.clone());
}
self.t += 1;
let bias_correction1 = A::one() - self.beta1.powi(self.t as i32);
let bias_correction2 = A::one() - self.beta2.powi(self.t as i32);
let mut updated_params = params.clone();
for (&idx, &grad_val) in gradient.indices.iter().zip(&gradient.values) {
let adjusted_grad = if self.weight_decay > A::zero() {
grad_val + params[idx] * self.weight_decay
} else {
grad_val
};
let m_prev = *self.m.get(&idx).unwrap_or(&A::zero());
let m_t = self.beta1 * m_prev + (A::one() - self.beta1) * adjusted_grad;
self.m.insert(idx, m_t);
let v_prev = *self.v.get(&idx).unwrap_or(&A::zero());
let v_t = self.beta2 * v_prev + (A::one() - self.beta2) * adjusted_grad * adjusted_grad;
self.v.insert(idx, v_t);
let m_hat = m_t / bias_correction1;
let v_hat = v_t / bias_correction2;
let step = self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
updated_params[idx] = params[idx] - step;
}
Ok(updated_params)
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
impl<A> Optimizer<A, Ix1> for SparseAdam<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
{
fn step(&mut self, params: &Array<A, Ix1>, gradients: &Array<A, Ix1>) -> Result<Array<A, Ix1>> {
let sparse_gradient = SparseGradient::from_array(gradients);
self.step_sparse(params, &sparse_gradient)
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_sparse_gradient_creation() {
let indices = vec![0, 2, 4];
let values = vec![1.0, 2.0, 3.0];
let dim = 5;
let sparse_grad = SparseGradient::new(indices, values, dim);
assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
assert_eq!(sparse_grad.dim, 5);
}
#[test]
fn test_sparse_gradient_from_array() {
let dense = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
let sparse_grad = SparseGradient::from_array(&dense);
assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
assert_eq!(sparse_grad.dim, 5);
}
#[test]
fn test_sparse_gradient_to_array() {
let indices = vec![0, 2, 4];
let values = vec![1.0, 2.0, 3.0];
let dim = 5;
let sparse_grad = SparseGradient::new(indices, values, dim);
let dense = sparse_grad.to_array();
let expected = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
assert_eq!(dense, expected);
}
#[test]
fn test_sparse_adam_creation() {
let optimizer = SparseAdam::<f64>::new(0.001);
assert_eq!(optimizer.get_learning_rate(), 0.001);
assert_eq!(optimizer.get_beta1(), 0.9);
assert_eq!(optimizer.get_beta2(), 0.999);
assert_eq!(optimizer.get_epsilon(), 1e-8);
assert_eq!(optimizer.get_weight_decay(), 0.0);
}
#[test]
fn test_sparse_adam_step() {
let mut optimizer = SparseAdam::<f64>::new(0.1);
let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let sparse_grad = SparseGradient::new(
vec![1, 3], vec![0.2, 0.5], 5, );
let updated_params = optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0], 0.0);
assert!(updated_params[1] < 0.0); assert_abs_diff_eq!(updated_params[2], 0.0);
assert!(updated_params[3] < 0.0); assert_abs_diff_eq!(updated_params[4], 0.0);
assert!(updated_params[3].abs() > updated_params[1].abs());
}
#[test]
fn test_sparse_adam_vs_dense_adam() {
let mut sparse_optimizer = SparseAdam::<f64>::new(0.1);
let mut dense_optimizer = crate::optimizers::adam::Adam::<f64>::new(0.1);
let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let dense_grad = Array1::from_vec(vec![0.0, 0.2, 0.0, 0.5, 0.0]);
let sparse_grad = SparseGradient::from_array(&dense_grad);
let sparse_result = sparse_optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
let dense_result = dense_optimizer
.step(¶ms, &dense_grad)
.expect("unwrap failed");
assert_abs_diff_eq!(sparse_result[0], dense_result[0]);
assert_abs_diff_eq!(sparse_result[1], dense_result[1], epsilon = 1e-10);
assert_abs_diff_eq!(sparse_result[2], dense_result[2]);
assert_abs_diff_eq!(sparse_result[3], dense_result[3], epsilon = 1e-10);
assert_abs_diff_eq!(sparse_result[4], dense_result[4]);
}
#[test]
fn test_sparse_adam_multiple_steps() {
let mut optimizer = SparseAdam::<f64>::new(0.1);
let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let sparse_grad1 = SparseGradient::new(
vec![1, 3], vec![0.2, 0.5], 5, );
params = optimizer
.step_sparse(¶ms, &sparse_grad1)
.expect("unwrap failed");
let sparse_grad2 = SparseGradient::new(
vec![0, 2], vec![0.3, 0.4], 5, );
params = optimizer
.step_sparse(¶ms, &sparse_grad2)
.expect("unwrap failed");
assert!(params[0] < 0.0);
assert!(params[1] < 0.0);
assert!(params[2] < 0.0);
assert!(params[3] < 0.0);
assert_abs_diff_eq!(params[4], 0.0);
params = optimizer
.step_sparse(¶ms, &sparse_grad2)
.expect("unwrap failed");
let prev_param0 = params[0];
let prev_param2 = params[2];
params = optimizer
.step_sparse(¶ms, &sparse_grad2)
.expect("unwrap failed");
assert!(params[0].abs() > prev_param0.abs());
assert!(params[2].abs() > prev_param2.abs());
}
#[test]
fn test_sparse_adam_with_weight_decay() {
let mut optimizer = SparseAdam::<f64>::new(0.1).with_weight_decay(0.01);
let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
let sparse_grad = SparseGradient::new(
vec![1, 3], vec![0.2, 0.5], 5, );
let mut optimizer_no_decay = SparseAdam::<f64>::new(0.1);
let with_decay = optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
let without_decay = optimizer_no_decay
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
assert!(with_decay[1] != without_decay[1]);
assert!(with_decay[3] != without_decay[3]);
assert_abs_diff_eq!(with_decay[0], params[0]);
assert_abs_diff_eq!(with_decay[2], params[2]);
assert_abs_diff_eq!(with_decay[4], params[4]);
}
#[test]
fn test_sparse_adam_empty_gradient() {
let mut optimizer = SparseAdam::<f64>::new(0.1);
let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
let sparse_grad = SparseGradient::new(
vec![], vec![], 5, );
let result = optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
assert_eq!(result, params);
}
#[test]
fn test_sparse_adam_reset() {
let mut optimizer = SparseAdam::<f64>::new(0.1);
let params = Array1::from_vec(vec![0.0; 5]);
let sparse_grad = SparseGradient::new(
vec![1, 3], vec![0.2, 0.5], 5, );
for _ in 0..10 {
optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
}
optimizer.reset();
let mut new_optimizer = SparseAdam::<f64>::new(0.1);
let reset_result = optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
let new_result = new_optimizer
.step_sparse(¶ms, &sparse_grad)
.expect("unwrap failed");
assert_abs_diff_eq!(reset_result[1], new_result[1], epsilon = 1e-10);
assert_abs_diff_eq!(reset_result[3], new_result[3], epsilon = 1e-10);
}
}