candle_optimisers/
adamax.rs1use candle_core::{Result, Var};
36use candle_nn::optim::Optimizer;
37
38use crate::{Decay, OptimParams};
39
40#[derive(Debug)]
45pub struct Adamax {
46 vars: Vec<VarAdaMax>,
47 params: ParamsAdaMax,
48 t: f64,
49}
50
51#[derive(Debug)]
52struct VarAdaMax {
53 theta: Var,
54 m: Var,
55 u: Var,
56}
57
58#[derive(Clone, Debug, PartialEq, PartialOrd)]
60pub struct ParamsAdaMax {
61 pub lr: f64,
63 pub beta_1: f64,
65 pub beta_2: f64,
67 pub weight_decay: Option<Decay>,
69 pub eps: f64,
71}
72
73impl Default for ParamsAdaMax {
74 fn default() -> Self {
75 Self {
76 lr: 1.0,
77 beta_1: 0.9,
78 beta_2: 0.999,
79 weight_decay: None,
80 eps: 1e-8,
81 }
82 }
83}
84
85impl Optimizer for Adamax {
86 type Config = ParamsAdaMax;
87
88 fn new(vars: Vec<Var>, params: ParamsAdaMax) -> Result<Self> {
89 let vars = vars
90 .into_iter()
91 .filter(|var| var.dtype().is_float())
92 .map(|var| {
93 let dtype = var.dtype();
94 let shape = var.shape();
95 let device = var.device();
96 let m = Var::zeros(shape, dtype, device)?;
97 let u = Var::zeros(shape, dtype, device)?;
98 Ok(VarAdaMax { theta: var, m, u })
99 })
100 .collect::<Result<Vec<VarAdaMax>>>()?;
101 Ok(Self {
105 vars,
106 params,
107 t: 1.,
108 })
109 }
110
111 fn learning_rate(&self) -> f64 {
112 self.params.lr
113 }
114
115 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
116 if let Some(decay) = self.params.weight_decay {
117 match decay {
118 Decay::WeightDecay(decay) => {
119 for var in &self.vars {
120 let theta = &var.theta;
121 let m = &var.m;
122 let u = &var.u;
123 if let Some(grad) = grads.get(theta) {
124 let grad = &(grad + (decay * theta.as_tensor())?)?;
125 let m_next = ((self.params.beta_1 * m.as_tensor())?
126 + (1. - self.params.beta_1) * grad)?;
127 let u_next = (self.params.beta_2 * u.as_tensor())?
128 .maximum(&(grad.abs()? + self.params.eps)?)?;
129 let delta = (&m_next * self.params.lr)?
130 .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
131 theta.set(&theta.sub(&(delta))?)?;
132 m.set(&m_next)?;
133 u.set(&u_next)?;
134 }
135 }
136 }
137 Decay::DecoupledWeightDecay(decay) => {
138 for var in &self.vars {
139 let theta = &var.theta;
140 let m = &var.m;
141 let u = &var.u;
142 if let Some(grad) = grads.get(theta) {
143 theta
145 .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
146 let m_next = ((self.params.beta_1 * m.as_tensor())?
147 + (1. - self.params.beta_1) * grad)?;
148 let u_next = (self.params.beta_2 * u.as_tensor())?
149 .maximum(&(grad.abs()? + self.params.eps)?)?;
150 let delta = (&m_next * self.params.lr)?
151 .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
152 theta.set(&theta.sub(&(delta))?)?;
153 m.set(&m_next)?;
154 u.set(&u_next)?;
155 }
156 }
157 }
158 }
159 } else {
160 for var in &self.vars {
161 let theta = &var.theta;
162 let m = &var.m;
163 let u = &var.u;
164 if let Some(grad) = grads.get(theta) {
165 let m_next =
166 ((self.params.beta_1 * m.as_tensor())? + (1. - self.params.beta_1) * grad)?;
167 let u_next = (self.params.beta_2 * u.as_tensor())?
168 .maximum(&(grad.abs()? + self.params.eps)?)?;
169 let delta = (&m_next * self.params.lr)?
170 .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
171 theta.set(&theta.sub(&(delta))?)?;
172 m.set(&m_next)?;
173 u.set(&u_next)?;
174 }
175 }
176 }
177 self.t += 1.;
178 Ok(())
179 }
180
181 fn set_learning_rate(&mut self, lr: f64) {
182 self.params.lr = lr;
183 }
184}
185
186impl OptimParams for Adamax {
187 fn params(&self) -> &Self::Config {
188 &self.params
189 }
190
191 fn set_params(&mut self, config: Self::Config) {
192 self.params = config;
193 }
194}
195
196impl Adamax {
197 #[must_use]
199 pub fn into_inner(self) -> Vec<Var> {
200 self.vars.into_iter().map(|v| v.theta).collect()
201 }
202
203 }
207
208#[cfg(test)]
209mod tests {
210 use anyhow::Result;
213 use assert_approx_eq::assert_approx_eq;
214 use candle_core::{Device, Var};
215 use candle_nn::Optimizer;
216
217 use super::*;
218 #[test]
219 fn lr_test() -> Result<()> {
220 let params = ParamsAdaMax {
221 lr: 0.004,
222 ..Default::default()
223 };
224 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226 let b = Var::new(0f32, &Device::Cpu)?;
227 let mut optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
228 assert_approx_eq!(0.004, optim.learning_rate());
229 optim.set_learning_rate(0.002);
230 assert_approx_eq!(0.002, optim.learning_rate());
231 Ok(())
232 }
233
234 #[test]
235 fn into_inner_test() -> Result<()> {
236 let params = ParamsAdaMax::default();
237 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238 let b = Var::new(-2f32, &Device::Cpu)?;
239 let optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
240 let inner = optim.into_inner();
241 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
242 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
243 Ok(())
244 }
245
246 #[test]
247 fn params_test() -> Result<()> {
248 let params = ParamsAdaMax {
249 lr: 0.004,
250 ..Default::default()
251 };
252 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254 let b = Var::new(0f32, &Device::Cpu)?;
255 let mut optim = Adamax::new(vec![w.clone(), b.clone()], params.clone())?;
256 assert_eq!(params, optim.params().clone());
257 let new_params = ParamsAdaMax {
258 lr: 0.002,
259 ..Default::default()
260 };
261 optim.set_params(new_params.clone());
262 assert_eq!(new_params, optim.params().clone());
263 Ok(())
264 }
265}