optirs_core/optimizers/
lars.rs1use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
50pub struct LARS<A: Float> {
51 learning_rate: A,
52 momentum: A,
53 weight_decay: A,
54 trust_coefficient: A,
55 eps: A,
56 exclude_bias_and_norm: bool,
57 velocity: Option<Vec<A>>,
58}
59
60impl<A: Float + ScalarOperand + Debug + Send + Sync> LARS<A> {
61 pub fn new(learning_rate: A) -> Self {
63 Self {
64 learning_rate,
65 momentum: A::from(0.9).unwrap(),
66 weight_decay: A::from(0.0001).unwrap(),
67 trust_coefficient: A::from(0.001).unwrap(),
68 eps: A::from(1e-8).unwrap(),
69 exclude_bias_and_norm: true,
70 velocity: None,
71 }
72 }
73
74 pub fn with_momentum(mut self, momentum: A) -> Self {
76 self.momentum = momentum;
77 self
78 }
79
80 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
82 self.weight_decay = weight_decay;
83 self
84 }
85
86 pub fn with_trust_coefficient(mut self, trust_coefficient: A) -> Self {
88 self.trust_coefficient = trust_coefficient;
89 self
90 }
91
92 pub fn with_eps(mut self, eps: A) -> Self {
94 self.eps = eps;
95 self
96 }
97
98 pub fn with_exclude_bias_and_norm(mut self, exclude_bias_and_norm: bool) -> Self {
100 self.exclude_bias_and_norm = exclude_bias_and_norm;
101 self
102 }
103
104 pub fn reset(&mut self) {
106 self.velocity = None;
107 }
108}
109
110impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
111 for LARS<A>
112{
113 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
114 let n_params = gradients.len();
116 if self.velocity.is_none() {
117 self.velocity = Some(vec![A::zero(); n_params]);
118 }
119
120 let velocity = match &mut self.velocity {
121 Some(v) => {
122 if v.len() != n_params {
123 return Err(OptimError::InvalidConfig(format!(
124 "LARS velocity length ({}) does not match gradients length ({})",
125 v.len(),
126 n_params
127 )));
128 }
129 v
130 }
131 None => unreachable!(), };
133
134 let params_clone = params.clone();
136
137 let weight_decay_term = if self.weight_decay > A::zero() {
139 ¶ms_clone * self.weight_decay
140 } else {
141 Array::zeros(params.raw_dim())
142 };
143
144 let weight_norm = params_clone.mapv(|x| x * x).sum().sqrt();
146 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
147
148 let should_apply_lars = !self.exclude_bias_and_norm || weight_norm > A::zero();
150
151 let local_lr = if should_apply_lars && weight_norm > A::zero() && grad_norm > A::zero() {
153 self.trust_coefficient * weight_norm
154 / (grad_norm + self.weight_decay * weight_norm + self.eps)
155 } else {
156 A::one()
157 };
158
159 let scaled_lr = self.learning_rate * local_lr;
161
162 let update_raw = gradients + &weight_decay_term;
164
165 let update_scaled = update_raw * scaled_lr;
167
168 let mut updated_params = params.clone();
170
171 for (idx, (p, &update)) in updated_params
173 .iter_mut()
174 .zip(update_scaled.iter())
175 .enumerate()
176 {
177 velocity[idx] = self.momentum * velocity[idx] + update;
179 *p = *p - velocity[idx];
181 }
182
183 Ok(updated_params)
184 }
185
186 fn set_learning_rate(&mut self, learning_rate: A) {
187 self.learning_rate = learning_rate;
188 }
189
190 fn get_learning_rate(&self) -> A {
191 self.learning_rate
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use approx::assert_abs_diff_eq;
199 use scirs2_core::ndarray::Array1;
200
201 #[test]
202 fn test_lars_creation() {
203 let optimizer = LARS::new(0.01);
204 assert_abs_diff_eq!(optimizer.learning_rate, 0.01);
205 assert_abs_diff_eq!(optimizer.momentum, 0.9);
206 assert_abs_diff_eq!(optimizer.weight_decay, 0.0001);
207 assert_abs_diff_eq!(optimizer.trust_coefficient, 0.001);
208 assert_abs_diff_eq!(optimizer.eps, 1e-8);
209 assert!(optimizer.exclude_bias_and_norm);
210 }
211
212 #[test]
213 fn test_lars_builder() {
214 let optimizer = LARS::new(0.01)
215 .with_momentum(0.95)
216 .with_weight_decay(0.0005)
217 .with_trust_coefficient(0.01)
218 .with_eps(1e-6)
219 .with_exclude_bias_and_norm(false);
220
221 assert_abs_diff_eq!(optimizer.momentum, 0.95);
222 assert_abs_diff_eq!(optimizer.weight_decay, 0.0005);
223 assert_abs_diff_eq!(optimizer.trust_coefficient, 0.01);
224 assert_abs_diff_eq!(optimizer.eps, 1e-6);
225 assert!(!optimizer.exclude_bias_and_norm);
226 }
227
228 #[test]
229 fn test_lars_update() {
230 let mut optimizer = LARS::new(0.1)
231 .with_momentum(0.9)
232 .with_weight_decay(0.0)
233 .with_trust_coefficient(1.0);
234
235 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
236 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
237
238 let updated_params = optimizer.step(¶ms, &gradients).unwrap();
240
241 let weight_norm = params.mapv(|x| x * x).sum().sqrt();
246 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
247 let scale = weight_norm / grad_norm;
248
249 assert_abs_diff_eq!(updated_params[0], 1.0 - 0.1 * scale * 0.1, epsilon = 1e-5);
250 assert_abs_diff_eq!(updated_params[1], 2.0 - 0.1 * scale * 0.2, epsilon = 1e-5);
251 assert_abs_diff_eq!(updated_params[2], 3.0 - 0.1 * scale * 0.3, epsilon = 1e-5);
252
253 let updated_params2 = optimizer.step(&updated_params, &gradients).unwrap();
255
256 assert!(updated_params2[0] < updated_params[0]);
259 assert!(updated_params2[1] < updated_params[1]);
260 assert!(updated_params2[2] < updated_params[2]);
261 }
262
263 #[test]
264 fn test_lars_weight_decay() {
265 let mut optimizer = LARS::new(0.01)
266 .with_momentum(0.0) .with_weight_decay(0.1)
268 .with_trust_coefficient(1.0);
269
270 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
271 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
272
273 let updated_params = optimizer.step(¶ms, &gradients).unwrap();
274
275 let weight_norm = params.mapv(|x| x * x).sum().sqrt();
280 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
281 let expected_scale = weight_norm / (grad_norm + 0.1 * weight_norm);
282
283 let expected_p0 = 1.0 - 0.01 * expected_scale * (0.1 + 0.1 * 1.0);
285 let expected_p1 = 2.0 - 0.01 * expected_scale * (0.2 + 0.1 * 2.0);
286 let expected_p2 = 3.0 - 0.01 * expected_scale * (0.3 + 0.1 * 3.0);
287
288 assert_abs_diff_eq!(updated_params[0], expected_p0, epsilon = 1e-5);
289 assert_abs_diff_eq!(updated_params[1], expected_p1, epsilon = 1e-5);
290 assert_abs_diff_eq!(updated_params[2], expected_p2, epsilon = 1e-5);
291 }
292
293 #[test]
294 fn test_zero_gradients() {
295 let mut optimizer = LARS::new(0.01);
296 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
297 let zero_gradients = Array1::zeros(3);
298
299 let updated_params = optimizer.step(¶ms, &zero_gradients).unwrap();
300
301 assert_abs_diff_eq!(updated_params[0], params[0], epsilon = 1e-3);
304 assert_abs_diff_eq!(updated_params[1], params[1], epsilon = 1e-3);
305 assert_abs_diff_eq!(updated_params[2], params[2], epsilon = 1e-3);
306 }
307
308 #[test]
309 fn test_exclude_bias_and_norm() {
310 let mut optimizer_excluded = LARS::new(0.01)
311 .with_momentum(0.0)
312 .with_weight_decay(0.0)
313 .with_exclude_bias_and_norm(true);
314
315 let mut optimizer_included = LARS::new(0.01)
316 .with_momentum(0.0)
317 .with_weight_decay(0.0)
318 .with_exclude_bias_and_norm(false);
319
320 let bias_params = Array1::from_vec(vec![0.1, 0.2]);
322 let bias_grads = Array1::from_vec(vec![0.01, 0.02]);
323
324 let updated_excluded = optimizer_excluded.step(&bias_params, &bias_grads).unwrap();
325 let updated_included = optimizer_included.step(&bias_params, &bias_grads).unwrap();
326
327 assert_abs_diff_eq!(updated_excluded[0], 0.1 - 0.01 * 0.01, epsilon = 1e-4);
329
330 let weight_norm = (0.1f64.powi(2) + 0.2f64.powi(2)).sqrt();
332 let grad_norm = (0.01f64.powi(2) + 0.02f64.powi(2)).sqrt();
333 let expected_factor = 0.001 * weight_norm / grad_norm; assert_abs_diff_eq!(
336 updated_included[0],
337 0.1 - 0.01 * expected_factor * 0.01,
338 epsilon = 1e-5
339 );
340 }
341}