optirs_core/optimizers/
mod.rs

1// Optimization algorithms for machine learning
2//
3// This module provides various optimization algorithms commonly used in machine learning,
4// such as Stochastic Gradient Descent (SGD), Adam, RMSprop, and others.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12/// Trait that defines the interface for optimization algorithms
13pub trait Optimizer<A, D>
14where
15    A: Float + ScalarOperand + Debug,
16    D: Dimension,
17{
18    /// Updates parameters using the given gradients
19    ///
20    /// # Arguments
21    ///
22    /// * `params` - The current parameter values
23    /// * `gradients` - The gradients of the parameters
24    ///
25    /// # Returns
26    ///
27    /// The updated parameters
28    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>>;
29
30    /// Gets the current learning rate
31    fn get_learning_rate(&self) -> A;
32
33    /// Sets a new learning rate
34    fn set_learning_rate(&mut self, learning_rate: A);
35
36    /// Updates multiple parameter arrays at once
37    ///
38    /// # Arguments
39    ///
40    /// * `params_list` - List of parameter arrays
41    /// * `gradients_list` - List of gradient arrays corresponding to the parameters
42    ///
43    /// # Returns
44    ///
45    /// Updated parameter arrays
46    fn step_list(
47        &mut self,
48        params_list: &[&Array<A, D>],
49        gradients_list: &[&Array<A, D>],
50    ) -> Result<Vec<Array<A, D>>> {
51        if params_list.len() != gradients_list.len() {
52            return Err(OptimError::InvalidConfig(format!(
53                "Number of parameter arrays ({}) does not match number of gradient arrays ({})",
54                params_list.len(),
55                gradients_list.len()
56            )));
57        }
58
59        let mut results = Vec::with_capacity(params_list.len());
60        for (params, grads) in params_list.iter().zip(gradients_list.iter()) {
61            results.push(self.step(params, grads)?);
62        }
63        Ok(results)
64    }
65}
66
67// Import specific optimizers
68mod adabound;
69mod adadelta;
70mod adagrad;
71mod adam;
72mod adamw;
73mod grouped_adam;
74mod lamb;
75mod lars;
76mod lbfgs;
77mod lion;
78mod lookahead;
79mod radam;
80mod ranger;
81mod rmsprop;
82mod sam;
83mod sgd;
84mod sgd_simd;
85mod sparse_adam;
86
87// Re-export specific optimizers
88pub use adabound::AdaBound;
89pub use adadelta::AdaDelta;
90pub use adagrad::Adagrad;
91pub use adam::Adam;
92pub use adamw::AdamW;
93pub use grouped_adam::GroupedAdam;
94pub use lamb::LAMB;
95pub use lars::LARS;
96pub use lbfgs::LBFGS;
97pub use lion::Lion;
98pub use lookahead::Lookahead;
99pub use radam::RAdam;
100pub use ranger::Ranger;
101pub use rmsprop::RMSprop;
102pub use sam::SAM;
103pub use sgd::SGD;
104pub use sgd_simd::SimdSGD;
105pub use sparse_adam::{SparseAdam, SparseGradient};