candle_nn/
optim.rs

1//! Various optimization algorithms.
2use candle::{Result, Tensor, Var};
3
4/// The interface optimizers should implement.
5pub trait Optimizer: Sized {
6    type Config: Sized;
7
8    fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
9
10    fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
11
12    fn learning_rate(&self) -> f64;
13
14    fn set_learning_rate(&mut self, lr: f64);
15
16    fn empty(config: Self::Config) -> Result<Self> {
17        Self::new(vec![], config)
18    }
19
20    fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
21        let grads = loss.backward()?;
22        self.step(&grads)
23    }
24
25    fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
26        let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
27        Self::new(vars, config)
28    }
29}
30
31/// Optimizer for Stochastic Gradient Descent.
32///
33/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
34#[derive(Debug)]
35pub struct SGD {
36    vars: Vec<Var>,
37    learning_rate: f64,
38}
39
40impl Optimizer for SGD {
41    type Config = f64;
42
43    fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
44        let vars = vars
45            .into_iter()
46            .filter(|var| var.dtype().is_float())
47            .collect();
48        Ok(Self {
49            vars,
50            learning_rate,
51        })
52    }
53
54    fn learning_rate(&self) -> f64 {
55        self.learning_rate
56    }
57
58    fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
59        for var in self.vars.iter() {
60            if let Some(grad) = grads.get(var) {
61                var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
62            }
63        }
64        Ok(())
65    }
66
67    fn set_learning_rate(&mut self, lr: f64) {
68        self.learning_rate = lr
69    }
70}
71
72impl SGD {
73    pub fn into_inner(self) -> Vec<Var> {
74        self.vars
75    }
76
77    pub fn push(&mut self, var: &Var) {
78        self.vars.push(var.clone())
79    }
80}
81
82#[derive(Clone, Debug)]
83pub struct ParamsAdamW {
84    pub lr: f64,
85    pub beta1: f64,
86    pub beta2: f64,
87    pub eps: f64,
88    pub weight_decay: f64,
89}
90
91impl Default for ParamsAdamW {
92    fn default() -> Self {
93        Self {
94            lr: 0.001,
95            beta1: 0.9,
96            beta2: 0.999,
97            eps: 1e-8,
98            weight_decay: 0.01,
99        }
100    }
101}
102
103#[derive(Debug)]
104struct VarAdamW {
105    var: Var,
106    first_moment: Var,
107    second_moment: Var,
108}
109
110#[derive(Debug)]
111pub struct AdamW {
112    vars: Vec<VarAdamW>,
113    step_t: usize,
114    params: ParamsAdamW,
115}
116
117impl Optimizer for AdamW {
118    type Config = ParamsAdamW;
119
120    fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
121        let vars = vars
122            .into_iter()
123            .filter(|var| var.dtype().is_float())
124            .map(|var| {
125                let dtype = var.dtype();
126                let shape = var.shape();
127                let device = var.device();
128                let first_moment = Var::zeros(shape, dtype, device)?;
129                let second_moment = Var::zeros(shape, dtype, device)?;
130                Ok(VarAdamW {
131                    var,
132                    first_moment,
133                    second_moment,
134                })
135            })
136            .collect::<Result<Vec<_>>>()?;
137        Ok(Self {
138            vars,
139            params,
140            step_t: 0,
141        })
142    }
143
144    fn learning_rate(&self) -> f64 {
145        self.params.lr
146    }
147
148    fn set_learning_rate(&mut self, lr: f64) {
149        self.params.lr = lr
150    }
151
152    fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
153        self.step_t += 1;
154        let lr = self.params.lr;
155        let lambda = self.params.weight_decay;
156        let lr_lambda = lr * lambda;
157        let beta1 = self.params.beta1;
158        let beta2 = self.params.beta2;
159        let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32));
160        let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32));
161        for var in self.vars.iter() {
162            let theta = &var.var;
163            let m = &var.first_moment;
164            let v = &var.second_moment;
165            if let Some(g) = grads.get(theta) {
166                // This involves locking 3 RWLocks per params, if the parameters are large this
167                // should not be an issue but this may be problematic with models with lots of
168                // small parameters.
169                let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?;
170                let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?;
171                let m_hat = (&next_m * scale_m)?;
172                let v_hat = (&next_v * scale_v)?;
173                let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?;
174                let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?;
175                let next_theta = (next_theta - (adjusted_grad * lr)?)?;
176                m.set(&next_m)?;
177                v.set(&next_v)?;
178                theta.set(&next_theta)?;
179            }
180        }
181        Ok(())
182    }
183}
184
185impl AdamW {
186    pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
187        let params = ParamsAdamW {
188            lr: learning_rate,
189            ..ParamsAdamW::default()
190        };
191        Self::new(vars, params)
192    }
193
194    pub fn params(&self) -> &ParamsAdamW {
195        &self.params
196    }
197
198    pub fn set_params(&mut self, params: ParamsAdamW) {
199        self.params = params;
200    }
201}