1use ghostflow_core::Tensor;
4use crate::optimizer::Optimizer;
5
6pub struct Adam {
8 params: Vec<Tensor>,
9 lr: f32,
10 betas: (f32, f32),
11 eps: f32,
12 weight_decay: f32,
13 m: Vec<Vec<f32>>, v: Vec<Vec<f32>>, t: usize, }
17
18impl Adam {
19 pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
20 let m = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
21 let v = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
22
23 Adam {
24 params,
25 lr,
26 betas: (0.9, 0.999),
27 eps: 1e-8,
28 weight_decay: 0.0,
29 m,
30 v,
31 t: 0,
32 }
33 }
34
35 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
36 self.betas = (beta1, beta2);
37 self
38 }
39
40 pub fn eps(mut self, eps: f32) -> Self {
41 self.eps = eps;
42 self
43 }
44
45 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
46 self.weight_decay = weight_decay;
47 self
48 }
49}
50
51impl Optimizer for Adam {
52 fn step(&mut self) {
53 self.t += 1;
54 let (beta1, beta2) = self.betas;
55
56 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
58 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
59
60 for (i, param) in self.params.iter_mut().enumerate() {
61 if let Some(grad) = param.grad() {
62 let mut grad_data = grad.data_f32();
63 let param_data = param.data_f32();
64
65 if self.weight_decay != 0.0 {
67 for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
68 *g += self.weight_decay * p;
69 }
70 }
71
72 for (j, &g) in grad_data.iter().enumerate() {
74 self.m[i][j] = beta1 * self.m[i][j] + (1.0 - beta1) * g;
75 }
76
77 for (j, &g) in grad_data.iter().enumerate() {
79 self.v[i][j] = beta2 * self.v[i][j] + (1.0 - beta2) * g * g;
80 }
81
82 let new_data: Vec<f32> = param_data.iter()
84 .enumerate()
85 .map(|(j, &p)| {
86 let m_hat = self.m[i][j] / bias_correction1;
87 let v_hat = self.v[i][j] / bias_correction2;
88 p - self.lr * m_hat / (v_hat.sqrt() + self.eps)
89 })
90 .collect();
91
92 *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
93 }
94 }
95 }
96
97 fn zero_grad(&mut self) {
98 for param in &mut self.params {
99 param.zero_grad();
100 }
101 }
102
103 fn get_lr(&self) -> f32 {
104 self.lr
105 }
106
107 fn set_lr(&mut self, lr: f32) {
108 self.lr = lr;
109 }
110
111 fn parameters(&self) -> &[Tensor] {
112 &self.params
113 }
114}
115
116pub struct AdamW {
118 params: Vec<Tensor>,
119 lr: f32,
120 betas: (f32, f32),
121 eps: f32,
122 weight_decay: f32,
123 m: Vec<Vec<f32>>,
124 v: Vec<Vec<f32>>,
125 t: usize,
126}
127
128impl AdamW {
129 pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
130 let m = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
131 let v = params.iter().map(|p| vec![0.0f32; p.numel()]).collect();
132
133 AdamW {
134 params,
135 lr,
136 betas: (0.9, 0.999),
137 eps: 1e-8,
138 weight_decay: 0.01, m,
140 v,
141 t: 0,
142 }
143 }
144
145 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
146 self.betas = (beta1, beta2);
147 self
148 }
149
150 pub fn eps(mut self, eps: f32) -> Self {
151 self.eps = eps;
152 self
153 }
154
155 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
156 self.weight_decay = weight_decay;
157 self
158 }
159}
160
161impl Optimizer for AdamW {
162 fn step(&mut self) {
163 self.t += 1;
164 let (beta1, beta2) = self.betas;
165
166 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
167 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
168
169 for (i, param) in self.params.iter_mut().enumerate() {
170 if let Some(grad) = param.grad() {
171 let grad_data = grad.data_f32();
172 let param_data = param.data_f32();
173
174 for (j, &g) in grad_data.iter().enumerate() {
176 self.m[i][j] = beta1 * self.m[i][j] + (1.0 - beta1) * g;
177 self.v[i][j] = beta2 * self.v[i][j] + (1.0 - beta2) * g * g;
178 }
179
180 let new_data: Vec<f32> = param_data.iter()
182 .enumerate()
183 .map(|(j, &p)| {
184 let m_hat = self.m[i][j] / bias_correction1;
185 let v_hat = self.v[i][j] / bias_correction2;
186
187 let p_decayed = p * (1.0 - self.lr * self.weight_decay);
189
190 p_decayed - self.lr * m_hat / (v_hat.sqrt() + self.eps)
191 })
192 .collect();
193
194 *param = Tensor::from_slice(&new_data, param.dims()).unwrap();
195 }
196 }
197 }
198
199 fn zero_grad(&mut self) {
200 for param in &mut self.params {
201 param.zero_grad();
202 }
203 }
204
205 fn get_lr(&self) -> f32 {
206 self.lr
207 }
208
209 fn set_lr(&mut self, lr: f32) {
210 self.lr = lr;
211 }
212
213 fn parameters(&self) -> &[Tensor] {
214 &self.params
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_adam_step() {
224 let mut param = Tensor::ones(&[3]);
225 param.set_requires_grad(true);
226 param.set_grad(Tensor::full(&[3], 0.1f32));
227
228 let mut adam = Adam::new(vec![param], 0.001);
229
230 for _ in 0..10 {
232 adam.step();
233 }
234
235 let updated = &adam.params[0];
237 assert!(updated.data_f32()[0] < 1.0);
238 }
239}