Skip to main content

optirs_core/optimizers/
lamb.rs

1// LAMB optimizer implementation
2//
3// Based on the paper "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
4// by You et al. (2019).
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13/// LAMB (Layer-wise Adaptive Moments) optimizer
14///
15/// LAMB is designed for large batch optimization. It extends AdamW with layer-wise
16/// adaptive learning rates, making it particularly effective for training large models
17/// with high batch sizes.
18///
19/// Formula:
20/// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
21/// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
22/// m_hat_t = m_t / (1 - beta1^t)
23/// v_hat_t = v_t / (1 - beta2^t)
24/// r1 = ||theta_t||
25/// g' = m_hat_t / (sqrt(v_hat_t) + epsilon) + lambda * theta_t
26/// r2 = ||g'||
27/// ratio = r1/r2 if r1 > 0 and r2 > 0, else 1.0
28/// theta_t = theta_{t-1} - lr * ratio * g'
29///
30/// # Examples
31///
32/// ```
33/// use scirs2_core::ndarray::Array1;
34/// use optirs_core::optimizers::{LAMB, Optimizer};
35///
36/// // Initialize parameters and gradients
37/// let params = Array1::zeros(5);
38/// let gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0, 0.5]);
39///
40/// // Create a LAMB optimizer with default hyperparameters
41/// let mut optimizer = LAMB::new(0.001);
42///
43/// // Update parameters
44/// let new_params = optimizer.step(&params, &gradients).expect("unwrap failed");
45/// ```
46#[derive(Debug, Clone)]
47pub struct LAMB<A: Float + ScalarOperand + Debug> {
48    /// Learning rate
49    learning_rate: A,
50    /// Exponential decay rate for the first moment estimates
51    beta1: A,
52    /// Exponential decay rate for the second moment estimates
53    beta2: A,
54    /// Small constant for numerical stability
55    epsilon: A,
56    /// Weight decay factor (L2 regularization)
57    weight_decay: A,
58    /// Whether to use bias correction
59    bias_correction: bool,
60    /// First moment vector
61    m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
62    /// Second moment vector
63    v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
64    /// Current timestep
65    t: usize,
66}
67
68impl<A: Float + ScalarOperand + Debug + Send + Sync> LAMB<A> {
69    /// Creates a new LAMB optimizer with the given learning rate and default settings
70    ///
71    /// # Arguments
72    ///
73    /// * `learning_rate` - The learning rate for parameter updates
74    pub fn new(learning_rate: A) -> Self {
75        Self {
76            learning_rate,
77            beta1: A::from(0.9).expect("unwrap failed"),
78            beta2: A::from(0.999).expect("unwrap failed"),
79            epsilon: A::from(1e-6).expect("unwrap failed"),
80            weight_decay: A::zero(),
81            bias_correction: true,
82            m: None,
83            v: None,
84            t: 0,
85        }
86    }
87
88    /// Creates a new LAMB optimizer with the full configuration
89    ///
90    /// # Arguments
91    ///
92    /// * `learning_rate` - The learning rate for parameter updates
93    /// * `beta1` - Exponential decay rate for the first moment estimates (default: 0.9)
94    /// * `beta2` - Exponential decay rate for the second moment estimates (default: 0.999)
95    /// * `epsilon` - Small constant for numerical stability (default: 1e-6)
96    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
97    /// * `bias_correction` - Whether to use bias correction (default: true)
98    pub fn new_with_config(
99        learning_rate: A,
100        beta1: A,
101        beta2: A,
102        epsilon: A,
103        weight_decay: A,
104        bias_correction: bool,
105    ) -> Self {
106        Self {
107            learning_rate,
108            beta1,
109            beta2,
110            epsilon,
111            weight_decay,
112            bias_correction,
113            m: None,
114            v: None,
115            t: 0,
116        }
117    }
118
119    /// Sets the beta1 parameter
120    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
121        self.beta1 = beta1;
122        self
123    }
124
125    /// Gets the beta1 parameter
126    pub fn get_beta1(&self) -> A {
127        self.beta1
128    }
129
130    /// Sets the beta2 parameter
131    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
132        self.beta2 = beta2;
133        self
134    }
135
136    /// Gets the beta2 parameter
137    pub fn get_beta2(&self) -> A {
138        self.beta2
139    }
140
141    /// Sets the epsilon parameter
142    pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
143        self.epsilon = epsilon;
144        self
145    }
146
147    /// Gets the epsilon parameter
148    pub fn get_epsilon(&self) -> A {
149        self.epsilon
150    }
151
152    /// Sets the weight decay parameter
153    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
154        self.weight_decay = weight_decay;
155        self
156    }
157
158    /// Gets the weight decay parameter
159    pub fn get_weight_decay(&self) -> A {
160        self.weight_decay
161    }
162
163    /// Gets the current learning rate
164    pub fn learning_rate(&self) -> A {
165        self.learning_rate
166    }
167
168    /// Sets the learning rate
169    pub fn set_lr(&mut self, lr: A) {
170        self.learning_rate = lr;
171    }
172
173    /// Resets the internal state of the optimizer
174    pub fn reset(&mut self) {
175        self.m = None;
176        self.v = None;
177        self.t = 0;
178    }
179}
180
181impl<A, D> Optimizer<A, D> for LAMB<A>
182where
183    A: Float + ScalarOperand + Debug + Send + Sync,
184    D: Dimension,
185{
186    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
187        // Convert to dynamic dimension for storage in state vectors
188        let params_dyn = params.to_owned().into_dyn();
189        let gradients_dyn = gradients.to_owned().into_dyn();
190
191        // Initialize state if this is the first step
192        if self.m.is_none() {
193            self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
194            self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
195            self.t = 0;
196        }
197
198        let m = self.m.as_mut().expect("unwrap failed");
199        let v = self.v.as_mut().expect("unwrap failed");
200
201        // Ensure we have state for this parameter set
202        if m.is_empty() {
203            m.push(Array::zeros(params_dyn.raw_dim()));
204            v.push(Array::zeros(params_dyn.raw_dim()));
205        } else if m[0].raw_dim() != params_dyn.raw_dim() {
206            // If the parameter dimensions have changed, reset state
207            m[0] = Array::zeros(params_dyn.raw_dim());
208            v[0] = Array::zeros(params_dyn.raw_dim());
209        }
210
211        // Increment timestep
212        self.t += 1;
213
214        // Update biased first moment estimate
215        m[0] = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
216
217        // Update biased second raw moment estimate
218        v[0] = &v[0] * self.beta2 + &(&gradients_dyn * &gradients_dyn * (A::one() - self.beta2));
219
220        // Compute bias-corrected moments if enabled
221        let (m_hat, v_hat) = if self.bias_correction {
222            let bias1 = A::one() - self.beta1.powi(self.t as i32);
223            let bias2 = A::one() - self.beta2.powi(self.t as i32);
224            (&m[0] / bias1, &v[0] / bias2)
225        } else {
226            (m[0].clone(), v[0].clone())
227        };
228
229        // Compute adaptive term (similar to Adam)
230        let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
231        let adaptive_term = &m_hat / &(&v_hat_sqrt + self.epsilon);
232
233        // Apply weight decay to create the full gradient term
234        let normalized_gradient = if self.weight_decay > A::zero() {
235            &adaptive_term + &(&params_dyn * self.weight_decay)
236        } else {
237            adaptive_term
238        };
239
240        // Layer-wise adaptation (trust ratio)
241        let weight_norm = {
242            let norm_sq = params_dyn
243                .iter()
244                .map(|x| *x * *x)
245                .fold(A::zero(), |acc, x| acc + x);
246            norm_sq.sqrt()
247        };
248        let gradient_norm = {
249            let norm_sq = normalized_gradient
250                .iter()
251                .map(|x| *x * *x)
252                .fold(A::zero(), |acc, x| acc + x);
253            norm_sq.sqrt()
254        };
255
256        let trust_ratio = if weight_norm > A::zero() && gradient_norm > A::zero() {
257            weight_norm / gradient_norm
258        } else {
259            A::one()
260        };
261
262        // Update parameters with the trust ratio
263        let step = &normalized_gradient * (self.learning_rate * trust_ratio);
264        let updated_params = &params_dyn - step;
265
266        // Convert back to original dimension
267        Ok(updated_params
268            .into_dimensionality::<D>()
269            .expect("unwrap failed"))
270    }
271
272    fn get_learning_rate(&self) -> A {
273        self.learning_rate
274    }
275
276    fn set_learning_rate(&mut self, learning_rate: A) {
277        self.learning_rate = learning_rate;
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use approx::assert_abs_diff_eq;
285    use scirs2_core::ndarray::Array1;
286
287    #[test]
288    fn test_lamb_basic_creation() {
289        let optimizer: LAMB<f64> = LAMB::new(0.001);
290        assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
291        assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
292        assert_abs_diff_eq!(optimizer.get_beta2(), 0.999);
293        assert_abs_diff_eq!(optimizer.get_epsilon(), 1e-6);
294        assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
295        assert!(optimizer.bias_correction);
296    }
297
298    #[test]
299    fn test_lamb_convergence() {
300        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
301
302        // Minimize a simple quadratic function: f(x) = x^2 + y^2
303        let mut params = Array1::from_vec(vec![5.0, 3.0]);
304
305        for _ in 0..50 {
306            // Gradient of x^2 + y^2 is (2x, 2y)
307            let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
308            params = optimizer.step(&params, &gradients).expect("unwrap failed");
309        }
310
311        // Should converge towards (0, 0)
312        assert!(params[0].abs() < 1.0);
313        assert!(params[1].abs() < 1.0);
314    }
315
316    #[test]
317    fn test_lamb_with_weight_decay() {
318        let mut optimizer: LAMB<f64> = LAMB::new_with_config(
319            0.1,   // learning_rate
320            0.9,   // beta1
321            0.999, // beta2
322            1e-6,  // epsilon
323            0.1,   // weight_decay
324            true,  // bias_correction
325        );
326
327        // Start from (1.0, 1.0)
328        let mut params = Array1::from_vec(vec![1.0, 1.0]);
329
330        // Run optimization with small gradients
331        for _ in 0..20 {
332            let gradients = Array1::from_vec(vec![0.1, 0.1]);
333            params = optimizer.step(&params, &gradients).expect("unwrap failed");
334        }
335
336        // With weight decay, parameters should decrease
337        assert!(params[0] < 1.0);
338        assert!(params[1] < 1.0);
339    }
340
341    #[test]
342    fn test_lamb_reset() {
343        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
344
345        // Perform a step to initialize state
346        let params = Array1::from_vec(vec![1.0]);
347        let gradients = Array1::from_vec(vec![0.5]);
348        let _ = optimizer.step(&params, &gradients).expect("unwrap failed");
349
350        // State should exist
351        assert!(optimizer.m.is_some());
352        assert!(optimizer.v.is_some());
353        assert_eq!(optimizer.t, 1);
354
355        // Reset
356        optimizer.reset();
357
358        // State should be cleared
359        assert!(optimizer.m.is_none());
360        assert!(optimizer.v.is_none());
361        assert_eq!(optimizer.t, 0);
362    }
363
364    #[test]
365    fn test_lamb_trust_ratio() {
366        // Test with normal gradient and parameters
367        let mut optimizer: LAMB<f64> = LAMB::new(0.1);
368        let params = Array1::from_vec(vec![2.0, 3.0]);
369        let gradients = Array1::from_vec(vec![0.4, 0.6]);
370
371        let new_params = optimizer.step(&params, &gradients).expect("unwrap failed");
372
373        // Parameters should be updated
374        assert_ne!(new_params[0], params[0]);
375        assert_ne!(new_params[1], params[1]);
376
377        // Check they moved in the right direction
378        assert!(new_params[0] < params[0]); // gradient was positive
379        assert!(new_params[1] < params[1]); // gradient was positive
380    }
381}