use crate::error::{OptimError, Result};
use crate::schedulers::LearningRateScheduler;
use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerConfig<A: Float> {
pub lr: A,
pub weight_decay: A,
pub grad_clip: Option<A>,
pub params: HashMap<String, A>,
}
impl<A: Float + Send + Sync> Default for OptimizerConfig<A> {
fn default() -> Self {
Self {
lr: A::from(0.001).expect("unwrap failed"),
weight_decay: A::zero(),
grad_clip: None,
params: HashMap::new(),
}
}
}
impl<A: Float + Send + Sync> OptimizerConfig<A> {
pub fn new(lr: A) -> Self {
Self {
lr,
..Default::default()
}
}
pub fn weight_decay(mut self, weightdecay: A) -> Self {
self.weight_decay = weightdecay;
self
}
pub fn grad_clip(mut self, gradclip: A) -> Self {
self.grad_clip = Some(gradclip);
self
}
pub fn param<S: Into<String>>(mut self, key: S, value: A) -> Self {
self.params.insert(key.into(), value);
self
}
pub fn params(mut self, params: HashMap<String, A>) -> Self {
self.params.extend(params);
self
}
}
#[derive(Debug, Clone)]
pub struct Parameter<A: Float, D: Dimension> {
pub data: Array<A, D>,
pub grad: Option<Array<A, D>>,
pub requires_grad: bool,
pub name: String,
}
impl<A: Float + ScalarOperand, D: Dimension + Send + Sync> Parameter<A, D> {
pub fn new<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
Self {
data,
grad: None,
requires_grad: true,
name: name.into(),
}
}
pub fn no_grad<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
Self {
data,
grad: None,
requires_grad: false,
name: name.into(),
}
}
pub fn set_grad(&mut self, grad: Array<A, D>) {
if self.requires_grad {
self.grad = Some(grad);
}
}
pub fn zero_grad(&mut self) {
self.grad = None;
}
pub fn grad(&self) -> Option<&Array<A, D>> {
self.grad.as_ref()
}
pub fn clip_grad(&mut self, maxnorm: A) -> Result<()> {
if let Some(ref mut grad) = self.grad {
let _norm = grad
.iter()
.map(|x| (*x) * (*x))
.fold(A::zero(), |acc, x| acc + x)
.sqrt();
if _norm > maxnorm {
let scale = maxnorm / _norm;
grad.mapv_inplace(|x| x * scale);
}
}
Ok(())
}
}
pub trait UnifiedOptimizer<A: Float> {
fn config(&self) -> &OptimizerConfig<A>;
fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()>
where
A: ScalarOperand + Debug;
fn step_params<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()>
where
A: ScalarOperand + Debug,
{
for param in params.iter_mut() {
self.step_param(param)?;
}
Ok(())
}
fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
for param in params.iter_mut() {
param.grad = None;
}
}
fn set_lr(&mut self, lr: A);
fn get_lr(&self) -> A;
fn state_dict(&self) -> HashMap<String, Vec<u8>>;
fn load_state_dict(&mut self, statedict: HashMap<String, Vec<u8>>) -> Result<()>;
}
#[derive(Debug)]
pub struct UnifiedSGD<A: Float> {
config: OptimizerConfig<A>,
momentum_buffers: HashMap<String, Array1<A>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedSGD<A> {
pub fn new(config: OptimizerConfig<A>) -> Self {
Self {
config,
momentum_buffers: HashMap::new(),
}
}
pub fn with_momentum(mut config: OptimizerConfig<A>, momentum: A) -> Self {
config.params.insert("momentum".to_string(), momentum);
Self::new(config)
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedSGD<A> {
fn config(&self) -> &OptimizerConfig<A> {
&self.config
}
fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
if !param.requires_grad {
return Ok(());
}
if param.grad.is_none() {
return Err(OptimError::InvalidConfig(
"Parameter has no gradient".to_string(),
));
}
if let Some(max_norm) = self.config.grad_clip {
param.clip_grad(max_norm)?;
}
if self.config.weight_decay > A::zero() {
param
.data
.mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
}
let grad = param.grad.as_ref().expect("unwrap failed");
let momentum = self
.config
.params
.get("momentum")
.copied()
.unwrap_or(A::zero());
if momentum > A::zero() {
if let Some(momentum_buffer) = self.momentum_buffers.get_mut(¶m.name) {
for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
*m = momentum * (*m) + *g;
}
for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
*p = *p - self.config.lr * (*m);
}
} else {
let mut momentum_buffer = Array1::zeros(grad.len());
for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
*m = *g;
}
for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
*p = *p - self.config.lr * (*m);
}
self.momentum_buffers
.insert(param.name.clone(), momentum_buffer);
}
} else {
for (p, g) in param.data.iter_mut().zip(grad.iter()) {
*p = *p - self.config.lr * (*g);
}
}
Ok(())
}
fn set_lr(&mut self, lr: A) {
self.config.lr = lr;
}
fn get_lr(&self) -> A {
self.config.lr
}
fn state_dict(&self) -> HashMap<String, Vec<u8>> {
HashMap::new()
}
fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
Ok(())
}
}
#[derive(Debug)]
pub struct UnifiedAdam<A: Float> {
config: OptimizerConfig<A>,
step_count: usize,
exp_avg: HashMap<String, Array1<A>>,
exp_avg_sq: HashMap<String, Array1<A>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedAdam<A> {
pub fn new(config: OptimizerConfig<A>) -> Self {
let mut params = config.params.clone();
params
.entry("beta1".to_string())
.or_insert_with(|| A::from(0.9).expect("unwrap failed"));
params
.entry("beta2".to_string())
.or_insert_with(|| A::from(0.999).expect("unwrap failed"));
params
.entry("eps".to_string())
.or_insert_with(|| A::from(1e-8).expect("unwrap failed"));
Self {
config: OptimizerConfig { params, ..config },
step_count: 0,
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
pub fn with_betas(mut config: OptimizerConfig<A>, beta1: A, beta2: A) -> Self {
config.params.insert("beta1".to_string(), beta1);
config.params.insert("beta2".to_string(), beta2);
Self::new(config)
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedAdam<A> {
fn config(&self) -> &OptimizerConfig<A> {
&self.config
}
fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
if !param.requires_grad {
return Ok(());
}
if param.grad.is_none() {
return Err(OptimError::InvalidConfig(
"Parameter has no gradient".to_string(),
));
}
if let Some(max_norm) = self.config.grad_clip {
param.clip_grad(max_norm)?;
}
self.step_count += 1;
let beta1 = self.config.params["beta1"];
let beta2 = self.config.params["beta2"];
let eps = self.config.params["eps"];
let grad = param.grad.as_ref().expect("unwrap failed");
let exp_avg = self
.exp_avg
.entry(param.name.clone())
.or_insert_with(|| Array1::zeros(grad.len()));
let exp_avg_sq = self
.exp_avg_sq
.entry(param.name.clone())
.or_insert_with(|| Array1::zeros(grad.len()));
for ((exp_avg_val, exp_avg_sq_val), grad_val) in exp_avg
.iter_mut()
.zip(exp_avg_sq.iter_mut())
.zip(grad.iter())
{
*exp_avg_val = beta1 * (*exp_avg_val) + (A::one() - beta1) * (*grad_val);
*exp_avg_sq_val =
beta2 * (*exp_avg_sq_val) + (A::one() - beta2) * (*grad_val) * (*grad_val);
}
let bias_correction1 = A::one() - beta1.powi(self.step_count as i32);
let bias_correction2 = A::one() - beta2.powi(self.step_count as i32);
let step_size = self.config.lr * (bias_correction2.sqrt() / bias_correction1);
for ((p, exp_avg_val), exp_avg_sq_val) in param
.data
.iter_mut()
.zip(exp_avg.iter())
.zip(exp_avg_sq.iter())
{
let denom = exp_avg_sq_val.sqrt() + eps;
*p = *p - step_size * (*exp_avg_val) / denom;
}
if self.config.weight_decay > A::zero() {
param
.data
.mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
}
Ok(())
}
fn set_lr(&mut self, lr: A) {
self.config.lr = lr;
}
fn get_lr(&self) -> A {
self.config.lr
}
fn state_dict(&self) -> HashMap<String, Vec<u8>> {
HashMap::new()
}
fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
Ok(())
}
}
pub struct OptimizerFactory;
impl OptimizerFactory {
pub fn sgd<A: Float + ScalarOperand + Debug + Send + Sync>(
config: OptimizerConfig<A>,
) -> UnifiedSGD<A> {
UnifiedSGD::new(config)
}
pub fn adam<A: Float + ScalarOperand + Debug + Send + Sync>(
config: OptimizerConfig<A>,
) -> UnifiedAdam<A> {
UnifiedAdam::new(config)
}
pub fn sgd_momentum<A: Float + ScalarOperand + Debug + Send + Sync>(
config: OptimizerConfig<A>,
momentum: A,
) -> UnifiedSGD<A> {
UnifiedSGD::with_momentum(config, momentum)
}
pub fn adam_custom<A: Float + ScalarOperand + Debug + Send + Sync>(
config: OptimizerConfig<A>,
beta1: A,
beta2: A,
) -> UnifiedAdam<A> {
UnifiedAdam::with_betas(config, beta1, beta2)
}
}
pub struct TrainingLoop<A: Float, O: UnifiedOptimizer<A>> {
optimizer: O,
scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
_phantom: std::marker::PhantomData<A>,
}
impl<A: Float + ScalarOperand + Debug, O: UnifiedOptimizer<A> + Send + Sync> TrainingLoop<A, O> {
pub fn new(optimizer: O) -> Self {
Self {
optimizer,
scheduler: None,
_phantom: std::marker::PhantomData,
}
}
pub fn with_scheduler(mut self, scheduler: Box<dyn LearningRateScheduler<A>>) -> Self {
self.scheduler = Some(scheduler);
self
}
pub fn step<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()> {
self.optimizer.step_params(params)?;
if let Some(ref mut scheduler) = self.scheduler {
let new_lr = scheduler.step();
self.optimizer.set_lr(new_lr);
}
Ok(())
}
pub fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
for param in params.iter_mut() {
param.grad = None;
}
}
pub fn get_lr(&self) -> A {
self.optimizer.get_lr()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_unified_sgd() {
let config = OptimizerConfig::new(0.1f64);
let mut optimizer = UnifiedSGD::new(config);
let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
optimizer.step_param(&mut param).expect("unwrap failed");
assert!((param.data[0] - 0.99).abs() < 1e-10);
assert!((param.data[1] - 1.98).abs() < 1e-10);
assert!((param.data[2] - 2.97).abs() < 1e-10);
}
#[test]
fn test_unified_adam() {
let config = OptimizerConfig::new(0.001f64);
let mut optimizer = UnifiedAdam::new(config);
let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
optimizer.step_param(&mut param).expect("unwrap failed");
assert!(param.data[0] < 1.0);
assert!(param.data[1] < 2.0);
assert!(param.data[2] < 3.0);
}
#[test]
fn test_optimizer_factory() {
let config = OptimizerConfig::new(0.01f64).weight_decay(0.0001);
let _sgd = OptimizerFactory::sgd(config.clone());
let _adam = OptimizerFactory::adam(config);
}
#[test]
fn test_parameter_operations() {
let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test");
param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
assert!(param.grad().is_some());
param.clip_grad(0.1).expect("unwrap failed");
let grad = param.grad().expect("unwrap failed");
let norm: f64 = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!((norm - 0.1).abs() < 1e-10);
param.zero_grad();
assert!(param.grad().is_none());
}
}