use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct LARS<A: Float> {
learning_rate: A,
momentum: A,
weight_decay: A,
trust_coefficient: A,
eps: A,
exclude_bias_and_norm: bool,
velocity: Option<Vec<A>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> LARS<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
momentum: A::from(0.9).expect("unwrap failed"),
weight_decay: A::from(0.0001).expect("unwrap failed"),
trust_coefficient: A::from(0.001).expect("unwrap failed"),
eps: A::from(1e-8).expect("unwrap failed"),
exclude_bias_and_norm: true,
velocity: None,
}
}
pub fn with_momentum(mut self, momentum: A) -> Self {
self.momentum = momentum;
self
}
pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn with_trust_coefficient(mut self, trust_coefficient: A) -> Self {
self.trust_coefficient = trust_coefficient;
self
}
pub fn with_eps(mut self, eps: A) -> Self {
self.eps = eps;
self
}
pub fn with_exclude_bias_and_norm(mut self, exclude_bias_and_norm: bool) -> Self {
self.exclude_bias_and_norm = exclude_bias_and_norm;
self
}
pub fn reset(&mut self) {
self.velocity = None;
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
for LARS<A>
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let n_params = gradients.len();
if self.velocity.is_none() {
self.velocity = Some(vec![A::zero(); n_params]);
}
let velocity = match &mut self.velocity {
Some(v) => {
if v.len() != n_params {
return Err(OptimError::InvalidConfig(format!(
"LARS velocity length ({}) does not match gradients length ({})",
v.len(),
n_params
)));
}
v
}
None => unreachable!(), };
let params_clone = params.clone();
let weight_decay_term = if self.weight_decay > A::zero() {
¶ms_clone * self.weight_decay
} else {
Array::zeros(params.raw_dim())
};
let weight_norm = params_clone.mapv(|x| x * x).sum().sqrt();
let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
let should_apply_lars = !self.exclude_bias_and_norm || weight_norm > A::zero();
let local_lr = if should_apply_lars && weight_norm > A::zero() && grad_norm > A::zero() {
self.trust_coefficient * weight_norm
/ (grad_norm + self.weight_decay * weight_norm + self.eps)
} else {
A::one()
};
let scaled_lr = self.learning_rate * local_lr;
let update_raw = gradients + &weight_decay_term;
let update_scaled = update_raw * scaled_lr;
let mut updated_params = params.clone();
for (idx, (p, &update)) in updated_params
.iter_mut()
.zip(update_scaled.iter())
.enumerate()
{
velocity[idx] = self.momentum * velocity[idx] + update;
*p = *p - velocity[idx];
}
Ok(updated_params)
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_lars_creation() {
let optimizer = LARS::new(0.01);
assert_abs_diff_eq!(optimizer.learning_rate, 0.01);
assert_abs_diff_eq!(optimizer.momentum, 0.9);
assert_abs_diff_eq!(optimizer.weight_decay, 0.0001);
assert_abs_diff_eq!(optimizer.trust_coefficient, 0.001);
assert_abs_diff_eq!(optimizer.eps, 1e-8);
assert!(optimizer.exclude_bias_and_norm);
}
#[test]
fn test_lars_builder() {
let optimizer = LARS::new(0.01)
.with_momentum(0.95)
.with_weight_decay(0.0005)
.with_trust_coefficient(0.01)
.with_eps(1e-6)
.with_exclude_bias_and_norm(false);
assert_abs_diff_eq!(optimizer.momentum, 0.95);
assert_abs_diff_eq!(optimizer.weight_decay, 0.0005);
assert_abs_diff_eq!(optimizer.trust_coefficient, 0.01);
assert_abs_diff_eq!(optimizer.eps, 1e-6);
assert!(!optimizer.exclude_bias_and_norm);
}
#[test]
fn test_lars_update() {
let mut optimizer = LARS::new(0.1)
.with_momentum(0.9)
.with_weight_decay(0.0)
.with_trust_coefficient(1.0);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let weight_norm = params.mapv(|x| x * x).sum().sqrt();
let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
let scale = weight_norm / grad_norm;
assert_abs_diff_eq!(updated_params[0], 1.0 - 0.1 * scale * 0.1, epsilon = 1e-5);
assert_abs_diff_eq!(updated_params[1], 2.0 - 0.1 * scale * 0.2, epsilon = 1e-5);
assert_abs_diff_eq!(updated_params[2], 3.0 - 0.1 * scale * 0.3, epsilon = 1e-5);
let updated_params2 = optimizer
.step(&updated_params, &gradients)
.expect("unwrap failed");
assert!(updated_params2[0] < updated_params[0]);
assert!(updated_params2[1] < updated_params[1]);
assert!(updated_params2[2] < updated_params[2]);
}
#[test]
fn test_lars_weight_decay() {
let mut optimizer = LARS::new(0.01)
.with_momentum(0.0) .with_weight_decay(0.1)
.with_trust_coefficient(1.0);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let updated_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let weight_norm = params.mapv(|x| x * x).sum().sqrt();
let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
let expected_scale = weight_norm / (grad_norm + 0.1 * weight_norm);
let expected_p0 = 1.0 - 0.01 * expected_scale * (0.1 + 0.1 * 1.0);
let expected_p1 = 2.0 - 0.01 * expected_scale * (0.2 + 0.1 * 2.0);
let expected_p2 = 3.0 - 0.01 * expected_scale * (0.3 + 0.1 * 3.0);
assert_abs_diff_eq!(updated_params[0], expected_p0, epsilon = 1e-5);
assert_abs_diff_eq!(updated_params[1], expected_p1, epsilon = 1e-5);
assert_abs_diff_eq!(updated_params[2], expected_p2, epsilon = 1e-5);
}
#[test]
fn test_zero_gradients() {
let mut optimizer = LARS::new(0.01);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let zero_gradients = Array1::zeros(3);
let updated_params = optimizer
.step(¶ms, &zero_gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0], params[0], epsilon = 1e-3);
assert_abs_diff_eq!(updated_params[1], params[1], epsilon = 1e-3);
assert_abs_diff_eq!(updated_params[2], params[2], epsilon = 1e-3);
}
#[test]
fn test_exclude_bias_and_norm() {
let mut optimizer_excluded = LARS::new(0.01)
.with_momentum(0.0)
.with_weight_decay(0.0)
.with_exclude_bias_and_norm(true);
let mut optimizer_included = LARS::new(0.01)
.with_momentum(0.0)
.with_weight_decay(0.0)
.with_exclude_bias_and_norm(false);
let bias_params = Array1::from_vec(vec![0.1, 0.2]);
let bias_grads = Array1::from_vec(vec![0.01, 0.02]);
let updated_excluded = optimizer_excluded
.step(&bias_params, &bias_grads)
.expect("unwrap failed");
let updated_included = optimizer_included
.step(&bias_params, &bias_grads)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_excluded[0], 0.1 - 0.01 * 0.01, epsilon = 1e-4);
let weight_norm = (0.1f64.powi(2) + 0.2f64.powi(2)).sqrt();
let grad_norm = (0.01f64.powi(2) + 0.02f64.powi(2)).sqrt();
let expected_factor = 0.001 * weight_norm / grad_norm;
assert_abs_diff_eq!(
updated_included[0],
0.1 - 0.01 * expected_factor * 0.01,
epsilon = 1e-5
);
}
}