Skip to main content

mlx_optim/
lib.rs

1//! Optimizers for MLX training.
2//!
3//! Tensors are immutable graph node handles — optimizers create new tensors
4//! for updated parameters rather than mutating in place.
5
6use mlx_core::{Result, Tensor};
7
8/// Optimizer trait: apply one step, returning updated parameters.
9pub trait Optimizer {
10    fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>>;
11}
12
13// ── SGD ──────────────────────────────────────────────────────────────────
14
15/// Stochastic Gradient Descent with optional momentum.
16pub struct Sgd {
17    lr: f32,
18    momentum: f32,
19    velocity: Vec<Tensor>,
20}
21
22impl Sgd {
23    /// Create a new SGD optimizer.
24    ///
25    /// - `lr`: learning rate
26    /// - `momentum`: momentum factor (0.0 = no momentum)
27    pub fn new(lr: f32, momentum: f32) -> Self {
28        Self {
29            lr,
30            momentum,
31            velocity: Vec::new(),
32        }
33    }
34}
35
36impl Optimizer for Sgd {
37    fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>> {
38        if params.len() != grads.len() {
39            return Err(mlx_core::MlxError::InvalidArgument(format!(
40                "params length {} != grads length {}",
41                params.len(),
42                grads.len()
43            )));
44        }
45        // Initialize velocity on first call
46        if self.velocity.is_empty() {
47            self.velocity = params
48                .iter()
49                .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
50                .collect::<Result<Vec<_>>>()?;
51        }
52
53        let lr_scalar = self.lr;
54        let mom = self.momentum;
55
56        let mut new_params = Vec::with_capacity(params.len());
57        let mut new_velocity = Vec::with_capacity(params.len());
58
59        for (i, (p, g)) in params.iter().zip(grads.iter()).enumerate() {
60            if mom == 0.0 {
61                // p_new = p - lr * g
62                let lr_t = scalar_like(lr_scalar, p)?;
63                let update = lr_t.mul(g)?;
64                new_params.push(p.sub(&update)?);
65                new_velocity.push(self.velocity[i].clone());
66            } else {
67                // v = momentum * v + g
68                let mom_t = scalar_like(mom, p)?;
69                let v = mom_t.mul(&self.velocity[i])?.add(g)?;
70                // p_new = p - lr * v
71                let lr_t = scalar_like(lr_scalar, p)?;
72                let update = lr_t.mul(&v)?;
73                new_params.push(p.sub(&update)?);
74                new_velocity.push(v);
75            }
76        }
77
78        self.velocity = new_velocity;
79        Ok(new_params)
80    }
81}
82
83// ── AdamW ────────────────────────────────────────────────────────────────
84
85/// AdamW optimizer (Adam with decoupled weight decay).
86pub struct AdamW {
87    lr: f32,
88    betas: (f32, f32),
89    eps: f32,
90    weight_decay: f32,
91    t: u64,
92    m: Vec<Tensor>,
93    v: Vec<Tensor>,
94}
95
96impl AdamW {
97    pub fn new(lr: f32) -> Self {
98        Self {
99            lr,
100            betas: (0.9, 0.999),
101            eps: 1e-8,
102            weight_decay: 0.01,
103            t: 0,
104            m: Vec::new(),
105            v: Vec::new(),
106        }
107    }
108
109    pub fn betas(mut self, b1: f32, b2: f32) -> Self {
110        self.betas = (b1, b2);
111        self
112    }
113
114    pub fn eps(mut self, eps: f32) -> Self {
115        self.eps = eps;
116        self
117    }
118
119    pub fn weight_decay(mut self, wd: f32) -> Self {
120        self.weight_decay = wd;
121        self
122    }
123}
124
125impl Optimizer for AdamW {
126    fn step(&mut self, params: &[Tensor], grads: &[Tensor]) -> Result<Vec<Tensor>> {
127        if params.len() != grads.len() {
128            return Err(mlx_core::MlxError::InvalidArgument(format!(
129                "params length {} != grads length {}",
130                params.len(),
131                grads.len()
132            )));
133        }
134        // Initialize moments on first call
135        if self.m.is_empty() {
136            self.m = params
137                .iter()
138                .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
139                .collect::<Result<Vec<_>>>()?;
140            self.v = params
141                .iter()
142                .map(|p| Tensor::zeros(p.shape(), p.dtype(), p.device()))
143                .collect::<Result<Vec<_>>>()?;
144        }
145
146        self.t += 1;
147        let (b1, b2) = self.betas;
148        let bc1 = 1.0 - b1.powi(self.t as i32);
149        let bc2 = 1.0 - b2.powi(self.t as i32);
150
151        let mut new_params = Vec::with_capacity(params.len());
152        let mut new_m = Vec::with_capacity(params.len());
153        let mut new_v = Vec::with_capacity(params.len());
154
155        for (i, (p, g)) in params.iter().zip(grads.iter()).enumerate() {
156            // m = β1 * m + (1 - β1) * g
157            let b1_t = scalar_like(b1, p)?;
158            let one_minus_b1 = scalar_like(1.0 - b1, p)?;
159            let m_new = b1_t.mul(&self.m[i])?.add(&one_minus_b1.mul(g)?)?;
160
161            // v = β2 * v + (1 - β2) * g²
162            let b2_t = scalar_like(b2, p)?;
163            let one_minus_b2 = scalar_like(1.0 - b2, p)?;
164            let g_sq = g.mul(g)?;
165            let v_new = b2_t.mul(&self.v[i])?.add(&one_minus_b2.mul(&g_sq)?)?;
166
167            // Bias-corrected estimates
168            let bc1_t = scalar_like(bc1, p)?;
169            let bc2_t = scalar_like(bc2, p)?;
170            let m_hat = m_new.div(&bc1_t)?;
171            let v_hat = v_new.div(&bc2_t)?;
172
173            // p_new = p * (1 - lr * wd) - lr * m_hat / (sqrt(v_hat) + eps)
174            let decay_factor = scalar_like(1.0 - self.lr * self.weight_decay, p)?;
175            let lr_t = scalar_like(self.lr, p)?;
176            let eps_t = scalar_like(self.eps, p)?;
177            let denom = v_hat.sqrt().add(&eps_t)?;
178            let step = lr_t.mul(&m_hat)?.div(&denom)?;
179            let p_new = decay_factor.mul(p)?.sub(&step)?;
180
181            new_params.push(p_new);
182            new_m.push(m_new);
183            new_v.push(v_new);
184        }
185
186        self.m = new_m;
187        self.v = new_v;
188        Ok(new_params)
189    }
190}
191
192// ── Learning Rate Schedulers ──────────────────────────────────────────
193
194/// Trait for learning rate schedulers.
195pub trait LrScheduler {
196    /// Return the learning rate for a given step.
197    fn get_lr(&self, step: u64) -> f32;
198}
199
200/// Step-decay learning rate scheduler.
201///
202/// Decays the learning rate by `gamma` every `step_size` steps:
203/// `lr = base_lr * gamma^(step / step_size)`
204pub struct StepLR {
205    base_lr: f32,
206    step_size: u64,
207    gamma: f32,
208}
209
210impl StepLR {
211    pub fn new(base_lr: f32, step_size: u64, gamma: f32) -> Self {
212        Self {
213            base_lr,
214            step_size,
215            gamma,
216        }
217    }
218}
219
220impl LrScheduler for StepLR {
221    fn get_lr(&self, step: u64) -> f32 {
222        self.base_lr * self.gamma.powi((step / self.step_size) as i32)
223    }
224}
225
226/// Cosine annealing learning rate scheduler.
227///
228/// `lr = eta_min + 0.5 * (base_lr - eta_min) * (1 + cos(pi * step / t_max))`
229pub struct CosineAnnealingLR {
230    base_lr: f32,
231    t_max: u64,
232    eta_min: f32,
233}
234
235impl CosineAnnealingLR {
236    pub fn new(base_lr: f32, t_max: u64, eta_min: f32) -> Self {
237        Self {
238            base_lr,
239            t_max,
240            eta_min,
241        }
242    }
243}
244
245impl LrScheduler for CosineAnnealingLR {
246    fn get_lr(&self, step: u64) -> f32 {
247        self.eta_min
248            + 0.5
249                * (self.base_lr - self.eta_min)
250                * (1.0 + (std::f32::consts::PI * step as f32 / self.t_max as f32).cos())
251    }
252}
253
254/// Create a scalar tensor broadcast to the same shape/dtype/device as `like`.
255fn scalar_like(val: f32, like: &Tensor) -> Result<Tensor> {
256    Tensor::from_f32(&[val], &mlx_core::Shape::scalar(), like.device())?.broadcast_to(like.shape())
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use mlx_core::{Device, Shape};
263
264    fn cpu() -> Device {
265        Device::Cpu
266    }
267
268    fn t(data: &[f32], shape: &[i64]) -> Tensor {
269        Tensor::from_f32(data, &Shape::new(shape.to_vec()), &cpu()).unwrap()
270    }
271
272    #[test]
273    fn test_sgd_no_momentum() {
274        let mut opt = Sgd::new(0.1, 0.0);
275        let p = t(&[1.0, 2.0, 3.0], &[3]);
276        let g = t(&[0.5, 1.0, 1.5], &[3]);
277        let new_p = opt.step(&[p], &[g]).unwrap();
278        let vals = new_p[0].to_vec_f32().unwrap();
279        // p - lr * g = [1 - 0.05, 2 - 0.1, 3 - 0.15]
280        mlx_conformance::assert_allclose(&vals, &[0.95, 1.9, 2.85], 1e-5, 1e-5);
281    }
282
283    #[test]
284    fn test_sgd_with_momentum() {
285        let mut opt = Sgd::new(0.1, 0.9);
286        let p = t(&[1.0, 2.0], &[2]);
287        let g = t(&[1.0, 1.0], &[2]);
288
289        // Step 1: v = 0.9*0 + g = [1,1], p = p - 0.1*v = [0.9, 1.9]
290        let new_p = opt.step(&[p], std::slice::from_ref(&g)).unwrap();
291        let vals1 = new_p[0].to_vec_f32().unwrap();
292        mlx_conformance::assert_allclose(&vals1, &[0.9, 1.9], 1e-5, 1e-5);
293
294        // Step 2: v = 0.9*[1,1] + [1,1] = [1.9,1.9], p = [0.9,1.9] - 0.1*[1.9,1.9] = [0.71, 1.71]
295        let new_p2 = opt
296            .step(std::slice::from_ref(&new_p[0]), std::slice::from_ref(&g))
297            .unwrap();
298        let vals2 = new_p2[0].to_vec_f32().unwrap();
299        mlx_conformance::assert_allclose(&vals2, &[0.71, 1.71], 1e-5, 1e-5);
300    }
301
302    #[test]
303    fn test_adamw_single_step() {
304        let mut opt = AdamW::new(0.001)
305            .betas(0.9, 0.999)
306            .eps(1e-8)
307            .weight_decay(0.01);
308        let p = t(&[1.0, 2.0], &[2]);
309        let g = t(&[0.1, 0.2], &[2]);
310
311        let new_p = opt
312            .step(std::slice::from_ref(&p), std::slice::from_ref(&g))
313            .unwrap();
314        let vals = new_p[0].to_vec_f32().unwrap();
315
316        // Hand-compute step 1:
317        // m = 0.1*(1-0.9)*g = [0.01, 0.02]
318        // v = (1-0.999)*g^2 = [0.00001, 0.00004]
319        // m_hat = m / (1 - 0.9) = [0.1, 0.2]
320        // v_hat = v / (1 - 0.999) = [0.01, 0.04]
321        // decay = 1 - 0.001*0.01 = 0.99999
322        // step = lr * m_hat / (sqrt(v_hat) + eps) = 0.001 * [0.1, 0.2] / ([0.1, 0.2] + 1e-8) ≈ [0.001, 0.001]
323        // p_new = decay*p - step ≈ [0.99999 - 0.001, 1.99998 - 0.001] ≈ [0.99899, 1.99898]
324        let expected_0 = 0.99999 * 1.0 - 0.001 * 0.1 / (0.01f32.sqrt() + 1e-8);
325        let expected_1 = 0.99999 * 2.0 - 0.001 * 0.2 / (0.04f32.sqrt() + 1e-8);
326        mlx_conformance::assert_allclose(&vals, &[expected_0, expected_1], 1e-4, 1e-4);
327    }
328
329    #[test]
330    fn test_adamw_two_steps() {
331        let mut opt = AdamW::new(0.001)
332            .betas(0.9, 0.999)
333            .eps(1e-8)
334            .weight_decay(0.0);
335        let p = t(&[1.0], &[1]);
336        let g = t(&[1.0], &[1]);
337
338        let p1 = opt.step(&[p], std::slice::from_ref(&g)).unwrap();
339        let p2 = opt
340            .step(std::slice::from_ref(&p1[0]), std::slice::from_ref(&g))
341            .unwrap();
342
343        // After 2 steps, parameter should have decreased
344        let v1 = p1[0].to_vec_f32().unwrap()[0];
345        let v2 = p2[0].to_vec_f32().unwrap()[0];
346        assert!(v1 < 1.0, "param should decrease after step 1");
347        assert!(v2 < v1, "param should decrease after step 2");
348    }
349
350    #[test]
351    fn test_step_lr() {
352        let sched = StepLR::new(0.1, 10, 0.5);
353        assert!((sched.get_lr(0) - 0.1).abs() < 1e-6);
354        assert!((sched.get_lr(5) - 0.1).abs() < 1e-6);
355        assert!((sched.get_lr(10) - 0.05).abs() < 1e-6);
356        assert!((sched.get_lr(20) - 0.025).abs() < 1e-6);
357        assert!((sched.get_lr(30) - 0.0125).abs() < 1e-6);
358    }
359
360    #[test]
361    fn test_cosine_annealing_lr() {
362        let sched = CosineAnnealingLR::new(0.1, 100, 0.001);
363        // At step 0: lr = 0.001 + 0.5*0.099*(1+cos(0)) = 0.001 + 0.099 = 0.1
364        assert!((sched.get_lr(0) - 0.1).abs() < 1e-6);
365        // At step t_max/2 (50): lr = 0.001 + 0.5*0.099*(1+cos(pi/2)) ≈ 0.001 + 0.0495 ≈ 0.0505
366        let lr_mid = sched.get_lr(50);
367        assert!((lr_mid - 0.0505).abs() < 1e-3, "mid lr: {lr_mid}");
368        // At step t_max (100): lr = 0.001 + 0.5*0.099*(1+cos(pi)) = 0.001 + 0 = 0.001
369        assert!((sched.get_lr(100) - 0.001).abs() < 1e-6);
370    }
371}