use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{OptimError, Result};
pub trait Optimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>>;
fn get_learning_rate(&self) -> A;
fn set_learning_rate(&mut self, learning_rate: A);
fn step_list(
&mut self,
params_list: &[&Array<A, D>],
gradients_list: &[&Array<A, D>],
) -> Result<Vec<Array<A, D>>> {
if params_list.len() != gradients_list.len() {
return Err(OptimError::InvalidConfig(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params_list.len(),
gradients_list.len()
)));
}
let mut results = Vec::with_capacity(params_list.len());
for (params, grads) in params_list.iter().zip(gradients_list.iter()) {
results.push(self.step(params, grads)?);
}
Ok(results)
}
}
mod adabound;
mod adadelta;
mod adagrad;
mod adam;
mod adamw;
mod grouped_adam;
mod lamb;
mod lars;
mod lbfgs;
mod lion;
mod lookahead;
mod meta_sgd;
mod radam;
mod ranger;
mod reptile;
mod rmsprop;
mod sam;
mod sgd;
mod sgd_simd;
mod sparse_adam;
pub use adabound::AdaBound;
pub use adadelta::AdaDelta;
pub use adagrad::Adagrad;
pub use adam::Adam;
pub use adamw::AdamW;
pub use grouped_adam::GroupedAdam;
pub use lamb::LAMB;
pub use lars::LARS;
pub use lbfgs::LBFGS;
pub use lion::Lion;
pub use lookahead::Lookahead;
pub use meta_sgd::MetaSGD;
pub use radam::RAdam;
pub use ranger::Ranger;
pub use reptile::ReptileOptimizer;
pub use rmsprop::RMSprop;
pub use sam::SAM;
pub use sgd::SGD;
pub use sgd_simd::SimdSGD;
pub use sparse_adam::{SparseAdam, SparseGradient};