use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use crate::parameter_groups::{
GroupManager, GroupedOptimizer, ParameterGroup, ParameterGroupConfig,
};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug)]
pub struct GroupedAdam<A: Float + Send + Sync, D: Dimension> {
defaultlr: A,
default_beta1: A,
default_beta2: A,
default_weight_decay: A,
epsilon: A,
amsgrad: bool,
group_manager: GroupManager<A, D>,
step: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> GroupedAdam<A, D> {
pub fn new(defaultlr: A) -> Self {
Self {
defaultlr,
default_beta1: A::from(0.9).expect("unwrap failed"),
default_beta2: A::from(0.999).expect("unwrap failed"),
default_weight_decay: A::zero(),
epsilon: A::from(1e-8).expect("unwrap failed"),
amsgrad: false,
group_manager: GroupManager::new(),
step: 0,
}
}
pub fn with_beta1(mut self, beta1: A) -> Self {
self.default_beta1 = beta1;
self
}
pub fn with_beta2(mut self, beta2: A) -> Self {
self.default_beta2 = beta2;
self
}
pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
self.default_weight_decay = weight_decay;
self
}
pub fn with_amsgrad(mut self) -> Self {
self.amsgrad = true;
self
}
fn init_group_state(&mut self, groupid: usize) -> Result<()> {
let group = self.group_manager.get_group_mut(groupid)?;
if group.state.is_empty() {
let mut m_t = Vec::new();
let mut v_t = Vec::new();
let mut v_hat_max = Vec::new();
for param in &group.params {
m_t.push(Array::zeros(param.raw_dim()));
v_t.push(Array::zeros(param.raw_dim()));
if self.amsgrad {
v_hat_max.push(Array::zeros(param.raw_dim()));
}
}
group.state.insert("m_t".to_string(), m_t);
group.state.insert("v_t".to_string(), v_t);
if self.amsgrad {
group.state.insert("v_hat_max".to_string(), v_hat_max);
}
}
Ok(())
}
fn step_group_internal(
&mut self,
groupid: usize,
gradients: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>> {
let t = A::from(self.step + 1).expect("unwrap failed");
self.init_group_state(groupid)?;
let group = self.group_manager.get_group_mut(groupid)?;
if gradients.len() != group.params.len() {
return Err(OptimError::InvalidConfig(format!(
"Number of gradients ({}) doesn't match number of parameters ({})",
gradients.len(),
group.params.len()
)));
}
let lr = group.learning_rate(self.defaultlr);
let beta1 = group.get_custom_param("beta1", self.default_beta1);
let beta2 = group.get_custom_param("beta2", self.default_beta2);
let weightdecay = group.weight_decay(self.default_weight_decay);
let mut updated_params = Vec::new();
for i in 0..group.params.len() {
let param = &group.params[i];
let grad = &gradients[i];
let grad_with_decay = if weightdecay > A::zero() {
grad + &(param * weightdecay)
} else {
grad.clone()
};
let updated = {
let m_t = group.state.get_mut("m_t").expect("unwrap failed");
m_t[i] = &m_t[i] * beta1 + &grad_with_decay * (A::one() - beta1);
let m_hat = &m_t[i] / (A::one() - beta1.powi(t.to_i32().expect("unwrap failed")));
let v_t = group.state.get_mut("v_t").expect("unwrap failed");
v_t[i] = &v_t[i] * beta2 + &grad_with_decay * &grad_with_decay * (A::one() - beta2);
let v_hat = &v_t[i] / (A::one() - beta2.powi(t.to_i32().expect("unwrap failed")));
if self.amsgrad {
let v_hat_max = group.state.get_mut("v_hat_max").expect("unwrap failed");
v_hat_max[i].zip_mut_with(&v_hat, |a, &b| *a = a.max(b));
param - &(&m_hat * lr / (&v_hat_max[i].mapv(|x| x.sqrt()) + self.epsilon))
} else {
param - &(&m_hat * lr / (&v_hat.mapv(|x| x.sqrt()) + self.epsilon))
}
};
updated_params.push(updated);
}
group.params = updated_params.clone();
Ok(updated_params)
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
GroupedOptimizer<A, D> for GroupedAdam<A, D>
{
fn add_group(
&mut self,
params: Vec<Array<A, D>>,
config: ParameterGroupConfig<A>,
) -> Result<usize> {
Ok(self.group_manager.add_group(params, config))
}
fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>> {
self.group_manager.get_group(groupid)
}
fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>> {
self.group_manager.get_group_mut(groupid)
}
fn groups(&self) -> &[ParameterGroup<A, D>] {
self.group_manager.groups()
}
fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>] {
self.group_manager.groups_mut()
}
fn step_group(
&mut self,
groupid: usize,
gradients: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>> {
self.step += 1;
self.step_group_internal(groupid, gradients)
}
fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()> {
let group = self.group_manager.get_group_mut(groupid)?;
group.config.learning_rate = Some(lr);
Ok(())
}
fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()> {
let group = self.group_manager.get_group_mut(groupid)?;
group.config.weight_decay = Some(wd);
Ok(())
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
for GroupedAdam<A, D>
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_vec = vec![params.clone()];
let gradients_vec = vec![gradients.clone()];
let config = ParameterGroupConfig::new();
let groupid = self.add_group(params_vec, config)?;
let result = self.step_group(groupid, &gradients_vec)?;
Ok(result.into_iter().next().expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.defaultlr
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.defaultlr = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_grouped_adam_creation() {
let optimizer: GroupedAdam<f64, scirs2_core::ndarray::Ix1> = GroupedAdam::new(0.001);
assert_eq!(optimizer.defaultlr, 0.001);
assert_eq!(optimizer.default_beta1, 0.9);
assert_eq!(optimizer.default_beta2, 0.999);
}
#[test]
fn test_grouped_adam_multiple_groups() {
let mut optimizer = GroupedAdam::new(0.001);
let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let config1 = ParameterGroupConfig::new().with_learning_rate(0.01);
let group1 = optimizer
.add_group(params1, config1)
.expect("unwrap failed");
let params2 = vec![Array1::from_vec(vec![3.0, 4.0, 5.0])];
let config2 = ParameterGroupConfig::new().with_learning_rate(0.0001);
let group2 = optimizer
.add_group(params2, config2)
.expect("unwrap failed");
let grads1 = vec![Array1::from_vec(vec![0.1, 0.2])];
let updated1 = optimizer
.step_group(group1, &grads1)
.expect("unwrap failed");
let grads2 = vec![Array1::from_vec(vec![0.3, 0.4, 0.5])];
let updated2 = optimizer
.step_group(group2, &grads2)
.expect("unwrap failed");
assert!(updated1[0][0] < 1.0); assert!(updated2[0][0] > 2.9); }
#[test]
fn test_grouped_adam_custom_betas() {
let mut optimizer = GroupedAdam::new(0.001);
let params = vec![Array1::from_vec(vec![1.0, 2.0])];
let config = ParameterGroupConfig::new()
.with_custom_param("beta1".to_string(), 0.8)
.with_custom_param("beta2".to_string(), 0.99);
let group = optimizer.add_group(params, config).expect("unwrap failed");
let group_ref = optimizer.get_group(group).expect("unwrap failed");
assert_eq!(group_ref.get_custom_param("beta1", 0.0), 0.8);
assert_eq!(group_ref.get_custom_param("beta2", 0.0), 0.99);
}
#[test]
fn test_grouped_adam_clear() {
let mut optimizer = GroupedAdam::new(0.001);
let params1 = vec![Array1::zeros(2)];
let config1 = ParameterGroupConfig::new();
optimizer
.add_group(params1, config1)
.expect("unwrap failed");
assert_eq!(optimizer.groups().len(), 1);
optimizer.group_manager = GroupManager::new();
optimizer.step = 0;
assert_eq!(optimizer.groups().len(), 0);
assert_eq!(optimizer.step, 0);
}
}