use mlxrs::{
Array,
ops::{
arithmetic::{add, exp, multiply, power, sin, square, tanh},
linalg_basic::{inner, matmul},
reduction::sum,
shape::contiguous,
},
transforms::{Closure, async_eval, checkpoint, custom_vjp, eval, grad, jvp, value_and_grad, vjp},
};
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() <= eps
}
#[test]
fn closure_construction_succeeds_and_round_trips() {
let cls = Closure::new(|xs: &[Array]| Ok(vec![square(&xs[0])?])).unwrap();
assert!(!cls.as_raw().ctx.is_null(), "Closure ctx must be non-null");
drop(cls); }
#[test]
fn closure_drop_releases_ffi_handle() {
let baseline = mlxrs::memory::peak_memory().unwrap();
for _ in 0..10 {
let cls = Closure::new(|xs: &[Array]| Ok(vec![square(&xs[0]).unwrap()])).unwrap();
drop(cls);
}
let after = mlxrs::memory::peak_memory().unwrap();
assert!(after >= baseline, "peak_memory must be monotonic");
}
#[test]
fn closure_constructor_failure_does_not_double_free_payload() {
for i in 0..64 {
let captured = i as f32;
let cls = Closure::new(move |xs: &[Array]| {
let s = square(&xs[0])?;
let scalar = Array::full::<f32>(&[0i32; 0], captured)?;
Ok(vec![multiply(&s, &scalar)?])
})
.unwrap();
assert!(!cls.as_raw().ctx.is_null());
drop(cls);
}
for i in 0..32 {
let captured = i as f32;
let f = custom_vjp(
move |xs| Ok(vec![square(&xs[0])?]),
move |primals, _cot, _outputs| {
let dims = primals[0].shape();
Ok(vec![Array::full::<f32>(&&dims[..], captured)?])
},
)
.unwrap();
let g = grad(f, &[0]).unwrap();
let x = Array::full::<f32>(&[0i32; 0], 2.0).unwrap();
let mut grads = g(&[x]).unwrap();
assert!(approx_eq(grads[0].item::<f32>().unwrap(), captured, 1e-5));
}
}
#[test]
fn closure_outlives_construction_scope() {
let vag = {
let mult = 2.0_f32;
value_and_grad(
move |xs| {
let s = square(&xs[0])?;
let scalar = Array::full::<f32>(&[0i32; 0], mult)?;
Ok(vec![multiply(&s, &scalar)?])
},
&[0],
)
.unwrap()
}; let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let (mut vals, mut grads) = vag(&[x]).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 18.0, 1e-5));
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 12.0, 1e-5));
}
#[test]
fn value_and_grad_rejects_empty_argnums() {
let r = value_and_grad(|xs| Ok(vec![square(&xs[0])?]), &[]);
let err = r.err().expect("empty argnums must be rejected");
match err {
mlxrs::Error::EmptyInput(p) => {
assert!(
p.context().contains("value_and_grad") && p.context().contains("argnums"),
"context names the argnums site: {}",
p.context()
);
}
other => panic!("expected Error::EmptyInput for empty argnums, got {other:?}"),
}
let r = grad(|xs| Ok(vec![square(&xs[0])?]), &[]);
let err = r.err().expect("empty argnums must be rejected by grad");
match err {
mlxrs::Error::EmptyInput(p) => {
assert!(
p.context().contains("argnums"),
"grad-delegated context still names argnums: {}",
p.context()
);
}
other => panic!("expected Error::EmptyInput from grad delegation, got {other:?}"),
}
}
#[test]
fn value_and_grad_simple_quadratic() {
let vag = value_and_grad(|xs| Ok(vec![square(&xs[0])?]), &[0]).unwrap();
let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let (mut vals, mut grads) = vag(&[x]).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 9.0, 1e-5));
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 6.0, 1e-5));
}
#[test]
fn value_and_grad_multivariate() {
let vag = value_and_grad(
|xs| {
let xs0 = square(&xs[0])?; let three = Array::full::<f32>(&[0i32; 0], 3.0)?;
let ys3 = power(&xs[1], &three)?; Ok(vec![add(&xs0, &ys3)?])
},
&[0, 1],
)
.unwrap();
let x = Array::full::<f32>(&[0i32; 0], 2.0).unwrap();
let y = Array::full::<f32>(&[0i32; 0], 1.0).unwrap();
let (_vals, mut grads) = vag(&[x, y]).unwrap();
assert_eq!(grads.len(), 2);
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 4.0, 1e-5));
assert!(approx_eq(grads[1].item::<f32>().unwrap(), 3.0, 1e-5));
}
#[test]
fn grad_composition_yields_second_derivative() {
let g = grad(
|xs| {
let three = Array::full::<f32>(&[0i32; 0], 3.0)?;
Ok(vec![power(&xs[0], &three)?])
},
&[0],
)
.unwrap();
let gg = grad(move |xs| g(xs), &[0]).unwrap();
let x = Array::full::<f32>(&[0i32; 0], 2.0).unwrap();
let mut grads = gg(&[x]).unwrap();
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 12.0, 1e-4));
}
#[test]
fn closure_user_error_propagates_through_grad() {
use mlxrs::Error;
let g = grad(
|_xs: &[Array]| -> mlxrs::Result<Vec<Array>> {
Err(Error::Backend("USER_ERROR_PAYLOAD".into()))
},
&[0],
)
.unwrap();
let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let err = g(&[x]).expect_err("user error must surface");
let msg = format!("{err}");
assert!(
msg.contains("USER_ERROR_PAYLOAD"),
"expected user error payload to surface; got: {msg}"
);
assert!(
!msg.contains("mlx_closure returned a non-zero value"),
"must NOT surface mlx-c's generic closure-non-zero wrapper; got: {msg}"
);
}
#[test]
fn closure_user_panic_propagates_through_grad_as_error() {
let g = grad(
|_xs: &[Array]| -> mlxrs::Result<Vec<Array>> { panic!("USER_PANIC_PAYLOAD") },
&[0],
)
.unwrap();
let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let err = g(&[x]).expect_err("user panic must surface as Err");
let msg = format!("{err}");
assert!(
msg.contains("USER_PANIC_PAYLOAD"),
"expected user panic payload to surface; got: {msg}"
);
assert!(
msg.contains("panic"),
"expected indication that the closure panicked; got: {msg}"
);
assert!(
!msg.contains("mlx_closure returned a non-zero value"),
"must NOT surface mlx-c's generic closure-non-zero wrapper; got: {msg}"
);
}
#[test]
fn vjp_matches_grad_for_scalar_output() {
let primals = vec![Array::full::<f32>(&[0i32; 0], 3.0).unwrap()];
let cot = vec![Array::full::<f32>(&[0i32; 0], 1.0).unwrap()];
let (mut vals, mut grads) = vjp(|xs| Ok(vec![square(&xs[0])?]), &primals, &cot).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 9.0, 1e-5));
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 6.0, 1e-5));
}
#[test]
fn jvp_matches_directional_derivative() {
let primals = vec![Array::full::<f32>(&[0i32; 0], 3.0).unwrap()];
let tan = vec![Array::full::<f32>(&[0i32; 0], 1.0).unwrap()];
let (mut vals, mut jvp_out) = jvp(|xs| Ok(vec![square(&xs[0])?]), &primals, &tan).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 9.0, 1e-5));
assert!(approx_eq(jvp_out[0].item::<f32>().unwrap(), 6.0, 1e-5));
}
#[test]
fn custom_vjp_overrides_autograd() {
let f = custom_vjp(
|xs| Ok(vec![square(&xs[0])?]),
|primals, _cot, _outputs| {
let dims = primals[0].shape();
Ok(vec![Array::full::<f32>(&&dims[..], 42.0)?])
},
)
.unwrap();
let g = grad(f, &[0]).unwrap();
let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let mut grads = g(&[x]).unwrap();
assert!(approx_eq(grads[0].item::<f32>().unwrap(), 42.0, 1e-5));
}
#[test]
fn custom_vjp_trampoline_argument_order_regression() {
let f = custom_vjp(
|xs| Ok(vec![square(&xs[0])?]),
|primals, cotangents, outputs| {
let ten = Array::full::<f32>(&[0i32; 0], 10.0)?;
let thousand = Array::full::<f32>(&[0i32; 0], 1000.0)?;
let c_term = multiply(&cotangents[0], &ten)?;
let p_term = multiply(&primals[0], &thousand)?;
let sum1 = add(&c_term, &outputs[0])?;
Ok(vec![add(&sum1, &p_term)?])
},
)
.unwrap();
let primal = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let cotangent = Array::full::<f32>(&[0i32; 0], 2.0).unwrap();
let (_vals, mut grads) = vjp(f, &[primal], &[cotangent]).unwrap();
assert_eq!(grads.len(), 1);
let got = grads[0].item::<f32>().unwrap();
let expected = 3029.0_f32;
let swapped_value = 3092.0_f32;
assert!(
approx_eq(got, expected, 1e-3),
"trampoline arg order regression: got {got}, expected {expected} \
(a value near {swapped_value} would indicate a \
`(primals, outputs, cotangents)` slot ordering has been reintroduced)"
);
}
#[test]
fn checkpoint_returns_same_value_as_uncheckpointed() {
let x = Array::full::<f32>(&[0i32; 0], 3.0).unwrap();
let mut direct = square(&x).unwrap();
let direct_val = direct.item::<f32>().unwrap();
let cf = checkpoint(|xs| Ok(vec![square(&xs[0])?])).unwrap();
let mut vals = cf(&[x]).unwrap();
let ckpt_val = vals[0].item::<f32>().unwrap();
assert!(approx_eq(direct_val, ckpt_val, 1e-6));
}
#[test]
fn checkpoint_gradient_matches_uncheckpointed() {
let g_direct = grad(|xs| Ok(vec![square(&xs[0])?]), &[0]).unwrap();
let cf = checkpoint(|xs| Ok(vec![square(&xs[0])?])).unwrap();
let g_ckpt = grad(cf, &[0]).unwrap();
let x = Array::full::<f32>(&[0i32; 0], 4.0).unwrap();
let mut direct = g_direct(&[x.try_clone().unwrap()]).unwrap();
let mut ckpt = g_ckpt(&[x]).unwrap();
assert!(approx_eq(
direct[0].item::<f32>().unwrap(),
ckpt[0].item::<f32>().unwrap(),
1e-5,
));
}
#[test]
fn eval_empty_slice_is_noop() {
eval(&[]).unwrap();
async_eval(&[]).unwrap();
}
#[test]
fn eval_materializes_all_arrays() {
let a = Array::full::<f32>(&(2usize, 2usize), 1.0).unwrap();
let b = Array::full::<f32>(&(2usize, 2usize), 2.0).unwrap();
let c = Array::full::<f32>(&(2usize, 2usize), 3.0).unwrap();
let mut d = add(&a, &b).unwrap();
let mut e = multiply(&b, &c).unwrap();
eval(&[&d, &e]).unwrap();
let dv = d.to_vec::<f32>().unwrap();
let ev = e.to_vec::<f32>().unwrap();
assert!(dv.iter().all(|&v| approx_eq(v, 3.0, 1e-6)));
assert!(ev.iter().all(|&v| approx_eq(v, 6.0, 1e-6)));
}
#[test]
fn grad_of_sum_is_ones_vector() {
let g = grad(|xs| Ok(vec![sum(&xs[0], false)?]), &[0]).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let grads = g(&[x]).unwrap();
assert_eq!(grads.len(), 1);
let gv = contiguous(&grads[0], false)
.unwrap()
.to_vec::<f32>()
.unwrap();
assert_eq!(gv.len(), 4, "grad must keep the input's element count");
assert!(
gv.iter().all(|&v| approx_eq(v, 1.0, 1e-5)),
"d/dx_i[Σx] must be 1 everywhere; got {gv:?}"
);
}
#[test]
fn value_and_grad_of_dot_returns_value_and_constant_grad() {
let c = Array::from_slice::<f32>(&[4.0, 5.0, 6.0], &[3]).unwrap();
let vag = value_and_grad(move |xs| Ok(vec![inner(&xs[0], &c)?]), &[0]).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap();
let (mut vals, mut grads) = vag(&[x]).unwrap();
assert!(
approx_eq(vals[0].item::<f32>().unwrap(), 32.0, 1e-4),
"value must equal the dot product x·c = 32"
);
let gv = grads[0].to_vec::<f32>().unwrap();
assert_eq!(gv.len(), 3);
assert!(
approx_eq(gv[0], 4.0, 1e-5) && approx_eq(gv[1], 5.0, 1e-5) && approx_eq(gv[2], 6.0, 1e-5),
"d/dx[x·c] must equal c = [4,5,6]; got {gv:?}"
);
}
#[test]
fn grad_through_matmul_sum_is_n_times_ones() {
let w = Array::full::<f32>(&(2usize, 3usize), 1.0).unwrap();
let vag = value_and_grad(
move |xs| {
let y = matmul(&xs[0], &w)?; Ok(vec![sum(&y, false)?]) },
&[0],
)
.unwrap();
let x = Array::full::<f32>(&(2usize, 2usize), 1.0).unwrap();
let (mut vals, mut grads) = vag(&[x]).unwrap();
assert!(
approx_eq(vals[0].item::<f32>().unwrap(), 12.0, 1e-4),
"Σ(ones(2,2) @ ones(2,3)) = 12"
);
let gv = grads[0].to_vec::<f32>().unwrap();
assert_eq!(gv.len(), 4, "grad_X must be 2x2 = 4 elements");
assert!(
gv.iter().all(|&v| approx_eq(v, 3.0, 1e-5)),
"∂Σ(XW)/∂X = n·ones with n=3; got {gv:?}"
);
}
#[test]
fn value_and_grad_argnums_selects_second_arg_only() {
let vag = value_and_grad(
|xs| {
let x2 = square(&xs[0])?;
let y2 = square(&xs[1])?;
Ok(vec![add(&x2, &y2)?])
},
&[1],
)
.unwrap();
let x = Array::full::<f32>(&[0i32; 0], 2.0).unwrap();
let y = Array::full::<f32>(&[0i32; 0], 5.0).unwrap();
let (mut vals, mut grads) = vag(&[x, y]).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 29.0, 1e-4));
assert_eq!(
grads.len(),
1,
"argnums=[1] selects exactly one grad target"
);
assert!(
approx_eq(grads[0].item::<f32>().unwrap(), 10.0, 1e-4),
"grad must be ∂f/∂y = 2y = 10 (a value of 4 would mean arg 0 was differentiated)"
);
}
#[test]
fn vjp_scales_by_nonunit_cotangent() {
let primals = vec![Array::full::<f32>(&[0i32; 0], 3.0).unwrap()];
let cot = vec![Array::full::<f32>(&[0i32; 0], 2.0).unwrap()];
let (mut vals, mut grads) = vjp(|xs| Ok(vec![square(&xs[0])?]), &primals, &cot).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 9.0, 1e-5));
assert!(
approx_eq(grads[0].item::<f32>().unwrap(), 12.0, 1e-4),
"vjp = cotangent · 2x = 2 · 6 = 12"
);
}
#[test]
fn vjp_of_vector_sum_broadcasts_cotangent() {
let primals = vec![Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap()];
let cot = vec![Array::full::<f32>(&[0i32; 0], 2.0).unwrap()];
let (mut vals, grads) = vjp(|xs| Ok(vec![sum(&xs[0], false)?]), &primals, &cot).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 6.0, 1e-5));
let gv = contiguous(&grads[0], false)
.unwrap()
.to_vec::<f32>()
.unwrap();
assert_eq!(gv.len(), 3, "vjp output matches the primal's shape");
assert!(
gv.iter().all(|&v| approx_eq(v, 2.0, 1e-5)),
"vjp_i = cotangent · 1 = 2 everywhere; got {gv:?}"
);
}
#[test]
fn jvp_of_vector_sum_contracts_tangent() {
let primals = vec![Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]).unwrap()];
let tan = vec![Array::from_slice::<f32>(&[2.0, 3.0, 4.0], &[3]).unwrap()];
let (mut vals, mut jvp_out) = jvp(|xs| Ok(vec![sum(&xs[0], false)?]), &primals, &tan).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 6.0, 1e-5));
assert!(
approx_eq(jvp_out[0].item::<f32>().unwrap(), 9.0, 1e-4),
"jvp = Σ(1 · v_i) = 2+3+4 = 9"
);
}
#[test]
fn jvp_multi_primal_sums_directional_contributions() {
let primals = vec![
Array::full::<f32>(&[0i32; 0], 2.0).unwrap(),
Array::full::<f32>(&[0i32; 0], 3.0).unwrap(),
];
let tan = vec![
Array::full::<f32>(&[0i32; 0], 1.0).unwrap(),
Array::full::<f32>(&[0i32; 0], 1.0).unwrap(),
];
let (mut vals, mut jvp_out) =
jvp(|xs| Ok(vec![multiply(&xs[0], &xs[1])?]), &primals, &tan).unwrap();
assert!(approx_eq(vals[0].item::<f32>().unwrap(), 6.0, 1e-5));
assert!(
approx_eq(jvp_out[0].item::<f32>().unwrap(), 5.0, 1e-4),
"jvp = y·dx + x·dy = 3 + 2 = 5"
);
}
fn central_difference<F>(f: &F, x: f32, h: f32) -> mlxrs::Result<f32>
where
F: Fn(&Array) -> mlxrs::Result<Array>,
{
let mut up = f(&Array::full::<f32>(&[0i32; 0], x + h)?)?;
let mut down = f(&Array::full::<f32>(&[0i32; 0], x - h)?)?;
Ok((up.item::<f32>()? - down.item::<f32>()?) / (2.0 * h))
}
fn analytic_grad<F>(f: F, x: f32) -> mlxrs::Result<f32>
where
F: Fn(&Array) -> mlxrs::Result<Array> + 'static,
{
let g = grad(move |xs| Ok(vec![f(&xs[0])?]), &[0])?;
let mut grads = g(&[Array::full::<f32>(&[0i32; 0], x)?])?;
grads[0].item::<f32>()
}
#[test]
fn finite_diff_matches_grad_sin() {
let h = 1e-3_f32;
for &x in &[-1.5_f32, -0.3, 0.0, 0.7, 2.0] {
let analytic = analytic_grad(sin, x).unwrap();
let numeric = central_difference(&(|a: &Array| sin(a)), x, h).unwrap();
assert!(
approx_eq(analytic, numeric, 2e-3),
"d/dx sin at x={x}: analytic grad {analytic} vs central-diff {numeric}"
);
}
}
#[test]
fn finite_diff_matches_grad_exp() {
let h = 1e-3_f32;
for &x in &[-1.0_f32, -0.2, 0.5, 1.2] {
let analytic = analytic_grad(exp, x).unwrap();
let numeric = central_difference(&(|a: &Array| exp(a)), x, h).unwrap();
assert!(
approx_eq(analytic, numeric, 5e-3),
"d/dx exp at x={x}: analytic grad {analytic} vs central-diff {numeric}"
);
}
}
#[test]
fn finite_diff_matches_grad_tanh() {
let h = 1e-3_f32;
for &x in &[-1.3_f32, -0.4, 0.0, 0.6, 1.5] {
let analytic = analytic_grad(tanh, x).unwrap();
let numeric = central_difference(&(|a: &Array| tanh(a)), x, h).unwrap();
assert!(
approx_eq(analytic, numeric, 2e-3),
"d/dx tanh at x={x}: analytic grad {analytic} vs central-diff {numeric}"
);
}
}
#[test]
fn finite_diff_matches_grad_cubic_plus_linear() {
let f = |a: &Array| -> mlxrs::Result<Array> {
let three = Array::full::<f32>(&[0i32; 0], 3.0)?;
let two = Array::full::<f32>(&[0i32; 0], 2.0)?;
let cube = power(a, &three)?;
let lin = multiply(a, &two)?;
add(&cube, &lin)
};
let h = 1e-3_f32;
for &x in &[-1.4_f32, -0.5, 0.3, 1.1, 2.0] {
let analytic = analytic_grad(f, x).unwrap();
let numeric = central_difference(&f, x, h).unwrap();
assert!(
approx_eq(analytic, numeric, 8e-3),
"d/dx (x³+2x) at x={x}: analytic grad {analytic} vs central-diff {numeric} \
(closed form 3x²+2 = {})",
3.0 * x * x + 2.0
);
}
}
#[test]
fn finite_diff_matches_partial_grads_multivariate() {
let x0 = 1.5_f32;
let y0 = 2.0_f32;
let h = 1e-3_f32;
let forward = |xs: &[Array]| -> mlxrs::Result<Vec<Array>> {
let x2 = square(&xs[0])?;
let x2y = multiply(&x2, &xs[1])?;
Ok(vec![add(&x2y, &xs[1])?])
};
let gx = grad(forward, &[0]).unwrap();
let gy = grad(forward, &[1]).unwrap();
let mut dx = gx(&[
Array::full::<f32>(&[0i32; 0], x0).unwrap(),
Array::full::<f32>(&[0i32; 0], y0).unwrap(),
])
.unwrap();
let mut dy = gy(&[
Array::full::<f32>(&[0i32; 0], x0).unwrap(),
Array::full::<f32>(&[0i32; 0], y0).unwrap(),
])
.unwrap();
let analytic_dx = dx[0].item::<f32>().unwrap();
let analytic_dy = dy[0].item::<f32>().unwrap();
let scalar_f = |x: f32, y: f32| -> f32 {
let xa = Array::full::<f32>(&[0i32; 0], x).unwrap();
let ya = Array::full::<f32>(&[0i32; 0], y).unwrap();
let mut out = forward(&[xa, ya]).unwrap();
out.remove(0).item::<f32>().unwrap()
};
let numeric_dx = (scalar_f(x0 + h, y0) - scalar_f(x0 - h, y0)) / (2.0 * h);
let numeric_dy = (scalar_f(x0, y0 + h) - scalar_f(x0, y0 - h)) / (2.0 * h);
assert!(
approx_eq(analytic_dx, numeric_dx, 5e-3),
"∂f/∂x: analytic {analytic_dx} vs central-diff {numeric_dx} (2xy = {})",
2.0 * x0 * y0
);
assert!(
approx_eq(analytic_dy, numeric_dy, 5e-3),
"∂f/∂y: analytic {analytic_dy} vs central-diff {numeric_dy} (x²+1 = {})",
x0 * x0 + 1.0
);
}