use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct AdamW<F: Float + NumAssign + ScalarOperand + Debug> {
learning_rate: F,
beta1: F,
beta2: F,
epsilon: F,
weight_decay: F,
m: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
v: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
t: usize,
}
impl<F: Float + NumAssign + ScalarOperand + Debug> AdamW<F> {
pub fn new(learning_rate: F, beta1: F, beta2: F, epsilon: F, weight_decay: F) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn default_with_lr(learning_rate: F) -> Result<Self> {
let beta1 = F::from(0.9).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.9 to the appropriate floating point type".to_string(),
)
})?;
let beta2 = F::from(0.999).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.999 to the appropriate floating point type".to_string(),
)
})?;
let epsilon = F::from(1e-8).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 1e-8 to the appropriate floating point type".to_string(),
)
})?;
let weight_decay = F::from(0.01).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.01 to the appropriate floating point type".to_string(),
)
})?;
Ok(Self::new(
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
))
}
pub fn get_beta1(&self) -> F {
self.beta1
}
pub fn set_beta1(&mut self, beta1: F) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta2(&self) -> F {
self.beta2
}
pub fn set_beta2(&mut self, beta2: F) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn get_epsilon(&self) -> F {
self.epsilon
}
pub fn set_epsilon(&mut self, epsilon: F) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_weight_decay(&self) -> F {
self.weight_decay
}
pub fn set_weight_decay(&mut self, weight_decay: F) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
impl<F: Float + NumAssign + ScalarOperand + Debug> Optimizer<F> for AdamW<F> {
fn update(
&mut self,
params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params.len(),
grads.len()
)));
}
self.t += 1;
if self.m.len() != params.len() {
self.m = params.iter().map(|p| Array::zeros(p.raw_dim())).collect();
self.v = params.iter().map(|p| Array::zeros(p.raw_dim())).collect();
}
let one_minus_beta1 = F::one() - self.beta1;
let one_minus_beta2 = F::one() - self.beta2;
let beta1_pow_t = self.beta1.powi(self.t as i32);
let beta2_pow_t = self.beta2.powi(self.t as i32);
let bias_correction1 = F::one() - beta1_pow_t;
let bias_correction2 = F::one() - beta2_pow_t;
for i in 0..params.len() {
self.m[i] = &self.m[i] * self.beta1 + &(&grads[i] * one_minus_beta1);
self.v[i] = &self.v[i] * self.beta2 + &(grads[i].mapv(|x| x * x) * one_minus_beta2);
let m_hat = &self.m[i] / bias_correction1;
let v_hat = &self.v[i] / bias_correction2;
let denom = v_hat.mapv(|x| x.sqrt()) + self.epsilon;
params[i] = ¶ms[i] * (F::one() - self.learning_rate * self.weight_decay)
- &(m_hat / denom * self.learning_rate);
}
Ok(())
}
fn get_learning_rate(&self) -> F {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: F) {
self.learning_rate = lr;
}
fn name(&self) -> &'static str {
"AdamW"
}
}