optirs_core/optimizers/
lars.rs

1// Layer-wise Adaptive Rate Scaling (LARS) optimizer
2//
3// LARS is an optimization algorithm specifically designed for large batch training
4// in deep neural networks. It scales the learning rate for each layer based on the
5// ratio of the weight norm to the gradient norm.
6//
7// References:
8// - [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// Layer-wise Adaptive Rate Scaling (LARS) optimizer
17///
18/// LARS is an optimization algorithm specifically designed for large batch training,
19/// which allows scaling up the batch size significantly without loss of accuracy.
20/// It works by adapting the learning rate per layer based on the ratio of
21/// weight norm to gradient norm.
22///
23/// # Parameters
24///
25/// * `learning_rate` - Base learning rate
26/// * `momentum` - Momentum factor (default: 0.9)
27/// * `weight_decay` - Weight decay factor (default: 0.0001)
28/// * `trust_coefficient` - Trust coefficient for scaling (default: 0.001)
29/// * `eps` - Small constant for numerical stability (default: 1e-8)
30/// * `exclude_bias_and_norm` - Whether to exclude bias and normalization layers from LARS adaptation (default: true)
31///
32/// # Example
33///
34/// ```no_run
35/// use scirs2_core::ndarray::Array1;
36/// use optirs_core::optimizers::{LARS, Optimizer};
37///
38/// let mut optimizer = LARS::new(0.01)
39///     .with_momentum(0.9)
40///     .with_weight_decay(0.0001)
41///     .with_trust_coefficient(0.001);
42///
43/// let params = Array1::zeros(10);
44/// let gradients = Array1::ones(10);
45///
46/// let updated_params = optimizer.step(&params, &gradients).unwrap();
47/// // Parameters are automatically updated
48/// ```
49#[derive(Debug, Clone)]
50pub struct LARS<A: Float> {
51    learning_rate: A,
52    momentum: A,
53    weight_decay: A,
54    trust_coefficient: A,
55    eps: A,
56    exclude_bias_and_norm: bool,
57    velocity: Option<Vec<A>>,
58}
59
60impl<A: Float + ScalarOperand + Debug + Send + Sync> LARS<A> {
61    /// Create a new LARS optimizer with the given learning rate
62    pub fn new(learning_rate: A) -> Self {
63        Self {
64            learning_rate,
65            momentum: A::from(0.9).unwrap(),
66            weight_decay: A::from(0.0001).unwrap(),
67            trust_coefficient: A::from(0.001).unwrap(),
68            eps: A::from(1e-8).unwrap(),
69            exclude_bias_and_norm: true,
70            velocity: None,
71        }
72    }
73
74    /// Set the momentum factor
75    pub fn with_momentum(mut self, momentum: A) -> Self {
76        self.momentum = momentum;
77        self
78    }
79
80    /// Set the weight decay factor
81    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
82        self.weight_decay = weight_decay;
83        self
84    }
85
86    /// Set the trust coefficient
87    pub fn with_trust_coefficient(mut self, trust_coefficient: A) -> Self {
88        self.trust_coefficient = trust_coefficient;
89        self
90    }
91
92    /// Set the epsilon value for numerical stability
93    pub fn with_eps(mut self, eps: A) -> Self {
94        self.eps = eps;
95        self
96    }
97
98    /// Set whether to exclude bias and normalization layers from LARS adaptation
99    pub fn with_exclude_bias_and_norm(mut self, exclude_bias_and_norm: bool) -> Self {
100        self.exclude_bias_and_norm = exclude_bias_and_norm;
101        self
102    }
103
104    /// Reset the optimizer state
105    pub fn reset(&mut self) {
106        self.velocity = None;
107    }
108}
109
110impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
111    for LARS<A>
112{
113    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
114        // Initialize velocity if not already created
115        let n_params = gradients.len();
116        if self.velocity.is_none() {
117            self.velocity = Some(vec![A::zero(); n_params]);
118        }
119
120        let velocity = match &mut self.velocity {
121            Some(v) => {
122                if v.len() != n_params {
123                    return Err(OptimError::InvalidConfig(format!(
124                        "LARS velocity length ({}) does not match gradients length ({})",
125                        v.len(),
126                        n_params
127                    )));
128                }
129                v
130            }
131            None => unreachable!(), // We already initialized it
132        };
133
134        // Make a clone of parameters for calculating update
135        let params_clone = params.clone();
136
137        // Calculate the weight decay term
138        let weight_decay_term = if self.weight_decay > A::zero() {
139            &params_clone * self.weight_decay
140        } else {
141            Array::zeros(params.raw_dim())
142        };
143
144        // Calculate weight norm and gradient norm
145        let weight_norm = params_clone.mapv(|x| x * x).sum().sqrt();
146        let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
147
148        // Determine if we should apply LARS scaling
149        let should_apply_lars = !self.exclude_bias_and_norm || weight_norm > A::zero();
150
151        // Calculate local learning rate using trust ratio
152        let local_lr = if should_apply_lars && weight_norm > A::zero() && grad_norm > A::zero() {
153            self.trust_coefficient * weight_norm
154                / (grad_norm + self.weight_decay * weight_norm + self.eps)
155        } else {
156            A::one()
157        };
158
159        // Apply local learning rate scaling
160        let scaled_lr = self.learning_rate * local_lr;
161
162        // Calculate gradient update with weight decay
163        let update_raw = gradients + &weight_decay_term;
164
165        // Apply scaled learning rate
166        let update_scaled = update_raw * scaled_lr;
167
168        // Create output array - will be our result
169        let mut updated_params = params.clone();
170
171        // Apply momentum and update parameters
172        for (idx, (p, &update)) in updated_params
173            .iter_mut()
174            .zip(update_scaled.iter())
175            .enumerate()
176        {
177            // Update velocity with momentum
178            velocity[idx] = self.momentum * velocity[idx] + update;
179            // Update parameter
180            *p = *p - velocity[idx];
181        }
182
183        Ok(updated_params)
184    }
185
186    fn set_learning_rate(&mut self, learning_rate: A) {
187        self.learning_rate = learning_rate;
188    }
189
190    fn get_learning_rate(&self) -> A {
191        self.learning_rate
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use approx::assert_abs_diff_eq;
199    use scirs2_core::ndarray::Array1;
200
201    #[test]
202    fn test_lars_creation() {
203        let optimizer = LARS::new(0.01);
204        assert_abs_diff_eq!(optimizer.learning_rate, 0.01);
205        assert_abs_diff_eq!(optimizer.momentum, 0.9);
206        assert_abs_diff_eq!(optimizer.weight_decay, 0.0001);
207        assert_abs_diff_eq!(optimizer.trust_coefficient, 0.001);
208        assert_abs_diff_eq!(optimizer.eps, 1e-8);
209        assert!(optimizer.exclude_bias_and_norm);
210    }
211
212    #[test]
213    fn test_lars_builder() {
214        let optimizer = LARS::new(0.01)
215            .with_momentum(0.95)
216            .with_weight_decay(0.0005)
217            .with_trust_coefficient(0.01)
218            .with_eps(1e-6)
219            .with_exclude_bias_and_norm(false);
220
221        assert_abs_diff_eq!(optimizer.momentum, 0.95);
222        assert_abs_diff_eq!(optimizer.weight_decay, 0.0005);
223        assert_abs_diff_eq!(optimizer.trust_coefficient, 0.01);
224        assert_abs_diff_eq!(optimizer.eps, 1e-6);
225        assert!(!optimizer.exclude_bias_and_norm);
226    }
227
228    #[test]
229    fn test_lars_update() {
230        let mut optimizer = LARS::new(0.1)
231            .with_momentum(0.9)
232            .with_weight_decay(0.0)
233            .with_trust_coefficient(1.0);
234
235        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
236        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
237
238        // First update
239        let updated_params = optimizer.step(&params, &gradients).unwrap();
240
241        // LARS scaling factor with trust_coefficient=1.0 should be:
242        // weight_norm / grad_norm = sqrt(14) / sqrt(0.14) ≈ 10
243        // So the effective learning rate is 0.1 * 10 = 1.0
244        // Scale is approximately 10, but let's check actual value (more precise)
245        let weight_norm = params.mapv(|x| x * x).sum().sqrt();
246        let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
247        let scale = weight_norm / grad_norm;
248
249        assert_abs_diff_eq!(updated_params[0], 1.0 - 0.1 * scale * 0.1, epsilon = 1e-5);
250        assert_abs_diff_eq!(updated_params[1], 2.0 - 0.1 * scale * 0.2, epsilon = 1e-5);
251        assert_abs_diff_eq!(updated_params[2], 3.0 - 0.1 * scale * 0.3, epsilon = 1e-5);
252
253        // Second update should include momentum
254        let updated_params2 = optimizer.step(&updated_params, &gradients).unwrap();
255
256        // For the second update, the velocity will be updated with momentum
257        // Just check that parameters continue to change in the expected direction
258        assert!(updated_params2[0] < updated_params[0]);
259        assert!(updated_params2[1] < updated_params[1]);
260        assert!(updated_params2[2] < updated_params[2]);
261    }
262
263    #[test]
264    fn test_lars_weight_decay() {
265        let mut optimizer = LARS::new(0.01)
266            .with_momentum(0.0) // No momentum for clarity
267            .with_weight_decay(0.1)
268            .with_trust_coefficient(1.0);
269
270        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
271        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
272
273        let updated_params = optimizer.step(&params, &gradients).unwrap();
274
275        // Gradients with weight decay: [0.1, 0.2, 0.3] + 0.1*[1.0, 2.0, 3.0] = [0.2, 0.4, 0.6]
276        // LARS scaling factor includes weight decay in denominator
277        // weight_norm / (grad_norm + weight_decay * weight_norm)
278        // = sqrt(14) / (sqrt(0.56) + 0.1*sqrt(14))
279        let weight_norm = params.mapv(|x| x * x).sum().sqrt();
280        let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
281        let expected_scale = weight_norm / (grad_norm + 0.1 * weight_norm);
282
283        // Check calculation is approximately correct (allowing for floating point differences)
284        let expected_p0 = 1.0 - 0.01 * expected_scale * (0.1 + 0.1 * 1.0);
285        let expected_p1 = 2.0 - 0.01 * expected_scale * (0.2 + 0.1 * 2.0);
286        let expected_p2 = 3.0 - 0.01 * expected_scale * (0.3 + 0.1 * 3.0);
287
288        assert_abs_diff_eq!(updated_params[0], expected_p0, epsilon = 1e-5);
289        assert_abs_diff_eq!(updated_params[1], expected_p1, epsilon = 1e-5);
290        assert_abs_diff_eq!(updated_params[2], expected_p2, epsilon = 1e-5);
291    }
292
293    #[test]
294    fn test_zero_gradients() {
295        let mut optimizer = LARS::new(0.01);
296        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
297        let zero_gradients = Array1::zeros(3);
298
299        let updated_params = optimizer.step(&params, &zero_gradients).unwrap();
300
301        // With zero gradients, only weight decay should contribute to the update
302        // With small weight decay (0.0001), changes should be very small
303        assert_abs_diff_eq!(updated_params[0], params[0], epsilon = 1e-3);
304        assert_abs_diff_eq!(updated_params[1], params[1], epsilon = 1e-3);
305        assert_abs_diff_eq!(updated_params[2], params[2], epsilon = 1e-3);
306    }
307
308    #[test]
309    fn test_exclude_bias_and_norm() {
310        let mut optimizer_excluded = LARS::new(0.01)
311            .with_momentum(0.0)
312            .with_weight_decay(0.0)
313            .with_exclude_bias_and_norm(true);
314
315        let mut optimizer_included = LARS::new(0.01)
316            .with_momentum(0.0)
317            .with_weight_decay(0.0)
318            .with_exclude_bias_and_norm(false);
319
320        // Test with parameters that could be bias (small 1D array)
321        let bias_params = Array1::from_vec(vec![0.1, 0.2]);
322        let bias_grads = Array1::from_vec(vec![0.01, 0.02]);
323
324        let updated_excluded = optimizer_excluded.step(&bias_params, &bias_grads).unwrap();
325        let updated_included = optimizer_included.step(&bias_params, &bias_grads).unwrap();
326
327        // When excluded, should use base learning rate (but still include momentum calculation)
328        assert_abs_diff_eq!(updated_excluded[0], 0.1 - 0.01 * 0.01, epsilon = 1e-4);
329
330        // When included, should use LARS scaled learning rate
331        let weight_norm = (0.1f64.powi(2) + 0.2f64.powi(2)).sqrt();
332        let grad_norm = (0.01f64.powi(2) + 0.02f64.powi(2)).sqrt();
333        let expected_factor = 0.001 * weight_norm / grad_norm; // trust_coefficient * weight_norm / grad_norm
334
335        assert_abs_diff_eq!(
336            updated_included[0],
337            0.1 - 0.01 * expected_factor * 0.01,
338            epsilon = 1e-5
339        );
340    }
341}