optirs_core/optimizers/
lars.rs1use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
50pub struct LARS<A: Float> {
51 learning_rate: A,
52 momentum: A,
53 weight_decay: A,
54 trust_coefficient: A,
55 eps: A,
56 exclude_bias_and_norm: bool,
57 velocity: Option<Vec<A>>,
58}
59
60impl<A: Float + ScalarOperand + Debug + Send + Sync> LARS<A> {
61 pub fn new(learning_rate: A) -> Self {
63 Self {
64 learning_rate,
65 momentum: A::from(0.9).expect("unwrap failed"),
66 weight_decay: A::from(0.0001).expect("unwrap failed"),
67 trust_coefficient: A::from(0.001).expect("unwrap failed"),
68 eps: A::from(1e-8).expect("unwrap failed"),
69 exclude_bias_and_norm: true,
70 velocity: None,
71 }
72 }
73
74 pub fn with_momentum(mut self, momentum: A) -> Self {
76 self.momentum = momentum;
77 self
78 }
79
80 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
82 self.weight_decay = weight_decay;
83 self
84 }
85
86 pub fn with_trust_coefficient(mut self, trust_coefficient: A) -> Self {
88 self.trust_coefficient = trust_coefficient;
89 self
90 }
91
92 pub fn with_eps(mut self, eps: A) -> Self {
94 self.eps = eps;
95 self
96 }
97
98 pub fn with_exclude_bias_and_norm(mut self, exclude_bias_and_norm: bool) -> Self {
100 self.exclude_bias_and_norm = exclude_bias_and_norm;
101 self
102 }
103
104 pub fn reset(&mut self) {
106 self.velocity = None;
107 }
108}
109
110impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
111 for LARS<A>
112{
113 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
114 let n_params = gradients.len();
116 if self.velocity.is_none() {
117 self.velocity = Some(vec![A::zero(); n_params]);
118 }
119
120 let velocity = match &mut self.velocity {
121 Some(v) => {
122 if v.len() != n_params {
123 return Err(OptimError::InvalidConfig(format!(
124 "LARS velocity length ({}) does not match gradients length ({})",
125 v.len(),
126 n_params
127 )));
128 }
129 v
130 }
131 None => unreachable!(), };
133
134 let params_clone = params.clone();
136
137 let weight_decay_term = if self.weight_decay > A::zero() {
139 ¶ms_clone * self.weight_decay
140 } else {
141 Array::zeros(params.raw_dim())
142 };
143
144 let weight_norm = params_clone.mapv(|x| x * x).sum().sqrt();
146 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
147
148 let should_apply_lars = !self.exclude_bias_and_norm || weight_norm > A::zero();
150
151 let local_lr = if should_apply_lars && weight_norm > A::zero() && grad_norm > A::zero() {
153 self.trust_coefficient * weight_norm
154 / (grad_norm + self.weight_decay * weight_norm + self.eps)
155 } else {
156 A::one()
157 };
158
159 let scaled_lr = self.learning_rate * local_lr;
161
162 let update_raw = gradients + &weight_decay_term;
164
165 let update_scaled = update_raw * scaled_lr;
167
168 let mut updated_params = params.clone();
170
171 for (idx, (p, &update)) in updated_params
173 .iter_mut()
174 .zip(update_scaled.iter())
175 .enumerate()
176 {
177 velocity[idx] = self.momentum * velocity[idx] + update;
179 *p = *p - velocity[idx];
181 }
182
183 Ok(updated_params)
184 }
185
186 fn set_learning_rate(&mut self, learning_rate: A) {
187 self.learning_rate = learning_rate;
188 }
189
190 fn get_learning_rate(&self) -> A {
191 self.learning_rate
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use approx::assert_abs_diff_eq;
199 use scirs2_core::ndarray::Array1;
200
201 #[test]
202 fn test_lars_creation() {
203 let optimizer = LARS::new(0.01);
204 assert_abs_diff_eq!(optimizer.learning_rate, 0.01);
205 assert_abs_diff_eq!(optimizer.momentum, 0.9);
206 assert_abs_diff_eq!(optimizer.weight_decay, 0.0001);
207 assert_abs_diff_eq!(optimizer.trust_coefficient, 0.001);
208 assert_abs_diff_eq!(optimizer.eps, 1e-8);
209 assert!(optimizer.exclude_bias_and_norm);
210 }
211
212 #[test]
213 fn test_lars_builder() {
214 let optimizer = LARS::new(0.01)
215 .with_momentum(0.95)
216 .with_weight_decay(0.0005)
217 .with_trust_coefficient(0.01)
218 .with_eps(1e-6)
219 .with_exclude_bias_and_norm(false);
220
221 assert_abs_diff_eq!(optimizer.momentum, 0.95);
222 assert_abs_diff_eq!(optimizer.weight_decay, 0.0005);
223 assert_abs_diff_eq!(optimizer.trust_coefficient, 0.01);
224 assert_abs_diff_eq!(optimizer.eps, 1e-6);
225 assert!(!optimizer.exclude_bias_and_norm);
226 }
227
228 #[test]
229 fn test_lars_update() {
230 let mut optimizer = LARS::new(0.1)
231 .with_momentum(0.9)
232 .with_weight_decay(0.0)
233 .with_trust_coefficient(1.0);
234
235 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
236 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
237
238 let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
240
241 let weight_norm = params.mapv(|x| x * x).sum().sqrt();
246 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
247 let scale = weight_norm / grad_norm;
248
249 assert_abs_diff_eq!(updated_params[0], 1.0 - 0.1 * scale * 0.1, epsilon = 1e-5);
250 assert_abs_diff_eq!(updated_params[1], 2.0 - 0.1 * scale * 0.2, epsilon = 1e-5);
251 assert_abs_diff_eq!(updated_params[2], 3.0 - 0.1 * scale * 0.3, epsilon = 1e-5);
252
253 let updated_params2 = optimizer
255 .step(&updated_params, &gradients)
256 .expect("unwrap failed");
257
258 assert!(updated_params2[0] < updated_params[0]);
261 assert!(updated_params2[1] < updated_params[1]);
262 assert!(updated_params2[2] < updated_params[2]);
263 }
264
265 #[test]
266 fn test_lars_weight_decay() {
267 let mut optimizer = LARS::new(0.01)
268 .with_momentum(0.0) .with_weight_decay(0.1)
270 .with_trust_coefficient(1.0);
271
272 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
273 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
274
275 let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
276
277 let weight_norm = params.mapv(|x| x * x).sum().sqrt();
282 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
283 let expected_scale = weight_norm / (grad_norm + 0.1 * weight_norm);
284
285 let expected_p0 = 1.0 - 0.01 * expected_scale * (0.1 + 0.1 * 1.0);
287 let expected_p1 = 2.0 - 0.01 * expected_scale * (0.2 + 0.1 * 2.0);
288 let expected_p2 = 3.0 - 0.01 * expected_scale * (0.3 + 0.1 * 3.0);
289
290 assert_abs_diff_eq!(updated_params[0], expected_p0, epsilon = 1e-5);
291 assert_abs_diff_eq!(updated_params[1], expected_p1, epsilon = 1e-5);
292 assert_abs_diff_eq!(updated_params[2], expected_p2, epsilon = 1e-5);
293 }
294
295 #[test]
296 fn test_zero_gradients() {
297 let mut optimizer = LARS::new(0.01);
298 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
299 let zero_gradients = Array1::zeros(3);
300
301 let updated_params = optimizer
302 .step(¶ms, &zero_gradients)
303 .expect("unwrap failed");
304
305 assert_abs_diff_eq!(updated_params[0], params[0], epsilon = 1e-3);
308 assert_abs_diff_eq!(updated_params[1], params[1], epsilon = 1e-3);
309 assert_abs_diff_eq!(updated_params[2], params[2], epsilon = 1e-3);
310 }
311
312 #[test]
313 fn test_exclude_bias_and_norm() {
314 let mut optimizer_excluded = LARS::new(0.01)
315 .with_momentum(0.0)
316 .with_weight_decay(0.0)
317 .with_exclude_bias_and_norm(true);
318
319 let mut optimizer_included = LARS::new(0.01)
320 .with_momentum(0.0)
321 .with_weight_decay(0.0)
322 .with_exclude_bias_and_norm(false);
323
324 let bias_params = Array1::from_vec(vec![0.1, 0.2]);
326 let bias_grads = Array1::from_vec(vec![0.01, 0.02]);
327
328 let updated_excluded = optimizer_excluded
329 .step(&bias_params, &bias_grads)
330 .expect("unwrap failed");
331 let updated_included = optimizer_included
332 .step(&bias_params, &bias_grads)
333 .expect("unwrap failed");
334
335 assert_abs_diff_eq!(updated_excluded[0], 0.1 - 0.01 * 0.01, epsilon = 1e-4);
337
338 let weight_norm = (0.1f64.powi(2) + 0.2f64.powi(2)).sqrt();
340 let grad_norm = (0.01f64.powi(2) + 0.02f64.powi(2)).sqrt();
341 let expected_factor = 0.001 * weight_norm / grad_norm; assert_abs_diff_eq!(
344 updated_included[0],
345 0.1 - 0.01 * expected_factor * 0.01,
346 epsilon = 1e-5
347 );
348 }
349}