use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static ADAM_UPDATE_SHADER_SOURCE: &str =
include_str!("../shaders/adam_update.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("adam_update_f32", ADAM_UPDATE_SHADER_SOURCE);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_adam_update_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
param: &MlxBuffer,
grad: &MlxBuffer,
m: &MlxBuffer,
v: &MlxBuffer,
params_buf: &MlxBuffer,
meta_buf: &MlxBuffer,
) -> Result<()> {
let n = param.element_count();
if n == 0 {
return Err(MlxError::InvalidArgument(
"adam_update_f32: param must have at least one element".into(),
));
}
for (label, buf) in [("grad", grad), ("m", m), ("v", v)] {
if buf.element_count() != n {
return Err(MlxError::InvalidArgument(format!(
"adam_update_f32: {label} element count {} != param element count {n}",
buf.element_count(),
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"adam_update_f32: {label} dtype {} not f32",
buf.dtype()
)));
}
}
if param.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"adam_update_f32: param dtype {} not f32",
param.dtype()
)));
}
if params_buf.byte_len() < 24 {
return Err(MlxError::InvalidArgument(format!(
"adam_update_f32: params_buf too small (need 24 bytes for 6×f32, got {})",
params_buf.byte_len()
)));
}
if meta_buf.byte_len() < 4 {
return Err(MlxError::InvalidArgument(format!(
"adam_update_f32: meta_buf too small (need 4 bytes for u32, got {})",
meta_buf.byte_len()
)));
}
let pipeline = registry.get_pipeline("adam_update_f32", device)?;
let thread_count = n as u64;
let tg_size = std::cmp::min(256, thread_count);
encoder.encode(
pipeline,
&[
(0, param),
(1, grad),
(2, m),
(3, v),
(4, params_buf),
(5, meta_buf),
],
MTLSize::new(thread_count, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
fn adam_cpu(
param: &mut [f32],
grad: &[f32],
m: &mut [f32],
v: &mut [f32],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
omb1_t: f32,
omb2_t: f32,
) {
for i in 0..param.len() {
let g = grad[i];
let m_new = beta1 * m[i] + (1.0 - beta1) * g;
let v_new = beta2 * v[i] + (1.0 - beta2) * g * g;
m[i] = m_new;
v[i] = v_new;
let m_hat = m_new / omb1_t;
let v_hat = v_new / omb2_t;
param[i] = param[i] - lr * m_hat / (v_hat.sqrt() + eps);
}
}
fn run_adam_step(
param: &[f32],
grad: &[f32],
m: &[f32],
v: &[f32],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
t: u32,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let device = MlxDevice::new().expect("device");
let n = param.len();
let mut p_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc param");
p_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(param);
let mut g_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc grad");
g_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(grad);
let mut m_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc m");
m_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(m);
let mut v_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc v");
v_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(v);
let omb1_t = 1.0 - beta1.powi(t as i32);
let omb2_t = 1.0 - beta2.powi(t as i32);
let mut params_buf = device
.alloc_buffer(24, DType::F32, vec![6])
.expect("alloc params");
params_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[lr, beta1, beta2, eps, omb1_t, omb2_t]);
let mut meta_buf = device
.alloc_buffer(4, DType::F32, vec![1])
.expect("alloc meta");
meta_buf.as_mut_slice::<u32>().unwrap()[0] = n as u32;
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("encoder");
dispatch_adam_update_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&p_buf,
&g_buf,
&m_buf,
&v_buf,
¶ms_buf,
&meta_buf,
)
.expect("dispatch adam");
encoder.commit_and_wait().expect("commit");
(
p_buf.as_slice::<f32>().unwrap().to_vec(),
m_buf.as_slice::<f32>().unwrap().to_vec(),
v_buf.as_slice::<f32>().unwrap().to_vec(),
)
}
fn assert_close_vec(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
let diff = (g - c).abs();
let scale = g.abs().max(c.abs()).max(1.0);
assert!(
diff <= abs_tol || diff / scale <= rel_tol,
"{label}: i={i}: gpu={g} cpu={c} diff={diff}"
);
}
}
#[test]
fn adam_step_t1_byte_close_to_cpu() {
let n = 64;
let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1 - 1.0).collect();
let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.013).sin() * 0.5).collect();
let m = vec![0f32; n];
let v = vec![0f32; n];
let lr = 1e-3_f32;
let beta1 = 0.9_f32;
let beta2 = 0.999_f32;
let eps = 1e-8_f32;
let (p_gpu, m_gpu, v_gpu) =
run_adam_step(¶m, &grad, &m, &v, lr, beta1, beta2, eps, 1);
let mut p_cpu = param.clone();
let mut m_cpu = m.clone();
let mut v_cpu = v.clone();
adam_cpu(
&mut p_cpu,
&grad,
&mut m_cpu,
&mut v_cpu,
lr,
beta1,
beta2,
eps,
1.0 - beta1.powi(1),
1.0 - beta2.powi(1),
);
assert_close_vec("adam param t=1", &p_gpu, &p_cpu, 1e-5, 1e-7);
assert_close_vec("adam m t=1", &m_gpu, &m_cpu, 1e-5, 1e-7);
assert_close_vec("adam v t=1", &v_gpu, &v_cpu, 1e-5, 1e-7);
}
#[test]
fn adam_step_t10_with_nontrivial_state() {
let n = 32;
let param: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05).collect();
let grad: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.011).cos() * 0.3).collect();
let m: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
let v: Vec<f32> = (0..n).map(|i| (i as f32) * 0.0001 + 0.001).collect();
let lr = 5e-4_f32;
let beta1 = 0.9_f32;
let beta2 = 0.999_f32;
let eps = 1e-8_f32;
let (p_gpu, m_gpu, v_gpu) =
run_adam_step(¶m, &grad, &m, &v, lr, beta1, beta2, eps, 10);
let mut p_cpu = param.clone();
let mut m_cpu = m.clone();
let mut v_cpu = v.clone();
adam_cpu(
&mut p_cpu,
&grad,
&mut m_cpu,
&mut v_cpu,
lr,
beta1,
beta2,
eps,
1.0 - beta1.powi(10),
1.0 - beta2.powi(10),
);
assert_close_vec("adam param t=10", &p_gpu, &p_cpu, 1e-5, 1e-7);
assert_close_vec("adam m t=10", &m_gpu, &m_cpu, 1e-5, 1e-7);
assert_close_vec("adam v t=10", &v_gpu, &v_cpu, 1e-5, 1e-7);
}
#[test]
fn adam_zero_grad_leaves_param_unchanged() {
let n = 16;
let param: Vec<f32> = (0..n).map(|i| (i as f32) - 8.0).collect();
let grad = vec![0f32; n];
let m = vec![0f32; n];
let v = vec![0f32; n];
let (p_gpu, m_gpu, v_gpu) =
run_adam_step(¶m, &grad, &m, &v, 1e-3, 0.9, 0.999, 1e-8, 1);
for (i, (p_in, p_out)) in param.iter().zip(p_gpu.iter()).enumerate() {
assert!(
(p_in - p_out).abs() < 1e-9,
"i={i}: param changed from {p_in} to {p_out}"
);
}
assert!(m_gpu.iter().all(|&x| x == 0.0));
assert!(v_gpu.iter().all(|&x| x == 0.0));
}
#[test]
fn adam_simple_optimization_converges() {
let device = MlxDevice::new().expect("device");
let mut p_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("p");
p_buf.as_mut_slice::<f32>().unwrap()[0] = 0.0; let mut g_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("g");
let m_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("m");
let v_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("v");
let mut params_buf = device
.alloc_buffer(24, DType::F32, vec![6])
.expect("params");
let mut meta_buf = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
meta_buf.as_mut_slice::<u32>().unwrap()[0] = 1u32;
let lr = 0.1_f32;
let beta1 = 0.9_f32;
let beta2 = 0.999_f32;
let eps = 1e-8_f32;
let mut registry = KernelRegistry::new();
register(&mut registry);
for step in 1..=200u32 {
let x = p_buf.as_slice::<f32>().unwrap()[0];
let g = 2.0 * (x - 5.0);
g_buf.as_mut_slice::<f32>().unwrap()[0] = g;
params_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&[
lr,
beta1,
beta2,
eps,
1.0 - beta1.powi(step as i32),
1.0 - beta2.powi(step as i32),
]);
let mut encoder = device.command_encoder().expect("encoder");
dispatch_adam_update_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&p_buf,
&g_buf,
&m_buf,
&v_buf,
¶ms_buf,
&meta_buf,
)
.unwrap();
encoder.commit_and_wait().unwrap();
}
let final_x = p_buf.as_slice::<f32>().unwrap()[0];
assert!(
(final_x - 5.0).abs() < 0.05,
"expected x ≈ 5 after 200 Adam steps; got {final_x}"
);
}
#[test]
fn adam_rejects_mismatched_sizes() {
let device = MlxDevice::new().expect("device");
let p = device.alloc_buffer(16, DType::F32, vec![4]).expect("p");
let g = device.alloc_buffer(32, DType::F32, vec![8]).expect("g"); let m = device.alloc_buffer(16, DType::F32, vec![4]).expect("m");
let v = device.alloc_buffer(16, DType::F32, vec![4]).expect("v");
let params = device.alloc_buffer(24, DType::F32, vec![6]).expect("params");
let meta = device.alloc_buffer(4, DType::F32, vec![1]).expect("meta");
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("encoder");
let err = dispatch_adam_update_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&p,
&g,
&m,
&v,
¶ms,
&meta,
)
.expect_err("must reject mismatched sizes");
assert!(format!("{err}").contains("grad element count"));
}
}