#![cfg(target_os = "macos")]
use mlxrs::{
Array, Dtype,
ops::fast::metal_kernel::{KernelTemplateArg, MetalKernel, MetalKernelApplyConfig},
};
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn exp_kernel_writes_e_to_every_element() {
let input = Array::ones::<f32>(&[8]).unwrap();
let kernel = MetalKernel::new(
"exp_kernel",
&["input"],
&["out"],
"uint elem = thread_position_in_grid.x;
out[elem] = exp(input[elem]);",
"",
true,
false,
)
.unwrap();
assert_eq!(kernel.output_arity(), 1);
assert_eq!(kernel.output_names_slice(), &["out".to_string()]);
let cfg = MetalKernelApplyConfig::new(
[8, 1, 1],
[8, 1, 1],
vec![vec![8]],
vec![Dtype::F32],
)
.unwrap();
let mut outs = kernel.apply(&[&input], &cfg).unwrap();
assert_eq!(outs.len(), 1);
assert_eq!(outs[0].shape(), vec![8]);
let buf: Vec<f32> = outs[0].to_vec().unwrap();
let e = std::f32::consts::E;
for (i, v) in buf.iter().enumerate() {
assert!((v - e).abs() < 1e-5, "out[{i}] = {v}, expected ≈ {e}");
}
}
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn saxpy_kernel_uses_template_alpha() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let y = Array::from_slice::<f32>(&[10.0, 20.0, 30.0, 40.0], &[4]).unwrap();
let kernel = MetalKernel::new(
"saxpy_kernel",
&["x", "y"],
&["out"],
"uint elem = thread_position_in_grid.x;
out[elem] = float(ALPHA) * x[elem] + y[elem];",
"",
true,
false,
)
.unwrap();
let cfg = MetalKernelApplyConfig::new([4, 1, 1], [4, 1, 1], vec![vec![4]], vec![Dtype::F32])
.unwrap()
.with_template(vec![("ALPHA".to_string(), KernelTemplateArg::Int(2))]);
let mut outs = kernel.apply(&[&x, &y], &cfg).unwrap();
assert_eq!(outs.len(), 1);
let buf: Vec<f32> = outs[0].to_vec().unwrap();
assert_eq!(buf, vec![12.0, 24.0, 36.0, 48.0]);
}
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn multi_output_kernel_emits_two_arrays() {
let input = Array::full::<f32>(&[4], 5.0).unwrap();
let kernel = MetalKernel::new(
"split_kernel",
&["input"],
&["sum", "diff"],
"uint elem = thread_position_in_grid.x;
sum[elem] = input[elem] + 1.0;
diff[elem] = input[elem] - 1.0;",
"",
true,
false,
)
.unwrap();
assert_eq!(kernel.output_arity(), 2);
let cfg = MetalKernelApplyConfig::new(
[4, 1, 1],
[4, 1, 1],
vec![vec![4], vec![4]],
vec![Dtype::F32, Dtype::F32],
)
.unwrap();
let mut outs = kernel.apply(&[&input], &cfg).unwrap();
assert_eq!(outs.len(), 2);
assert_eq!(outs[0].shape(), vec![4]);
assert_eq!(outs[1].shape(), vec![4]);
let sum: Vec<f32> = outs[0].to_vec().unwrap();
let diff: Vec<f32> = outs[1].to_vec().unwrap();
assert_eq!(sum, vec![6.0, 6.0, 6.0, 6.0]);
assert_eq!(diff, vec![4.0, 4.0, 4.0, 4.0]);
}
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn apply_accepts_valid_multi_dim_output_shape() {
let input = Array::ones::<f32>(&[4, 8, 16]).unwrap();
let kernel = MetalKernel::new(
"constant_3d_kernel",
&["input"],
&["out"],
"uint elem = thread_position_in_grid.x;
out[elem] = input[elem] * 2.0;",
"",
true,
false,
)
.unwrap();
let cfg = MetalKernelApplyConfig::new(
[4 * 8 * 16, 1, 1],
[32, 1, 1],
vec![vec![4, 8, 16]],
vec![Dtype::F32],
)
.unwrap();
let mut outs = kernel.apply(&[&input], &cfg).unwrap();
assert_eq!(outs.len(), 1);
assert_eq!(outs[0].shape(), vec![4, 8, 16]);
assert_eq!(outs[0].dtype().unwrap(), Dtype::F32);
let buf: Vec<f32> = outs[0].to_vec().unwrap();
assert_eq!(buf.len(), 4 * 8 * 16);
for (i, v) in buf.iter().enumerate() {
assert!((v - 2.0_f32).abs() < 1e-5, "out[{i}] = {v}, expected 2.0");
}
}
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn apply_rejects_shape_count_mismatch() {
let kernel = MetalKernel::new(
"noop",
&["x"],
&["out"],
"uint elem = thread_position_in_grid.x; out[elem] = x[elem];",
"",
true,
false,
)
.unwrap();
let input = Array::ones::<f32>(&[4]).unwrap();
let cfg = MetalKernelApplyConfig::new(
[4, 1, 1],
[4, 1, 1],
vec![vec![4], vec![4]],
vec![Dtype::F32, Dtype::F32],
)
.unwrap();
let err = kernel
.apply(&[&input], &cfg)
.expect_err("declared 1 output_name but supplied 2 output_shapes");
match err {
mlxrs::Error::LengthMismatch(payload) => {
assert_eq!(payload.expected(), 1, "expected count: {:?}", payload);
assert_eq!(payload.actual(), 2, "actual count: {:?}", payload);
}
other => panic!("expected LengthMismatch, got: {other:?}"),
}
}