1use candle::{Result, Tensor, Var};
3
4pub trait Optimizer: Sized {
6 type Config: Sized;
7
8 fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
9
10 fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
11
12 fn learning_rate(&self) -> f64;
13
14 fn set_learning_rate(&mut self, lr: f64);
15
16 fn empty(config: Self::Config) -> Result<Self> {
17 Self::new(vec![], config)
18 }
19
20 fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
21 let grads = loss.backward()?;
22 self.step(&grads)
23 }
24
25 fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
26 let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
27 Self::new(vars, config)
28 }
29}
30
31#[derive(Debug)]
35pub struct SGD {
36 vars: Vec<Var>,
37 learning_rate: f64,
38}
39
40impl Optimizer for SGD {
41 type Config = f64;
42
43 fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
44 let vars = vars
45 .into_iter()
46 .filter(|var| var.dtype().is_float())
47 .collect();
48 Ok(Self {
49 vars,
50 learning_rate,
51 })
52 }
53
54 fn learning_rate(&self) -> f64 {
55 self.learning_rate
56 }
57
58 fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
59 for var in self.vars.iter() {
60 if let Some(grad) = grads.get(var) {
61 var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
62 }
63 }
64 Ok(())
65 }
66
67 fn set_learning_rate(&mut self, lr: f64) {
68 self.learning_rate = lr
69 }
70}
71
72impl SGD {
73 pub fn into_inner(self) -> Vec<Var> {
74 self.vars
75 }
76
77 pub fn push(&mut self, var: &Var) {
78 self.vars.push(var.clone())
79 }
80}
81
82#[derive(Clone, Debug)]
83pub struct ParamsAdamW {
84 pub lr: f64,
85 pub beta1: f64,
86 pub beta2: f64,
87 pub eps: f64,
88 pub weight_decay: f64,
89}
90
91impl Default for ParamsAdamW {
92 fn default() -> Self {
93 Self {
94 lr: 0.001,
95 beta1: 0.9,
96 beta2: 0.999,
97 eps: 1e-8,
98 weight_decay: 0.01,
99 }
100 }
101}
102
103#[derive(Debug)]
104struct VarAdamW {
105 var: Var,
106 first_moment: Var,
107 second_moment: Var,
108}
109
110#[derive(Debug)]
111pub struct AdamW {
112 vars: Vec<VarAdamW>,
113 step_t: usize,
114 params: ParamsAdamW,
115}
116
117impl Optimizer for AdamW {
118 type Config = ParamsAdamW;
119
120 fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
121 let vars = vars
122 .into_iter()
123 .filter(|var| var.dtype().is_float())
124 .map(|var| {
125 let dtype = var.dtype();
126 let shape = var.shape();
127 let device = var.device();
128 let first_moment = Var::zeros(shape, dtype, device)?;
129 let second_moment = Var::zeros(shape, dtype, device)?;
130 Ok(VarAdamW {
131 var,
132 first_moment,
133 second_moment,
134 })
135 })
136 .collect::<Result<Vec<_>>>()?;
137 Ok(Self {
138 vars,
139 params,
140 step_t: 0,
141 })
142 }
143
144 fn learning_rate(&self) -> f64 {
145 self.params.lr
146 }
147
148 fn set_learning_rate(&mut self, lr: f64) {
149 self.params.lr = lr
150 }
151
152 fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
153 self.step_t += 1;
154 let lr = self.params.lr;
155 let lambda = self.params.weight_decay;
156 let lr_lambda = lr * lambda;
157 let beta1 = self.params.beta1;
158 let beta2 = self.params.beta2;
159 let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32));
160 let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32));
161 for var in self.vars.iter() {
162 let theta = &var.var;
163 let m = &var.first_moment;
164 let v = &var.second_moment;
165 if let Some(g) = grads.get(theta) {
166 let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?;
170 let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?;
171 let m_hat = (&next_m * scale_m)?;
172 let v_hat = (&next_v * scale_v)?;
173 let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?;
174 let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?;
175 let next_theta = (next_theta - (adjusted_grad * lr)?)?;
176 m.set(&next_m)?;
177 v.set(&next_v)?;
178 theta.set(&next_theta)?;
179 }
180 }
181 Ok(())
182 }
183}
184
185impl AdamW {
186 pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
187 let params = ParamsAdamW {
188 lr: learning_rate,
189 ..ParamsAdamW::default()
190 };
191 Self::new(vars, params)
192 }
193
194 pub fn params(&self) -> &ParamsAdamW {
195 &self.params
196 }
197
198 pub fn set_params(&mut self, params: ParamsAdamW) {
199 self.params = params;
200 }
201}