optirs_core/optimizers/
lamb.rs

1// LAMB optimizer implementation
2//
3// Based on the paper "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
4// by You et al. (2019).
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13/// LAMB (Layer-wise Adaptive Moments) optimizer
14///
15/// LAMB is designed for large batch optimization. It extends AdamW with layer-wise
16/// adaptive learning rates, making it particularly effective for training large models
17/// with high batch sizes.
18///
19/// Formula:
20/// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
21/// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
22/// m_hat_t = m_t / (1 - beta1^t)
23/// v_hat_t = v_t / (1 - beta2^t)
24/// r1 = ||theta_t||
25/// g' = m_hat_t / (sqrt(v_hat_t) + epsilon) + lambda * theta_t
26/// r2 = ||g'||
27/// ratio = r1/r2 if r1 > 0 and r2 > 0, else 1.0
28/// theta_t = theta_{t-1} - lr * ratio * g'
29///
30/// # Examples
31///
32/// ```
33/// use scirs2_core::ndarray::Array1;
34/// use optirs_core::optimizers::{LAMB, Optimizer};
35///
36/// // Initialize parameters and gradients
37/// let params = Array1::zeros(5);
38/// let gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0, 0.5]);
39///
40/// // Create a LAMB optimizer with default hyperparameters
41/// let mut optimizer = LAMB::new(0.001);
42///
43/// // Update parameters
44/// let new_params = optimizer.step(&params, &gradients).unwrap();
45/// ```
46#[derive(Debug, Clone)]
47pub struct LAMB<A: Float + ScalarOperand + Debug> {
48    /// Learning rate
49    learning_rate: A,
50    /// Exponential decay rate for the first moment estimates
51    beta1: A,
52    /// Exponential decay rate for the second moment estimates
53    beta2: A,
54    /// Small constant for numerical stability
55    epsilon: A,
56    /// Weight decay factor (L2 regularization)
57    weight_decay: A,
58    /// Whether to use bias correction
59    bias_correction: bool,
60    /// First moment vector
61    m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
62    /// Second moment vector
63    v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
64    /// Current timestep
65    t: usize,
66}
67
68impl<A: Float + ScalarOperand + Debug + Send + Sync> LAMB<A> {
69    /// Creates a new LAMB optimizer with the given learning rate and default settings
70    ///
71    /// # Arguments
72    ///
73    /// * `learning_rate` - The learning rate for parameter updates
74    pub fn new(learning_rate: A) -> Self {
75        Self {
76            learning_rate,
77            beta1: A::from(0.9).unwrap(),
78            beta2: A::from(0.999).unwrap(),
79            epsilon: A::from(1e-6).unwrap(),
80            weight_decay: A::zero(),
81            bias_correction: true,
82            m: None,
83            v: None,
84            t: 0,
85        }
86    }
87
88    /// Creates a new LAMB optimizer with the full configuration
89    ///
90    /// # Arguments
91    ///
92    /// * `learning_rate` - The learning rate for parameter updates
93    /// * `beta1` - Exponential decay rate for the first moment estimates (default: 0.9)
94    /// * `beta2` - Exponential decay rate for the second moment estimates (default: 0.999)
95    /// * `epsilon` - Small constant for numerical stability (default: 1e-6)
96    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
97    /// * `bias_correction` - Whether to use bias correction (default: true)
98    pub fn new_with_config(
99        learning_rate: A,
100        beta1: A,
101        beta2: A,
102        epsilon: A,
103        weight_decay: A,
104        bias_correction: bool,
105    ) -> Self {
106        Self {
107            learning_rate,
108            beta1,
109            beta2,
110            epsilon,
111            weight_decay,
112            bias_correction,
113            m: None,
114            v: None,
115            t: 0,
116        }
117    }
118
119    /// Sets the beta1 parameter
120    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
121        self.beta1 = beta1;
122        self
123    }
124
125    /// Gets the beta1 parameter
126    pub fn get_beta1(&self) -> A {
127        self.beta1
128    }
129
130    /// Sets the beta2 parameter
131    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
132        self.beta2 = beta2;
133        self
134    }
135
136    /// Gets the beta2 parameter
137    pub fn get_beta2(&self) -> A {
138        self.beta2
139    }
140
141    /// Sets the epsilon parameter
142    pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
143        self.epsilon = epsilon;
144        self
145    }
146
147    /// Gets the epsilon parameter
148    pub fn get_epsilon(&self) -> A {
149        self.epsilon
150    }
151
152    /// Sets the weight decay parameter
153    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
154        self.weight_decay = weight_decay;
155        self
156    }
157
158    /// Gets the weight decay parameter
159    pub fn get_weight_decay(&self) -> A {
160        self.weight_decay
161    }
162
163    /// Gets the current learning rate
164    pub fn learning_rate(&self) -> A {
165        self.learning_rate
166    }
167
168    /// Sets the learning rate
169    pub fn set_lr(&mut self, lr: A) {
170        self.learning_rate = lr;
171    }
172
173    /// Resets the internal state of the optimizer
174    pub fn reset(&mut self) {
175        self.m = None;
176        self.v = None;
177        self.t = 0;
178    }
179}
180
181impl<A, D> Optimizer<A, D> for LAMB<A>
182where
183    A: Float + ScalarOperand + Debug + Send + Sync,
184    D: Dimension,
185{
186    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
187        // Convert to dynamic dimension for storage in state vectors
188        let params_dyn = params.to_owned().into_dyn();
189        let gradients_dyn = gradients.to_owned().into_dyn();
190
191        // Initialize state if this is the first step
192        if self.m.is_none() {
193            self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
194            self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
195            self.t = 0;
196        }
197
198        let m = self.m.as_mut().unwrap();
199        let v = self.v.as_mut().unwrap();
200
201        // Ensure we have state for this parameter set
202        if m.is_empty() {
203            m.push(Array::zeros(params_dyn.raw_dim()));
204            v.push(Array::zeros(params_dyn.raw_dim()));
205        } else if m[0].raw_dim() != params_dyn.raw_dim() {
206            // If the parameter dimensions have changed, reset state
207            m[0] = Array::zeros(params_dyn.raw_dim());
208            v[0] = Array::zeros(params_dyn.raw_dim());
209        }
210
211        // Increment timestep
212        self.t += 1;
213
214        // Update biased first moment estimate
215        m[0] = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
216
217        // Update biased second raw moment estimate
218        v[0] = &v[0] * self.beta2 + &(&gradients_dyn * &gradients_dyn * (A::one() - self.beta2));
219
220        // Compute bias-corrected moments if enabled
221        let (m_hat, v_hat) = if self.bias_correction {
222            let bias1 = A::one() - self.beta1.powi(self.t as i32);
223            let bias2 = A::one() - self.beta2.powi(self.t as i32);
224            (&m[0] / bias1, &v[0] / bias2)
225        } else {
226            (m[0].clone(), v[0].clone())
227        };
228
229        // Compute adaptive term (similar to Adam)
230        let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
231        let adaptive_term = &m_hat / &(&v_hat_sqrt + self.epsilon);
232
233        // Apply weight decay to create the full gradient term
234        let normalized_gradient = if self.weight_decay > A::zero() {
235            &adaptive_term + &(&params_dyn * self.weight_decay)
236        } else {
237            adaptive_term
238        };
239
240        // Layer-wise adaptation (trust ratio)
241        let weight_norm = {
242            let norm_sq = params_dyn
243                .iter()
244                .map(|x| *x * *x)
245                .fold(A::zero(), |acc, x| acc + x);
246            norm_sq.sqrt()
247        };
248        let gradient_norm = {
249            let norm_sq = normalized_gradient
250                .iter()
251                .map(|x| *x * *x)
252                .fold(A::zero(), |acc, x| acc + x);
253            norm_sq.sqrt()
254        };
255
256        let trust_ratio = if weight_norm > A::zero() && gradient_norm > A::zero() {
257            weight_norm / gradient_norm
258        } else {
259            A::one()
260        };
261
262        // Update parameters with the trust ratio
263        let step = &normalized_gradient * (self.learning_rate * trust_ratio);
264        let updated_params = &params_dyn - step;
265
266        // Convert back to original dimension
267        Ok(updated_params.into_dimensionality::<D>().unwrap())
268    }
269
270    fn get_learning_rate(&self) -> A {
271        self.learning_rate
272    }
273
274    fn set_learning_rate(&mut self, learning_rate: A) {
275        self.learning_rate = learning_rate;
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use approx::assert_abs_diff_eq;
283    use scirs2_core::ndarray::Array1;
284
285    #[test]
286    fn test_lamb_basic_creation() {
287        let optimizer: LAMB<f64> = LAMB::new(0.001);
288        assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
289        assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
290        assert_abs_diff_eq!(optimizer.get_beta2(), 0.999);
291        assert_abs_diff_eq!(optimizer.get_epsilon(), 1e-6);
292        assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
293        assert!(optimizer.bias_correction);
294    }
295
296    #[test]
297    fn test_lamb_convergence() {
298        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
299
300        // Minimize a simple quadratic function: f(x) = x^2 + y^2
301        let mut params = Array1::from_vec(vec![5.0, 3.0]);
302
303        for _ in 0..50 {
304            // Gradient of x^2 + y^2 is (2x, 2y)
305            let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
306            params = optimizer.step(&params, &gradients).unwrap();
307        }
308
309        // Should converge towards (0, 0)
310        assert!(params[0].abs() < 1.0);
311        assert!(params[1].abs() < 1.0);
312    }
313
314    #[test]
315    fn test_lamb_with_weight_decay() {
316        let mut optimizer: LAMB<f64> = LAMB::new_with_config(
317            0.1,   // learning_rate
318            0.9,   // beta1
319            0.999, // beta2
320            1e-6,  // epsilon
321            0.1,   // weight_decay
322            true,  // bias_correction
323        );
324
325        // Start from (1.0, 1.0)
326        let mut params = Array1::from_vec(vec![1.0, 1.0]);
327
328        // Run optimization with small gradients
329        for _ in 0..20 {
330            let gradients = Array1::from_vec(vec![0.1, 0.1]);
331            params = optimizer.step(&params, &gradients).unwrap();
332        }
333
334        // With weight decay, parameters should decrease
335        assert!(params[0] < 1.0);
336        assert!(params[1] < 1.0);
337    }
338
339    #[test]
340    fn test_lamb_reset() {
341        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
342
343        // Perform a step to initialize state
344        let params = Array1::from_vec(vec![1.0]);
345        let gradients = Array1::from_vec(vec![0.5]);
346        let _ = optimizer.step(&params, &gradients).unwrap();
347
348        // State should exist
349        assert!(optimizer.m.is_some());
350        assert!(optimizer.v.is_some());
351        assert_eq!(optimizer.t, 1);
352
353        // Reset
354        optimizer.reset();
355
356        // State should be cleared
357        assert!(optimizer.m.is_none());
358        assert!(optimizer.v.is_none());
359        assert_eq!(optimizer.t, 0);
360    }
361
362    #[test]
363    fn test_lamb_trust_ratio() {
364        // Test with normal gradient and parameters
365        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
366        let params = Array1::from_vec(vec![2.0, 3.0]);
367        let gradients = Array1::from_vec(vec![0.4, 0.6]);
368
369        let new_params = optimizer.step(&params, &gradients).unwrap();
370
371        // Parameters should be updated
372        assert_ne!(new_params[0], params[0]);
373        assert_ne!(new_params[1], params[1]);
374
375        // Check they moved in the right direction
376        assert!(new_params[0] < params[0]); // gradient was positive
377        assert!(new_params[1] < params[1]); // gradient was positive
378    }
379}