optirs_core/optimizers/
mod.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12pub trait Optimizer<A, D>
14where
15 A: Float + ScalarOperand + Debug,
16 D: Dimension,
17{
18 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>>;
29
30 fn get_learning_rate(&self) -> A;
32
33 fn set_learning_rate(&mut self, learning_rate: A);
35
36 fn step_list(
47 &mut self,
48 params_list: &[&Array<A, D>],
49 gradients_list: &[&Array<A, D>],
50 ) -> Result<Vec<Array<A, D>>> {
51 if params_list.len() != gradients_list.len() {
52 return Err(OptimError::InvalidConfig(format!(
53 "Number of parameter arrays ({}) does not match number of gradient arrays ({})",
54 params_list.len(),
55 gradients_list.len()
56 )));
57 }
58
59 let mut results = Vec::with_capacity(params_list.len());
60 for (params, grads) in params_list.iter().zip(gradients_list.iter()) {
61 results.push(self.step(params, grads)?);
62 }
63 Ok(results)
64 }
65}
66
67mod adabound;
69mod adadelta;
70mod adagrad;
71mod adam;
72mod adamw;
73mod grouped_adam;
74mod lamb;
75mod lars;
76mod lbfgs;
77mod lion;
78mod lookahead;
79mod radam;
80mod ranger;
81mod rmsprop;
82mod sam;
83mod sgd;
84mod sgd_simd;
85mod sparse_adam;
86
87pub use adabound::AdaBound;
89pub use adadelta::AdaDelta;
90pub use adagrad::Adagrad;
91pub use adam::Adam;
92pub use adamw::AdamW;
93pub use grouped_adam::GroupedAdam;
94pub use lamb::LAMB;
95pub use lars::LARS;
96pub use lbfgs::LBFGS;
97pub use lion::Lion;
98pub use lookahead::Lookahead;
99pub use radam::RAdam;
100pub use ranger::Ranger;
101pub use rmsprop::RMSprop;
102pub use sam::SAM;
103pub use sgd::SGD;
104pub use sgd_simd::SimdSGD;
105pub use sparse_adam::{SparseAdam, SparseGradient};