optirs_core/optimizers/
adam.rs

1// Adam optimizer implementation
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7// SciRS2 Integration
8// Note: OptiRS receives pre-computed gradients, so scirs2-autograd is not needed
9use scirs2_optimize::stochastic::{minimize_adam, AdamOptions};
10
11use crate::error::Result;
12use crate::optimizers::Optimizer;
13
14/// Adam optimizer
15///
16/// Implements the Adam optimization algorithm from the paper:
17/// "Adam: A Method for Stochastic Optimization" by Kingma and Ba (2014).
18///
19/// Formula:
20/// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
21/// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
22/// m_hat_t = m_t / (1 - beta1^t)
23/// v_hat_t = v_t / (1 - beta2^t)
24/// theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::ndarray::Array1;
30/// use optirs_core::optimizers::{Adam, Optimizer};
31///
32/// // Initialize parameters and gradients
33/// let params = Array1::zeros(5);
34/// let gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0, 0.5]);
35///
36/// // Create an Adam optimizer with default hyperparameters
37/// let mut optimizer = Adam::new(0.001);
38///
39/// // Update parameters
40/// let new_params = optimizer.step(&params, &gradients).unwrap();
41/// ```
42#[derive(Debug, Clone)]
43pub struct Adam<A: Float + ScalarOperand + Debug> {
44    /// Learning rate
45    learning_rate: A,
46    /// Exponential decay rate for the first moment estimates
47    beta1: A,
48    /// Exponential decay rate for the second moment estimates
49    beta2: A,
50    /// Small constant for numerical stability
51    epsilon: A,
52    /// Weight decay factor (L2 regularization)
53    weight_decay: A,
54    /// First moment vector
55    m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
56    /// Second moment vector
57    v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
58    /// Current timestep
59    t: usize,
60}
61
62impl<A: Float + ScalarOperand + Debug + Send + Sync> Adam<A> {
63    /// Creates a new Adam optimizer with the given learning rate and default settings
64    ///
65    /// # Arguments
66    ///
67    /// * `learning_rate` - The learning rate for parameter updates
68    pub fn new(learning_rate: A) -> Self {
69        Self {
70            learning_rate,
71            beta1: A::from(0.9).unwrap(),
72            beta2: A::from(0.999).unwrap(),
73            epsilon: A::from(1e-8).unwrap(),
74            weight_decay: A::zero(),
75            m: None,
76            v: None,
77            t: 0,
78        }
79    }
80
81    /// Creates a new Adam optimizer with the full configuration
82    ///
83    /// # Arguments
84    ///
85    /// * `learning_rate` - The learning rate for parameter updates
86    /// * `beta1` - Exponential decay rate for the first moment estimates (default: 0.9)
87    /// * `beta2` - Exponential decay rate for the second moment estimates (default: 0.999)
88    /// * `epsilon` - Small constant for numerical stability (default: 1e-8)
89    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
90    pub fn new_with_config(
91        learning_rate: A,
92        beta1: A,
93        beta2: A,
94        epsilon: A,
95        weight_decay: A,
96    ) -> Self {
97        Self {
98            learning_rate,
99            beta1,
100            beta2,
101            epsilon,
102            weight_decay,
103            m: None,
104            v: None,
105            t: 0,
106        }
107    }
108
109    /// Sets the beta1 parameter
110    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
111        self.beta1 = beta1;
112        self
113    }
114
115    /// Builder method to set beta1 and return self
116    pub fn with_beta1(mut self, beta1: A) -> Self {
117        self.beta1 = beta1;
118        self
119    }
120
121    /// Gets the beta1 parameter
122    pub fn get_beta1(&self) -> A {
123        self.beta1
124    }
125
126    /// Sets the beta2 parameter
127    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
128        self.beta2 = beta2;
129        self
130    }
131
132    /// Builder method to set beta2 and return self
133    pub fn with_beta2(mut self, beta2: A) -> Self {
134        self.beta2 = beta2;
135        self
136    }
137
138    /// Gets the beta2 parameter
139    pub fn get_beta2(&self) -> A {
140        self.beta2
141    }
142
143    /// Sets the epsilon parameter
144    pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
145        self.epsilon = epsilon;
146        self
147    }
148
149    /// Builder method to set epsilon and return self
150    pub fn with_epsilon(mut self, epsilon: A) -> Self {
151        self.epsilon = epsilon;
152        self
153    }
154
155    /// Gets the epsilon parameter
156    pub fn get_epsilon(&self) -> A {
157        self.epsilon
158    }
159
160    /// Sets the weight decay parameter
161    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
162        self.weight_decay = weight_decay;
163        self
164    }
165
166    /// Builder method to set weight decay and return self
167    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
168        self.weight_decay = weight_decay;
169        self
170    }
171
172    /// Gets the weight decay parameter
173    pub fn get_weight_decay(&self) -> A {
174        self.weight_decay
175    }
176
177    /// Gets the current learning rate
178    pub fn learning_rate(&self) -> A {
179        self.learning_rate
180    }
181
182    /// Sets the learning rate
183    pub fn set_lr(&mut self, lr: A) {
184        self.learning_rate = lr;
185    }
186
187    /// Resets the internal state of the optimizer
188    pub fn reset(&mut self) {
189        self.m = None;
190        self.v = None;
191        self.t = 0;
192    }
193}
194
195impl<A, D> Optimizer<A, D> for Adam<A>
196where
197    A: Float + ScalarOperand + Debug + Send + Sync,
198    D: Dimension,
199{
200    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
201        // Validate that parameters and gradients have compatible shapes
202        if params.shape() != gradients.shape() {
203            return Err(crate::error::OptimError::DimensionMismatch(format!(
204                "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
205                params.shape(),
206                gradients.shape()
207            )));
208        }
209
210        // Convert to dynamic dimension for storage in state vectors
211        let params_dyn = params.to_owned().into_dyn();
212        let gradients_dyn = gradients.to_owned().into_dyn();
213
214        // Apply weight decay to gradients if needed
215        let adjusted_gradients = if self.weight_decay > A::zero() {
216            &gradients_dyn + &(&params_dyn * self.weight_decay)
217        } else {
218            gradients_dyn
219        };
220
221        // Initialize state if this is the first step
222        if self.m.is_none() {
223            self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
224            self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
225            self.t = 0;
226        }
227
228        let m = self.m.as_mut().unwrap();
229        let v = self.v.as_mut().unwrap();
230
231        // Ensure we have state for this parameter set
232        if m.is_empty() {
233            m.push(Array::zeros(params_dyn.raw_dim()));
234            v.push(Array::zeros(params_dyn.raw_dim()));
235        } else if m[0].raw_dim() != params_dyn.raw_dim() {
236            // If the parameter dimensions have changed, reset state
237            m[0] = Array::zeros(params_dyn.raw_dim());
238            v[0] = Array::zeros(params_dyn.raw_dim());
239        }
240
241        // Increment timestep with overflow protection
242        self.t = self.t.checked_add(1).ok_or_else(|| {
243            crate::error::OptimError::InvalidConfig(
244                "Timestep counter overflow - too many optimization steps".to_string(),
245            )
246        })?;
247
248        // Update biased first moment estimate
249        m[0] = &m[0] * self.beta1 + &(&adjusted_gradients * (A::one() - self.beta1));
250
251        // Update biased second raw moment estimate
252        v[0] = &v[0] * self.beta2
253            + &(&adjusted_gradients * &adjusted_gradients * (A::one() - self.beta2));
254
255        // Compute bias-corrected first moment estimate with safe integer conversion
256        let exp_beta1 = i32::try_from(self.t).map_err(|_| {
257            crate::error::OptimError::InvalidConfig(
258                "Timestep too large for bias correction calculation".to_string(),
259            )
260        })?;
261        let m_hat = &m[0] / (A::one() - self.beta1.powi(exp_beta1));
262
263        // Compute bias-corrected second raw moment estimate with safe integer conversion
264        let exp_beta2 = i32::try_from(self.t).map_err(|_| {
265            crate::error::OptimError::InvalidConfig(
266                "Timestep too large for bias correction calculation".to_string(),
267            )
268        })?;
269        let v_hat = &v[0] / (A::one() - self.beta2.powi(exp_beta2));
270
271        // Compute square root of v_hat
272        let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
273
274        // Update parameters
275        let step = &m_hat / &(&v_hat_sqrt + self.epsilon) * self.learning_rate;
276        let updated_params = &params_dyn - step;
277
278        // Convert back to original dimension
279        Ok(updated_params.into_dimensionality::<D>().unwrap())
280    }
281
282    fn get_learning_rate(&self) -> A {
283        self.learning_rate
284    }
285
286    fn set_learning_rate(&mut self, learning_rate: A) {
287        self.learning_rate = learning_rate;
288    }
289}