1use mlx_core::{Result, Tensor};
7
8pub trait Optimizer {
10 fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>>;
11}
12
13pub struct Sgd {
17 lr: f32,
18 momentum: f32,
19 velocity: Vec<Tensor>,
20}
21
22impl Sgd {
23 pub fn new(lr: f32, momentum: f32) -> Self {
28 Self {
29 lr,
30 momentum,
31 velocity: Vec::new(),
32 }
33 }
34}
35
36impl Optimizer for Sgd {
37 fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>> {
38 if params.len() != grads.len() {
39 return Err(mlx_core::MlxError::InvalidArgument(format!(
40 "params length {} != grads length {}",
41 params.len(),
42 grads.len()
43 )));
44 }
45 if self.velocity.is_empty() {
47 self.velocity = params
48 .iter()
49 .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
50 .collect::<Result<Vec<_>>>()?;
51 }
52
53 let lr_scalar = self.lr;
54 let mom = self.momentum;
55
56 let mut new_params = Vec::with_capacity(params.len());
57 let mut new_velocity = Vec::with_capacity(params.len());
58
59 for (i, (p, g)) in params.iter().zip(grads.iter()).enumerate() {
60 if mom == 0.0 {
61 let lr_t = scalar_like(lr_scalar, p)?;
63 let update = lr_t.mul(g)?;
64 new_params.push(p.sub(&update)?);
65 new_velocity.push(self.velocity[i].clone());
66 } else {
67 let mom_t = scalar_like(mom, p)?;
69 let v = mom_t.mul(&self.velocity[i])?.add(g)?;
70 let lr_t = scalar_like(lr_scalar, p)?;
72 let update = lr_t.mul(&v)?;
73 new_params.push(p.sub(&update)?);
74 new_velocity.push(v);
75 }
76 }
77
78 self.velocity = new_velocity;
79 Ok(new_params)
80 }
81}
82
83pub struct AdamW {
87 lr: f32,
88 betas: (f32, f32),
89 eps: f32,
90 weight_decay: f32,
91 t: u64,
92 m: Vec<Tensor>,
93 v: Vec<Tensor>,
94}
95
96impl AdamW {
97 pub fn new(lr: f32) -> Self {
98 Self {
99 lr,
100 betas: (0.9, 0.999),
101 eps: 1e-8,
102 weight_decay: 0.01,
103 t: 0,
104 m: Vec::new(),
105 v: Vec::new(),
106 }
107 }
108
109 pub fn betas(mut self, b1: f32, b2: f32) -> Self {
110 self.betas = (b1, b2);
111 self
112 }
113
114 pub fn eps(mut self, eps: f32) -> Self {
115 self.eps = eps;
116 self
117 }
118
119 pub fn weight_decay(mut self, wd: f32) -> Self {
120 self.weight_decay = wd;
121 self
122 }
123}
124
125impl Optimizer for AdamW {
126 fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>> {
127 if params.len() != grads.len() {
128 return Err(mlx_core::MlxError::InvalidArgument(format!(
129 "params length {} != grads length {}",
130 params.len(),
131 grads.len()
132 )));
133 }
134 if self.m.is_empty() {
136 self.m = params
137 .iter()
138 .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
139 .collect::<Result<Vec<_>>>()?;
140 self.v = params
141 .iter()
142 .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
143 .collect::<Result<Vec<_>>>()?;
144 }
145
146 self.t += 1;
147 let (b1, b2) = self.betas;
148 let bc1 = 1.0 - b1.powi(self.t as i32);
149 let bc2 = 1.0 - b2.powi(self.t as i32);
150
151 let mut new_params = Vec::with_capacity(params.len());
152 let mut new_m = Vec::with_capacity(params.len());
153 let mut new_v = Vec::with_capacity(params.len());
154
155 for (i, (p, g)) in params.iter().zip(grads.iter()).enumerate() {
156 let b1_t = scalar_like(b1, p)?;
158 let one_minus_b1 = scalar_like(1.0 - b1, p)?;
159 let m_new = b1_t.mul(&self.m[i])?.add(&one_minus_b1.mul(g)?)?;
160
161 let b2_t = scalar_like(b2, p)?;
163 let one_minus_b2 = scalar_like(1.0 - b2, p)?;
164 let g_sq = g.mul(g)?;
165 let v_new = b2_t.mul(&self.v[i])?.add(&one_minus_b2.mul(&g_sq)?)?;
166
167 let bc1_t = scalar_like(bc1, p)?;
169 let bc2_t = scalar_like(bc2, p)?;
170 let m_hat = m_new.div(&bc1_t)?;
171 let v_hat = v_new.div(&bc2_t)?;
172
173 let decay_factor = scalar_like(1.0 - self.lr * self.weight_decay, p)?;
175 let lr_t = scalar_like(self.lr, p)?;
176 let eps_t = scalar_like(self.eps, p)?;
177 let denom = v_hat.sqrt().add(&eps_t)?;
178 let step = lr_t.mul(&m_hat)?.div(&denom)?;
179 let p_new = decay_factor.mul(p)?.sub(&step)?;
180
181 new_params.push(p_new);
182 new_m.push(m_new);
183 new_v.push(v_new);
184 }
185
186 self.m = new_m;
187 self.v = new_v;
188 Ok(new_params)
189 }
190}
191
192pub trait LrScheduler {
196 fn get_lr(&self, step: u64) -> f32;
198}
199
200pub struct StepLR {
205 base_lr: f32,
206 step_size: u64,
207 gamma: f32,
208}
209
210impl StepLR {
211 pub fn new(base_lr: f32, step_size: u64, gamma: f32) -> Self {
212 Self {
213 base_lr,
214 step_size,
215 gamma,
216 }
217 }
218}
219
220impl LrScheduler for StepLR {
221 fn get_lr(&self, step: u64) -> f32 {
222 self.base_lr * self.gamma.powi((step / self.step_size) as i32)
223 }
224}
225
226pub struct CosineAnnealingLR {
230 base_lr: f32,
231 t_max: u64,
232 eta_min: f32,
233}
234
235impl CosineAnnealingLR {
236 pub fn new(base_lr: f32, t_max: u64, eta_min: f32) -> Self {
237 Self {
238 base_lr,
239 t_max,
240 eta_min,
241 }
242 }
243}
244
245impl LrScheduler for CosineAnnealingLR {
246 fn get_lr(&self, step: u64) -> f32 {
247 self.eta_min
248 + 0.5
249 * (self.base_lr - self.eta_min)
250 * (1.0 + (std::f32::consts::PI * step as f32 / self.t_max as f32).cos())
251 }
252}
253
254fn scalar_like(val: f32, like: &Tensor) -> Result<Tensor> {
256 Tensor::from_f32(&[val], &mlx_core::Shape::scalar(), like.device())?.broadcast_to(like.shape())
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use mlx_core::{Device, Shape};
263
264 fn cpu() -> Device {
265 Device::Cpu
266 }
267
268 fn t(data: &[f32], shape: &[i64]) -> Tensor {
269 Tensor::from_f32(data, &Shape::new(shape.to_vec()), &cpu()).unwrap()
270 }
271
272 #[test]
273 fn test_sgd_no_momentum() {
274 let mut opt = Sgd::new(0.1, 0.0);
275 let p = t(&[1.0, 2.0, 3.0], &[3]);
276 let g = t(&[0.5, 1.0, 1.5], &[3]);
277 let new_p = opt.step(&[p], &[g]).unwrap();
278 let vals = new_p[0].to_vec_f32().unwrap();
279 mlx_conformance::assert_allclose(&vals, &[0.95, 1.9, 2.85], 1e-5, 1e-5);
281 }
282
283 #[test]
284 fn test_sgd_with_momentum() {
285 let mut opt = Sgd::new(0.1, 0.9);
286 let p = t(&[1.0, 2.0], &[2]);
287 let g = t(&[1.0, 1.0], &[2]);
288
289 let new_p = opt.step(&[p], std::slice::from_ref(&g)).unwrap();
291 let vals1 = new_p[0].to_vec_f32().unwrap();
292 mlx_conformance::assert_allclose(&vals1, &[0.9, 1.9], 1e-5, 1e-5);
293
294 let new_p2 = opt
296 .step(std::slice::from_ref(&new_p[0]), std::slice::from_ref(&g))
297 .unwrap();
298 let vals2 = new_p2[0].to_vec_f32().unwrap();
299 mlx_conformance::assert_allclose(&vals2, &[0.71, 1.71], 1e-5, 1e-5);
300 }
301
302 #[test]
303 fn test_adamw_single_step() {
304 let mut opt = AdamW::new(0.001)
305 .betas(0.9, 0.999)
306 .eps(1e-8)
307 .weight_decay(0.01);
308 let p = t(&[1.0, 2.0], &[2]);
309 let g = t(&[0.1, 0.2], &[2]);
310
311 let new_p = opt
312 .step(std::slice::from_ref(&p), std::slice::from_ref(&g))
313 .unwrap();
314 let vals = new_p[0].to_vec_f32().unwrap();
315
316 let expected_0 = 0.99999 * 1.0 - 0.001 * 0.1 / (0.01f32.sqrt() + 1e-8);
325 let expected_1 = 0.99999 * 2.0 - 0.001 * 0.2 / (0.04f32.sqrt() + 1e-8);
326 mlx_conformance::assert_allclose(&vals, &[expected_0, expected_1], 1e-4, 1e-4);
327 }
328
329 #[test]
330 fn test_adamw_two_steps() {
331 let mut opt = AdamW::new(0.001)
332 .betas(0.9, 0.999)
333 .eps(1e-8)
334 .weight_decay(0.0);
335 let p = t(&[1.0], &[1]);
336 let g = t(&[1.0], &[1]);
337
338 let p1 = opt.step(&[p], std::slice::from_ref(&g)).unwrap();
339 let p2 = opt
340 .step(std::slice::from_ref(&p1[0]), std::slice::from_ref(&g))
341 .unwrap();
342
343 let v1 = p1[0].to_vec_f32().unwrap()[0];
345 let v2 = p2[0].to_vec_f32().unwrap()[0];
346 assert!(v1 < 1.0, "param should decrease after step 1");
347 assert!(v2 < v1, "param should decrease after step 2");
348 }
349
350 #[test]
351 fn test_step_lr() {
352 let sched = StepLR::new(0.1, 10, 0.5);
353 assert!((sched.get_lr(0) - 0.1).abs() < 1e-6);
354 assert!((sched.get_lr(5) - 0.1).abs() < 1e-6);
355 assert!((sched.get_lr(10) - 0.05).abs() < 1e-6);
356 assert!((sched.get_lr(20) - 0.025).abs() < 1e-6);
357 assert!((sched.get_lr(30) - 0.0125).abs() < 1e-6);
358 }
359
360 #[test]
361 fn test_cosine_annealing_lr() {
362 let sched = CosineAnnealingLR::new(0.1, 100, 0.001);
363 assert!((sched.get_lr(0) - 0.1).abs() < 1e-6);
365 let lr_mid = sched.get_lr(50);
367 assert!((lr_mid - 0.0505).abs() < 1e-3, "mid lr: {lr_mid}");
368 assert!((sched.get_lr(100) - 0.001).abs() < 1e-6);
370 }
371}