Skip to main content

mlx_native/ops/
adam_update.rs

1//! Adam optimizer step kernel + Rust dispatch.
2//!
3//! Used by hf2q's ADR-020 Track 2 DWQ-proper training loop (iter 13).
4//! Per-element in-place update of `param`, `m`, `v` per the standard
5//! Adam algorithm.  See `shaders/adam_update.metal` for the math.
6//!
7//! Caller pre-computes the bias-correction denominators
8//! `1 − β1^t` and `1 − β2^t` for the current step `t` and passes
9//! them via the params buffer — keeps the kernel pure-element.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static ADAM_UPDATE_SHADER_SOURCE: &str =
20    include_str!("../shaders/adam_update.metal");
21
22pub fn register(registry: &mut KernelRegistry) {
23    registry.register_source("adam_update_f32", ADAM_UPDATE_SHADER_SOURCE);
24}
25
26/// Encode one Adam optimizer step.
27///
28/// `params_buf` must contain `[lr, beta1, beta2, eps,
29/// (1 − β1^t), (1 − β2^t)]` as f32 (24 bytes).
30/// `meta_buf` must contain `[n_elements]` as u32 (4 bytes).
31///
32/// All four data buffers (`param`, `grad`, `m`, `v`) must have the
33/// same f32 element count.
34#[allow(clippy::too_many_arguments)]
35pub fn dispatch_adam_update_f32(
36    encoder: &mut CommandEncoder,
37    registry: &mut KernelRegistry,
38    device: &metal::DeviceRef,
39    param: &MlxBuffer,
40    grad: &MlxBuffer,
41    m: &MlxBuffer,
42    v: &MlxBuffer,
43    params_buf: &MlxBuffer,
44    meta_buf: &MlxBuffer,
45) -> Result<()> {
46    let n = param.element_count();
47    if n == 0 {
48        return Err(MlxError::InvalidArgument(
49            "adam_update_f32: param must have at least one element".into(),
50        ));
51    }
52    for (label, buf) in [("grad", grad), ("m", m), ("v", v)] {
53        if buf.element_count() != n {
54            return Err(MlxError::InvalidArgument(format!(
55                "adam_update_f32: {label} element count {} != param element count {n}",
56                buf.element_count(),
57            )));
58        }
59        if buf.dtype() != DType::F32 {
60            return Err(MlxError::InvalidArgument(format!(
61                "adam_update_f32: {label} dtype {} not f32",
62                buf.dtype()
63            )));
64        }
65    }
66    if param.dtype() != DType::F32 {
67        return Err(MlxError::InvalidArgument(format!(
68            "adam_update_f32: param dtype {} not f32",
69            param.dtype()
70        )));
71    }
72    if params_buf.byte_len() < 24 {
73        return Err(MlxError::InvalidArgument(format!(
74            "adam_update_f32: params_buf too small (need 24 bytes for 6×f32, got {})",
75            params_buf.byte_len()
76        )));
77    }
78    if meta_buf.byte_len() < 4 {
79        return Err(MlxError::InvalidArgument(format!(
80            "adam_update_f32: meta_buf too small (need 4 bytes for u32, got {})",
81            meta_buf.byte_len()
82        )));
83    }
84
85    let pipeline = registry.get_pipeline("adam_update_f32", device)?;
86    let thread_count = n as u64;
87    let tg_size = std::cmp::min(256, thread_count);
88    encoder.encode(
89        pipeline,
90        &[
91            (0, param),
92            (1, grad),
93            (2, m),
94            (3, v),
95            (4, params_buf),
96            (5, meta_buf),
97        ],
98        MTLSize::new(thread_count, 1, 1),
99        MTLSize::new(tg_size, 1, 1),
100    );
101    Ok(())
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::device::MlxDevice;
108
109    /// CPU oracle — pure-Rust Adam step.
110    fn adam_cpu(
111        param: &mut [f32],
112        grad: &[f32],
113        m: &mut [f32],
114        v: &mut [f32],
115        lr: f32,
116        beta1: f32,
117        beta2: f32,
118        eps: f32,
119        omb1_t: f32,
120        omb2_t: f32,
121    ) {
122        for i in 0..param.len() {
123            let g = grad[i];
124            let m_new = beta1 * m[i] + (1.0 - beta1) * g;
125            let v_new = beta2 * v[i] + (1.0 - beta2) * g * g;
126            m[i] = m_new;
127            v[i] = v_new;
128            let m_hat = m_new / omb1_t;
129            let v_hat = v_new / omb2_t;
130            param[i] = param[i] - lr * m_hat / (v_hat.sqrt() + eps);
131        }
132    }
133
134    fn run_adam_step(
135        param: &[f32],
136        grad: &[f32],
137        m: &[f32],
138        v: &[f32],
139        lr: f32,
140        beta1: f32,
141        beta2: f32,
142        eps: f32,
143        t: u32,
144    ) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
145        let device = MlxDevice::new().expect("device");
146        let n = param.len();
147        let mut p_buf = device
148            .alloc_buffer(n * 4, DType::F32, vec![n])
149            .expect("alloc param");
150        p_buf
151            .as_mut_slice::<f32>()
152            .unwrap()
153            .copy_from_slice(param);
154        let mut g_buf = device
155            .alloc_buffer(n * 4, DType::F32, vec![n])
156            .expect("alloc grad");
157        g_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(grad);
158        let mut m_buf = device
159            .alloc_buffer(n * 4, DType::F32, vec![n])
160            .expect("alloc m");
161        m_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(m);
162        let mut v_buf = device
163            .alloc_buffer(n * 4, DType::F32, vec![n])
164            .expect("alloc v");
165        v_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(v);
166        let omb1_t = 1.0 - beta1.powi(t as i32);
167        let omb2_t = 1.0 - beta2.powi(t as i32);
168        let mut params_buf = device
169            .alloc_buffer(24, DType::F32, vec![6])
170            .expect("alloc params");
171        params_buf
172            .as_mut_slice::<f32>()
173            .unwrap()
174            .copy_from_slice(&[lr, beta1, beta2, eps, omb1_t, omb2_t]);
175        let mut meta_buf = device
176            .alloc_buffer(4, DType::F32, vec![1])
177            .expect("alloc meta");
178        meta_buf.as_mut_slice::<u32>().unwrap()[0] = n as u32;
179
180        let mut registry = KernelRegistry::new();
181        register(&mut registry);
182        let mut encoder = device.command_encoder().expect("encoder");
183        dispatch_adam_update_f32(
184            &mut encoder,
185            &mut registry,
186            device.metal_device(),
187            &p_buf,
188            &g_buf,
189            &m_buf,
190            &v_buf,
191            &params_buf,
192            &meta_buf,
193        )
194        .expect("dispatch adam");
195        encoder.commit_and_wait().expect("commit");
196        (
197            p_buf.as_slice::<f32>().unwrap().to_vec(),
198            m_buf.as_slice::<f32>().unwrap().to_vec(),
199            v_buf.as_slice::<f32>().unwrap().to_vec(),
200        )
201    }
202
203    fn assert_close_vec(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
204        assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
205        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
206            let diff = (g - c).abs();
207            let scale = g.abs().max(c.abs()).max(1.0);
208            assert!(
209                diff <= abs_tol || diff / scale <= rel_tol,
210                "{label}: i={i}: gpu={g} cpu={c} diff={diff}"
211            );
212        }
213    }
214
215    #[test]
216    fn adam_step_t1_byte_close_to_cpu() {
217        let n = 64;
218        let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1 - 1.0).collect();
219        let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.013).sin() * 0.5).collect();
220        let m = vec![0f32; n];
221        let v = vec![0f32; n];
222        let lr = 1e-3_f32;
223        let beta1 = 0.9_f32;
224        let beta2 = 0.999_f32;
225        let eps = 1e-8_f32;
226        let (p_gpu, m_gpu, v_gpu) =
227            run_adam_step(&param, &grad, &m, &v, lr, beta1, beta2, eps, 1);
228        let mut p_cpu = param.clone();
229        let mut m_cpu = m.clone();
230        let mut v_cpu = v.clone();
231        adam_cpu(
232            &mut p_cpu,
233            &grad,
234            &mut m_cpu,
235            &mut v_cpu,
236            lr,
237            beta1,
238            beta2,
239            eps,
240            1.0 - beta1.powi(1),
241            1.0 - beta2.powi(1),
242        );
243        assert_close_vec("adam param t=1", &p_gpu, &p_cpu, 1e-5, 1e-7);
244        assert_close_vec("adam m t=1", &m_gpu, &m_cpu, 1e-5, 1e-7);
245        assert_close_vec("adam v t=1", &v_gpu, &v_cpu, 1e-5, 1e-7);
246    }
247
248    #[test]
249    fn adam_step_t10_with_nontrivial_state() {
250        // Simulate the state we'd have after 10 prior steps by
251        // initializing m, v to non-zero values.
252        let n = 32;
253        let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05).collect();
254        let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.011).cos() * 0.3).collect();
255        let m: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
256        let v: Vec<f32> = (0..n).map(|i| (i as f32) * 0.0001 + 0.001).collect();
257        let lr = 5e-4_f32;
258        let beta1 = 0.9_f32;
259        let beta2 = 0.999_f32;
260        let eps = 1e-8_f32;
261        let (p_gpu, m_gpu, v_gpu) =
262            run_adam_step(&param, &grad, &m, &v, lr, beta1, beta2, eps, 10);
263        let mut p_cpu = param.clone();
264        let mut m_cpu = m.clone();
265        let mut v_cpu = v.clone();
266        adam_cpu(
267            &mut p_cpu,
268            &grad,
269            &mut m_cpu,
270            &mut v_cpu,
271            lr,
272            beta1,
273            beta2,
274            eps,
275            1.0 - beta1.powi(10),
276            1.0 - beta2.powi(10),
277        );
278        assert_close_vec("adam param t=10", &p_gpu, &p_cpu, 1e-5, 1e-7);
279        assert_close_vec("adam m t=10", &m_gpu, &m_cpu, 1e-5, 1e-7);
280        assert_close_vec("adam v t=10", &v_gpu, &v_cpu, 1e-5, 1e-7);
281    }
282
283    #[test]
284    fn adam_zero_grad_leaves_param_unchanged() {
285        // With grad = 0 and m = v = 0, the update is 0/eps = 0
286        // (within fp32) → param unchanged.  Confirms the zero-grad
287        // optimization-fixed-point.
288        let n = 16;
289        let param: Vec<f32> = (0..n).map(|i| (i as f32) - 8.0).collect();
290        let grad = vec![0f32; n];
291        let m = vec![0f32; n];
292        let v = vec![0f32; n];
293        let (p_gpu, m_gpu, v_gpu) =
294            run_adam_step(&param, &grad, &m, &v, 1e-3, 0.9, 0.999, 1e-8, 1);
295        // m, v stay 0 after one step with grad=0; param: 0/(0+eps) = 0
296        // → no change.  Allow tiny eps-noise but bit-exact in practice.
297        for (i, (p_in, p_out)) in param.iter().zip(p_gpu.iter()).enumerate() {
298            assert!(
299                (p_in - p_out).abs() < 1e-9,
300                "i={i}: param changed from {p_in} to {p_out}"
301            );
302        }
303        assert!(m_gpu.iter().all(|&x| x == 0.0));
304        assert!(v_gpu.iter().all(|&x| x == 0.0));
305    }
306
307    #[test]
308    fn adam_simple_optimization_converges() {
309        // Optimize a single-element f(x) = (x - 5)² by running Adam
310        // for many steps.  Gradient = 2·(x − 5).  After enough steps
311        // x should converge to 5.
312        let device = MlxDevice::new().expect("device");
313        let mut p_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("p");
314        p_buf.as_mut_slice::<f32>().unwrap()[0] = 0.0; // start at 0
315        let mut g_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("g");
316        let m_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("m");
317        let v_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("v");
318        // alloc_buffer is zero-fill.
319        let mut params_buf = device
320            .alloc_buffer(24, DType::F32, vec![6])
321            .expect("params");
322        let mut meta_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
323        meta_buf.as_mut_slice::<u32>().unwrap()[0] = 1u32;
324
325        let lr = 0.1_f32;
326        let beta1 = 0.9_f32;
327        let beta2 = 0.999_f32;
328        let eps = 1e-8_f32;
329
330        let mut registry = KernelRegistry::new();
331        register(&mut registry);
332
333        for step in 1..=200u32 {
334            let x = p_buf.as_slice::<f32>().unwrap()[0];
335            let g = 2.0 * (x - 5.0);
336            g_buf.as_mut_slice::<f32>().unwrap()[0] = g;
337            params_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&[
338                lr,
339                beta1,
340                beta2,
341                eps,
342                1.0 - beta1.powi(step as i32),
343                1.0 - beta2.powi(step as i32),
344            ]);
345            let mut encoder = device.command_encoder().expect("encoder");
346            dispatch_adam_update_f32(
347                &mut encoder,
348                &mut registry,
349                device.metal_device(),
350                &p_buf,
351                &g_buf,
352                &m_buf,
353                &v_buf,
354                &params_buf,
355                &meta_buf,
356            )
357            .unwrap();
358            encoder.commit_and_wait().unwrap();
359        }
360
361        let final_x = p_buf.as_slice::<f32>().unwrap()[0];
362        // After 200 steps with lr=0.1 on f(x)=(x−5)², x should be near 5.
363        assert!(
364            (final_x - 5.0).abs() < 0.05,
365            "expected x ≈ 5 after 200 Adam steps; got {final_x}"
366        );
367    }
368
369    #[test]
370    fn adam_rejects_mismatched_sizes() {
371        let device = MlxDevice::new().expect("device");
372        let p = device.alloc_buffer(16, DType::F32, vec![4]).expect("p");
373        let g = device.alloc_buffer(32, DType::F32, vec![8]).expect("g"); // wrong size
374        let m = device.alloc_buffer(16, DType::F32, vec![4]).expect("m");
375        let v = device.alloc_buffer(16, DType::F32, vec![4]).expect("v");
376        let params = device.alloc_buffer(24, DType::F32, vec![6]).expect("params");
377        let meta = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
378        let mut registry = KernelRegistry::new();
379        register(&mut registry);
380        let mut encoder = device.command_encoder().expect("encoder");
381        let err = dispatch_adam_update_f32(
382            &mut encoder,
383            &mut registry,
384            device.metal_device(),
385            &p,
386            &g,
387            &m,
388            &v,
389            &params,
390            &meta,
391        )
392        .expect_err("must reject mismatched sizes");
393        assert!(format!("{err}").contains("grad element count"));
394    }
395}