use std::collections::HashMap;
use std::collections::hash_map::Entry;
use yscv_autograd::{Graph, NodeId};
use yscv_tensor::Tensor;
use super::validate::{validate_lr, validate_momentum};
use super::{LearningRate, OptimError};
#[derive(Debug, Clone)]
pub struct Lars {
base_lr: f32,
momentum: f32,
weight_decay: f32,
trust_coefficient: f32,
velocity: HashMap<u64, Tensor>,
}
impl Lars {
pub fn new(base_lr: f32) -> Result<Self, OptimError> {
validate_lr(base_lr)?;
Ok(Self {
base_lr,
momentum: 0.0,
weight_decay: 0.0,
trust_coefficient: 0.001,
velocity: HashMap::new(),
})
}
pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
validate_momentum(momentum)?;
self.momentum = momentum;
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
if !weight_decay.is_finite() || weight_decay < 0.0 {
return Err(OptimError::InvalidWeightDecay { weight_decay });
}
self.weight_decay = weight_decay;
Ok(self)
}
pub fn with_trust_coefficient(mut self, trust_coefficient: f32) -> Result<Self, OptimError> {
if !trust_coefficient.is_finite() || trust_coefficient <= 0.0 {
return Err(OptimError::InvalidEpsilon {
epsilon: trust_coefficient,
});
}
self.trust_coefficient = trust_coefficient;
Ok(self)
}
pub fn clear_state(&mut self) {
self.velocity.clear();
}
pub fn learning_rate(&self) -> f32 {
self.base_lr
}
pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
validate_lr(lr)?;
self.base_lr = lr;
Ok(())
}
pub fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError> {
if weights.shape() != grad.shape() {
return Err(OptimError::ShapeMismatch {
weights: weights.shape().to_vec(),
grad: grad.shape().to_vec(),
});
}
let w_data = weights.data();
let g_data = grad.data();
let w_norm = w_data.iter().map(|x| x * x).sum::<f32>().sqrt();
let g_norm = g_data.iter().map(|x| x * x).sum::<f32>().sqrt();
let local_lr = if w_norm > 0.0 && g_norm > 0.0 {
self.trust_coefficient * w_norm / (g_norm + self.weight_decay * w_norm)
} else {
1.0
};
let mut g_with_wd = g_data.to_vec();
if self.weight_decay != 0.0 {
for (gv, wv) in g_with_wd.iter_mut().zip(w_data.iter()) {
*gv += self.weight_decay * *wv;
}
}
let effective_lr = local_lr * self.base_lr;
let velocity = match self.velocity.entry(parameter_id) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Tensor::zeros(weights.shape().to_vec())?),
};
if velocity.shape() != weights.shape() {
*velocity = Tensor::zeros(weights.shape().to_vec())?;
}
let v_data = velocity.data_mut();
let weights_data = weights.data_mut();
for i in 0..weights_data.len() {
v_data[i] = self.momentum * v_data[i] + effective_lr * g_with_wd[i];
weights_data[i] -= v_data[i];
}
Ok(())
}
pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
if !graph.requires_grad(node)? {
return Ok(());
}
let grad = match graph.grad(node)? {
Some(grad) => grad.clone(),
None => return Err(OptimError::MissingGradient { node: node.0 }),
};
let weights = graph.value_mut(node)?;
self.step(node.0 as u64, weights, &grad)
}
}
impl LearningRate for Lars {
fn learning_rate(&self) -> f32 {
Lars::learning_rate(self)
}
fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
Lars::set_learning_rate(self, lr)
}
}