Skip to main content

yscv_optim/
lars.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_lr, validate_momentum};
8use super::{LearningRate, OptimError};
9
10/// Layer-wise Adaptive Rate Scaling (LARS) optimizer.
11///
12/// Scales the learning rate per layer using the ratio of parameter norm to
13/// gradient norm, enabling stable training with very large batch sizes.
14#[derive(Debug, Clone)]
15pub struct Lars {
16    base_lr: f32,
17    momentum: f32,
18    weight_decay: f32,
19    trust_coefficient: f32,
20    velocity: HashMap<u64, Tensor>,
21}
22
23impl Lars {
24    /// Creates LARS with required base learning rate.
25    pub fn new(base_lr: f32) -> Result<Self, OptimError> {
26        validate_lr(base_lr)?;
27        Ok(Self {
28            base_lr,
29            momentum: 0.0,
30            weight_decay: 0.0,
31            trust_coefficient: 0.001,
32            velocity: HashMap::new(),
33        })
34    }
35
36    /// Sets momentum factor in `[0, 1)`.
37    pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
38        validate_momentum(momentum)?;
39        self.momentum = momentum;
40        Ok(self)
41    }
42
43    /// Sets L2 weight decay factor in `[0, +inf)`.
44    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
45        if !weight_decay.is_finite() || weight_decay < 0.0 {
46            return Err(OptimError::InvalidWeightDecay { weight_decay });
47        }
48        self.weight_decay = weight_decay;
49        Ok(self)
50    }
51
52    /// Sets trust coefficient for the local learning rate scaling.
53    pub fn with_trust_coefficient(mut self, trust_coefficient: f32) -> Result<Self, OptimError> {
54        if !trust_coefficient.is_finite() || trust_coefficient <= 0.0 {
55            return Err(OptimError::InvalidEpsilon {
56                epsilon: trust_coefficient,
57            });
58        }
59        self.trust_coefficient = trust_coefficient;
60        Ok(self)
61    }
62
63    /// Drops optimizer state (for example when restarting training).
64    pub fn clear_state(&mut self) {
65        self.velocity.clear();
66    }
67
68    /// Returns current learning rate.
69    pub fn learning_rate(&self) -> f32 {
70        self.base_lr
71    }
72
73    /// Overrides current learning rate.
74    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
75        validate_lr(lr)?;
76        self.base_lr = lr;
77        Ok(())
78    }
79
80    /// Applies one update to raw tensor weights.
81    pub fn step(
82        &mut self,
83        parameter_id: u64,
84        weights: &mut Tensor,
85        grad: &Tensor,
86    ) -> Result<(), OptimError> {
87        if weights.shape() != grad.shape() {
88            return Err(OptimError::ShapeMismatch {
89                weights: weights.shape().to_vec(),
90                grad: grad.shape().to_vec(),
91            });
92        }
93
94        // Compute weight norm and gradient norm.
95        let w_data = weights.data();
96        let g_data = grad.data();
97
98        let w_norm = w_data.iter().map(|x| x * x).sum::<f32>().sqrt();
99        let g_norm = g_data.iter().map(|x| x * x).sum::<f32>().sqrt();
100
101        // Compute local learning rate.
102        let local_lr = if w_norm > 0.0 && g_norm > 0.0 {
103            self.trust_coefficient * w_norm / (g_norm + self.weight_decay * w_norm)
104        } else {
105            1.0
106        };
107
108        // Compute gradient with weight decay: g_with_wd = g + weight_decay * w
109        let mut g_with_wd = g_data.to_vec();
110        if self.weight_decay != 0.0 {
111            for (gv, wv) in g_with_wd.iter_mut().zip(w_data.iter()) {
112                *gv += self.weight_decay * *wv;
113            }
114        }
115
116        let effective_lr = local_lr * self.base_lr;
117
118        // Update velocity and weights.
119        let velocity = match self.velocity.entry(parameter_id) {
120            Entry::Occupied(entry) => entry.into_mut(),
121            Entry::Vacant(entry) => entry.insert(Tensor::zeros(weights.shape().to_vec())?),
122        };
123
124        if velocity.shape() != weights.shape() {
125            *velocity = Tensor::zeros(weights.shape().to_vec())?;
126        }
127
128        let v_data = velocity.data_mut();
129        let weights_data = weights.data_mut();
130
131        for i in 0..weights_data.len() {
132            v_data[i] = self.momentum * v_data[i] + effective_lr * g_with_wd[i];
133            weights_data[i] -= v_data[i];
134        }
135
136        Ok(())
137    }
138
139    /// Applies one update to a trainable graph node by its `NodeId`.
140    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
141        if !graph.requires_grad(node)? {
142            return Ok(());
143        }
144
145        let grad = match graph.grad(node)? {
146            Some(grad) => grad.clone(),
147            None => return Err(OptimError::MissingGradient { node: node.0 }),
148        };
149        let weights = graph.value_mut(node)?;
150        self.step(node.0 as u64, weights, &grad)
151    }
152}
153
154impl LearningRate for Lars {
155    fn learning_rate(&self) -> f32 {
156        Lars::learning_rate(self)
157    }
158
159    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
160        Lars::set_learning_rate(self, lr)
161    }
162}