1use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static ADAM_UPDATE_SHADER_SOURCE: &str =
20 include_str!("../shaders/adam_update.metal");
21
22pub fn register(registry: &mut KernelRegistry) {
23 registry.register_source("adam_update_f32", ADAM_UPDATE_SHADER_SOURCE);
24}
25
26#[allow(clippy::too_many_arguments)]
35pub fn dispatch_adam_update_f32(
36 encoder: &mut CommandEncoder,
37 registry: &mut KernelRegistry,
38 device: &metal::DeviceRef,
39 param: &MlxBuffer,
40 grad: &MlxBuffer,
41 m: &MlxBuffer,
42 v: &MlxBuffer,
43 params_buf: &MlxBuffer,
44 meta_buf: &MlxBuffer,
45) -> Result<()> {
46 let n = param.element_count();
47 if n == 0 {
48 return Err(MlxError::InvalidArgument(
49 "adam_update_f32: param must have at least one element".into(),
50 ));
51 }
52 for (label, buf) in [("grad", grad), ("m", m), ("v", v)] {
53 if buf.element_count() != n {
54 return Err(MlxError::InvalidArgument(format!(
55 "adam_update_f32: {label} element count {} != param element count {n}",
56 buf.element_count(),
57 )));
58 }
59 if buf.dtype() != DType::F32 {
60 return Err(MlxError::InvalidArgument(format!(
61 "adam_update_f32: {label} dtype {} not f32",
62 buf.dtype()
63 )));
64 }
65 }
66 if param.dtype() != DType::F32 {
67 return Err(MlxError::InvalidArgument(format!(
68 "adam_update_f32: param dtype {} not f32",
69 param.dtype()
70 )));
71 }
72 if params_buf.byte_len() < 24 {
73 return Err(MlxError::InvalidArgument(format!(
74 "adam_update_f32: params_buf too small (need 24 bytes for 6×f32, got {})",
75 params_buf.byte_len()
76 )));
77 }
78 if meta_buf.byte_len() < 4 {
79 return Err(MlxError::InvalidArgument(format!(
80 "adam_update_f32: meta_buf too small (need 4 bytes for u32, got {})",
81 meta_buf.byte_len()
82 )));
83 }
84
85 let pipeline = registry.get_pipeline("adam_update_f32", device)?;
86 let thread_count = n as u64;
87 let tg_size = std::cmp::min(256, thread_count);
88 encoder.encode(
89 pipeline,
90 &[
91 (0, param),
92 (1, grad),
93 (2, m),
94 (3, v),
95 (4, params_buf),
96 (5, meta_buf),
97 ],
98 MTLSize::new(thread_count, 1, 1),
99 MTLSize::new(tg_size, 1, 1),
100 );
101 Ok(())
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::device::MlxDevice;
108
109 fn adam_cpu(
111 param: &mut [f32],
112 grad: &[f32],
113 m: &mut [f32],
114 v: &mut [f32],
115 lr: f32,
116 beta1: f32,
117 beta2: f32,
118 eps: f32,
119 omb1_t: f32,
120 omb2_t: f32,
121 ) {
122 for i in 0..param.len() {
123 let g = grad[i];
124 let m_new = beta1 * m[i] + (1.0 - beta1) * g;
125 let v_new = beta2 * v[i] + (1.0 - beta2) * g * g;
126 m[i] = m_new;
127 v[i] = v_new;
128 let m_hat = m_new / omb1_t;
129 let v_hat = v_new / omb2_t;
130 param[i] = param[i] - lr * m_hat / (v_hat.sqrt() + eps);
131 }
132 }
133
134 fn run_adam_step(
135 param: &[f32],
136 grad: &[f32],
137 m: &[f32],
138 v: &[f32],
139 lr: f32,
140 beta1: f32,
141 beta2: f32,
142 eps: f32,
143 t: u32,
144 ) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
145 let device = MlxDevice::new().expect("device");
146 let n = param.len();
147 let mut p_buf = device
148 .alloc_buffer(n * 4, DType::F32, vec![n])
149 .expect("alloc param");
150 p_buf
151 .as_mut_slice::<f32>()
152 .unwrap()
153 .copy_from_slice(param);
154 let mut g_buf = device
155 .alloc_buffer(n * 4, DType::F32, vec![n])
156 .expect("alloc grad");
157 g_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(grad);
158 let mut m_buf = device
159 .alloc_buffer(n * 4, DType::F32, vec![n])
160 .expect("alloc m");
161 m_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(m);
162 let mut v_buf = device
163 .alloc_buffer(n * 4, DType::F32, vec![n])
164 .expect("alloc v");
165 v_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(v);
166 let omb1_t = 1.0 - beta1.powi(t as i32);
167 let omb2_t = 1.0 - beta2.powi(t as i32);
168 let mut params_buf = device
169 .alloc_buffer(24, DType::F32, vec![6])
170 .expect("alloc params");
171 params_buf
172 .as_mut_slice::<f32>()
173 .unwrap()
174 .copy_from_slice(&[lr, beta1, beta2, eps, omb1_t, omb2_t]);
175 let mut meta_buf = device
176 .alloc_buffer(4, DType::F32, vec![1])
177 .expect("alloc meta");
178 meta_buf.as_mut_slice::<u32>().unwrap()[0] = n as u32;
179
180 let mut registry = KernelRegistry::new();
181 register(&mut registry);
182 let mut encoder = device.command_encoder().expect("encoder");
183 dispatch_adam_update_f32(
184 &mut encoder,
185 &mut registry,
186 device.metal_device(),
187 &p_buf,
188 &g_buf,
189 &m_buf,
190 &v_buf,
191 ¶ms_buf,
192 &meta_buf,
193 )
194 .expect("dispatch adam");
195 encoder.commit_and_wait().expect("commit");
196 (
197 p_buf.as_slice::<f32>().unwrap().to_vec(),
198 m_buf.as_slice::<f32>().unwrap().to_vec(),
199 v_buf.as_slice::<f32>().unwrap().to_vec(),
200 )
201 }
202
203 fn assert_close_vec(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
204 assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
205 for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
206 let diff = (g - c).abs();
207 let scale = g.abs().max(c.abs()).max(1.0);
208 assert!(
209 diff <= abs_tol || diff / scale <= rel_tol,
210 "{label}: i={i}: gpu={g} cpu={c} diff={diff}"
211 );
212 }
213 }
214
215 #[test]
216 fn adam_step_t1_byte_close_to_cpu() {
217 let n = 64;
218 let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1 - 1.0).collect();
219 let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.013).sin() * 0.5).collect();
220 let m = vec![0f32; n];
221 let v = vec![0f32; n];
222 let lr = 1e-3_f32;
223 let beta1 = 0.9_f32;
224 let beta2 = 0.999_f32;
225 let eps = 1e-8_f32;
226 let (p_gpu, m_gpu, v_gpu) =
227 run_adam_step(¶m, &grad, &m, &v, lr, beta1, beta2, eps, 1);
228 let mut p_cpu = param.clone();
229 let mut m_cpu = m.clone();
230 let mut v_cpu = v.clone();
231 adam_cpu(
232 &mut p_cpu,
233 &grad,
234 &mut m_cpu,
235 &mut v_cpu,
236 lr,
237 beta1,
238 beta2,
239 eps,
240 1.0 - beta1.powi(1),
241 1.0 - beta2.powi(1),
242 );
243 assert_close_vec("adam param t=1", &p_gpu, &p_cpu, 1e-5, 1e-7);
244 assert_close_vec("adam m t=1", &m_gpu, &m_cpu, 1e-5, 1e-7);
245 assert_close_vec("adam v t=1", &v_gpu, &v_cpu, 1e-5, 1e-7);
246 }
247
248 #[test]
249 fn adam_step_t10_with_nontrivial_state() {
250 let n = 32;
253 let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05).collect();
254 let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.011).cos() * 0.3).collect();
255 let m: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
256 let v: Vec<f32> = (0..n).map(|i| (i as f32) * 0.0001 + 0.001).collect();
257 let lr = 5e-4_f32;
258 let beta1 = 0.9_f32;
259 let beta2 = 0.999_f32;
260 let eps = 1e-8_f32;
261 let (p_gpu, m_gpu, v_gpu) =
262 run_adam_step(¶m, &grad, &m, &v, lr, beta1, beta2, eps, 10);
263 let mut p_cpu = param.clone();
264 let mut m_cpu = m.clone();
265 let mut v_cpu = v.clone();
266 adam_cpu(
267 &mut p_cpu,
268 &grad,
269 &mut m_cpu,
270 &mut v_cpu,
271 lr,
272 beta1,
273 beta2,
274 eps,
275 1.0 - beta1.powi(10),
276 1.0 - beta2.powi(10),
277 );
278 assert_close_vec("adam param t=10", &p_gpu, &p_cpu, 1e-5, 1e-7);
279 assert_close_vec("adam m t=10", &m_gpu, &m_cpu, 1e-5, 1e-7);
280 assert_close_vec("adam v t=10", &v_gpu, &v_cpu, 1e-5, 1e-7);
281 }
282
283 #[test]
284 fn adam_zero_grad_leaves_param_unchanged() {
285 let n = 16;
289 let param: Vec<f32> = (0..n).map(|i| (i as f32) - 8.0).collect();
290 let grad = vec![0f32; n];
291 let m = vec![0f32; n];
292 let v = vec![0f32; n];
293 let (p_gpu, m_gpu, v_gpu) =
294 run_adam_step(¶m, &grad, &m, &v, 1e-3, 0.9, 0.999, 1e-8, 1);
295 for (i, (p_in, p_out)) in param.iter().zip(p_gpu.iter()).enumerate() {
298 assert!(
299 (p_in - p_out).abs() < 1e-9,
300 "i={i}: param changed from {p_in} to {p_out}"
301 );
302 }
303 assert!(m_gpu.iter().all(|&x| x == 0.0));
304 assert!(v_gpu.iter().all(|&x| x == 0.0));
305 }
306
307 #[test]
308 fn adam_simple_optimization_converges() {
309 let device = MlxDevice::new().expect("device");
313 let mut p_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("p");
314 p_buf.as_mut_slice::<f32>().unwrap()[0] = 0.0; let mut g_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("g");
316 let m_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("m");
317 let v_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("v");
318 let mut params_buf = device
320 .alloc_buffer(24, DType::F32, vec![6])
321 .expect("params");
322 let mut meta_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
323 meta_buf.as_mut_slice::<u32>().unwrap()[0] = 1u32;
324
325 let lr = 0.1_f32;
326 let beta1 = 0.9_f32;
327 let beta2 = 0.999_f32;
328 let eps = 1e-8_f32;
329
330 let mut registry = KernelRegistry::new();
331 register(&mut registry);
332
333 for step in 1..=200u32 {
334 let x = p_buf.as_slice::<f32>().unwrap()[0];
335 let g = 2.0 * (x - 5.0);
336 g_buf.as_mut_slice::<f32>().unwrap()[0] = g;
337 params_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&[
338 lr,
339 beta1,
340 beta2,
341 eps,
342 1.0 - beta1.powi(step as i32),
343 1.0 - beta2.powi(step as i32),
344 ]);
345 let mut encoder = device.command_encoder().expect("encoder");
346 dispatch_adam_update_f32(
347 &mut encoder,
348 &mut registry,
349 device.metal_device(),
350 &p_buf,
351 &g_buf,
352 &m_buf,
353 &v_buf,
354 ¶ms_buf,
355 &meta_buf,
356 )
357 .unwrap();
358 encoder.commit_and_wait().unwrap();
359 }
360
361 let final_x = p_buf.as_slice::<f32>().unwrap()[0];
362 assert!(
364 (final_x - 5.0).abs() < 0.05,
365 "expected x ≈ 5 after 200 Adam steps; got {final_x}"
366 );
367 }
368
369 #[test]
370 fn adam_rejects_mismatched_sizes() {
371 let device = MlxDevice::new().expect("device");
372 let p = device.alloc_buffer(16, DType::F32, vec![4]).expect("p");
373 let g = device.alloc_buffer(32, DType::F32, vec![8]).expect("g"); let m = device.alloc_buffer(16, DType::F32, vec![4]).expect("m");
375 let v = device.alloc_buffer(16, DType::F32, vec![4]).expect("v");
376 let params = device.alloc_buffer(24, DType::F32, vec![6]).expect("params");
377 let meta = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
378 let mut registry = KernelRegistry::new();
379 register(&mut registry);
380 let mut encoder = device.command_encoder().expect("encoder");
381 let err = dispatch_adam_update_f32(
382 &mut encoder,
383 &mut registry,
384 device.metal_device(),
385 &p,
386 &g,
387 &m,
388 &v,
389 ¶ms,
390 &meta,
391 )
392 .expect_err("must reject mismatched sizes");
393 assert!(format!("{err}").contains("grad element count"));
394 }
395}