use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct Lars<F: Float + NumAssign + ScalarOperand + Debug> {
learning_rate: F,
momentum: F,
weight_decay: F,
trust_coefficient: F,
clip: bool,
epsilon: F,
velocity: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
t: usize,
}
impl<F: Float + NumAssign + ScalarOperand + Debug> Lars<F> {
pub fn new(
learning_rate: F,
momentum: F,
weight_decay: F,
trust_coefficient: F,
clip: bool,
) -> Result<Self> {
if learning_rate <= F::zero() {
return Err(NeuralError::InvalidArgument(
"learning_rate must be positive".to_string(),
));
}
if momentum < F::zero() || momentum >= F::one() {
return Err(NeuralError::InvalidArgument(
"momentum must be in [0, 1)".to_string(),
));
}
let epsilon = F::from(1e-8).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 1e-8 to the floating point type".to_string(),
)
})?;
Ok(Self {
learning_rate,
momentum,
weight_decay,
trust_coefficient,
clip,
epsilon,
velocity: Vec::new(),
t: 0,
})
}
pub fn default_with_lr(learning_rate: F) -> Result<Self> {
let momentum = F::from(0.9).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.9 to the floating point type".to_string(),
)
})?;
let weight_decay = F::from(1e-4).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 1e-4 to the floating point type".to_string(),
)
})?;
let trust_coefficient = F::from(0.001).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.001 to the floating point type".to_string(),
)
})?;
Self::new(learning_rate, momentum, weight_decay, trust_coefficient, false)
}
pub fn larc_with_lr(learning_rate: F) -> Result<Self> {
let momentum = F::from(0.9).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.9 to the floating point type".to_string(),
)
})?;
let weight_decay = F::from(1e-4).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 1e-4 to the floating point type".to_string(),
)
})?;
let trust_coefficient = F::from(0.001).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.001 to the floating point type".to_string(),
)
})?;
Self::new(learning_rate, momentum, weight_decay, trust_coefficient, true)
}
pub fn local_lr(&self, param_norm: F, grad_norm: F) -> F {
let denom = grad_norm + self.weight_decay * param_norm + self.epsilon;
if denom < self.epsilon {
return self.learning_rate;
}
let local = self.trust_coefficient * param_norm / denom;
if self.clip {
if local < self.learning_rate {
local
} else {
self.learning_rate
}
} else {
local
}
}
pub fn is_larc(&self) -> bool {
self.clip
}
pub fn get_trust_coefficient(&self) -> F {
self.trust_coefficient
}
pub fn get_momentum(&self) -> F {
self.momentum
}
pub fn get_weight_decay(&self) -> F {
self.weight_decay
}
pub fn reset_state(&mut self) {
self.velocity.clear();
self.t = 0;
}
}
impl<F: Float + NumAssign + ScalarOperand + Debug> Optimizer<F> for Lars<F> {
fn update(
&mut self,
params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params.len(),
grads.len()
)));
}
self.t += 1;
if self.velocity.len() != params.len() {
self.velocity = params
.iter()
.map(|p| Array::zeros(p.raw_dim()))
.collect();
}
for i in 0..params.len() {
let param_norm_sq: F = params[i].iter().fold(F::zero(), |acc, &x| acc + x * x);
let param_norm = param_norm_sq.sqrt();
let grad_norm_sq: F = grads[i].iter().fold(F::zero(), |acc, &x| acc + x * x);
let grad_norm = grad_norm_sq.sqrt();
let local_lr = self.local_lr(param_norm, grad_norm);
let effective_grad = if self.weight_decay > F::zero() {
&grads[i] + &(¶ms[i] * self.weight_decay)
} else {
grads[i].clone()
};
self.velocity[i] = &self.velocity[i] * self.momentum
+ &(&effective_grad * local_lr);
params[i] = ¶ms[i] - &self.velocity[i];
}
Ok(())
}
fn get_learning_rate(&self) -> F {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: F) {
self.learning_rate = lr;
}
fn reset(&mut self) {
self.reset_state();
}
fn name(&self) -> &'static str {
"LARS"
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, IxDyn};
fn make_param(vals: &[f64]) -> Array<f64, IxDyn> {
Array1::from_vec(vals.to_vec()).into_dyn()
}
#[test]
fn test_lars_default_config() {
let lars = Lars::<f64>::default_with_lr(0.01).expect("should succeed");
assert!((lars.get_learning_rate() - 0.01).abs() < 1e-12);
assert!((lars.get_momentum() - 0.9).abs() < 1e-12);
assert!((lars.get_weight_decay() - 1e-4).abs() < 1e-12);
assert!(!lars.is_larc());
}
#[test]
fn test_lars_local_lr_formula() {
let lars = Lars::<f64>::new(0.1, 0.9, 0.0, 0.01, false).expect("should succeed");
let local = lars.local_lr(3.0, 4.0);
let expected = 0.01 * 3.0 / (4.0 + 1e-8);
assert!(
(local - expected).abs() < 1e-6,
"local_lr={local}, expected={expected}"
);
}
#[test]
fn test_larc_clips_lr() {
let global_lr = 0.01;
let larc = Lars::<f64>::new(global_lr, 0.9, 0.0, 1.0, true).expect("should succeed");
let local = larc.local_lr(100.0, 1.0);
assert!(
local <= global_lr + 1e-12,
"LARC local_lr={local} must not exceed global_lr={global_lr}"
);
}
#[test]
fn test_lars_update_descends() {
let mut lars = Lars::<f64>::new(0.01, 0.9, 0.0, 0.001, false).expect("should succeed");
let mut params = vec![make_param(&[2.0_f64, 2.0])];
let grads = vec![make_param(&[1.0_f64, 1.0])];
let before_dot: f64 = params[0].iter().zip(grads[0].iter()).map(|(p, g)| p * g).sum();
lars.update(&mut params, &grads).expect("update should succeed");
let after_dot: f64 = params[0].iter().zip(grads[0].iter()).map(|(p, g)| p * g).sum();
assert!(
after_dot < before_dot,
"Update should reduce paramĀ·grad: before={before_dot}, after={after_dot}"
);
}
#[test]
fn test_lars_zero_grad() {
let mut lars = Lars::<f64>::new(0.01, 0.9, 0.0, 0.001, false).expect("should succeed");
let initial = vec![3.0_f64, -1.0, 2.0];
let mut params = vec![make_param(&initial)];
let grads = vec![make_param(&[0.0_f64, 0.0, 0.0])];
lars.update(&mut params, &grads).expect("zero grad update should succeed");
for (p, &orig) in params[0].iter().zip(initial.iter()) {
assert!(
(*p - orig).abs() < 1e-12,
"With zero grad and no weight_decay, params should not change: {p} vs {orig}"
);
}
}
#[test]
fn test_lars_mismatched_lengths() {
let mut lars = Lars::<f64>::default_with_lr(0.01).expect("should succeed");
let mut params = vec![make_param(&[1.0_f64, 2.0])];
let grads = vec![
make_param(&[0.1_f64, 0.2]),
make_param(&[0.3_f64, 0.4]),
];
assert!(
lars.update(&mut params, &grads).is_err(),
"Mismatched param/grad counts should error"
);
}
#[test]
fn test_larc_default_creation() {
let larc = Lars::<f64>::larc_with_lr(0.05).expect("should succeed");
assert!(larc.is_larc());
assert!((larc.get_learning_rate() - 0.05).abs() < 1e-12);
}
}