#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let registry = KernelRegistry::new();
(device, registry)
}
fn alloc_params(device: &MlxDevice, eps: f32, dim: u32) -> mlx_native::MlxBuffer {
let mut buf = device
.alloc_buffer(2 * 4, DType::F32, vec![2])
.expect("alloc params");
{
let s = buf.as_mut_slice::<f32>().expect("mut params");
s[0] = eps;
s[1] = dim as f32;
}
buf
}
#[test]
fn test_l2_norm_f32_3_4_5_triangle() {
let (device, mut registry) = setup();
let eps = 0.0f32;
let dim = 2u32;
let rows = 1u32;
let input_data = [3.0f32, 4.0f32];
let mut input = device
.alloc_buffer(8, DType::F32, vec![dim as usize])
.expect("alloc input");
input
.as_mut_slice::<f32>()
.expect("mut input")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(8, DType::F32, vec![dim as usize])
.expect("alloc output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("encoder");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[f32] = output.as_slice().expect("read");
let expected = [0.6f32, 0.8f32];
for i in 0..2 {
let diff = (got[i] - expected[i]).abs();
assert!(
diff < 1e-5,
"f32 3-4-5 triangle mismatch at {}: got {}, expected {}, diff {}",
i,
got[i],
expected[i],
diff
);
}
}
#[test]
fn test_l2_norm_f32_multirow() {
let (device, mut registry) = setup();
let eps = 0.0f32;
let dim = 4u32;
let rows = 3u32;
let n = (rows * dim) as usize;
let input_data: [f32; 12] = [
1.0, 0.0, 0.0, 0.0,
1.0, 1.0, 1.0, 1.0,
0.3, 0.4, 0.0, 0.0,
];
let mut input = device
.alloc_buffer(n * 4, DType::F32, vec![rows as usize, dim as usize])
.expect("input");
input
.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(n * 4, DType::F32, vec![rows as usize, dim as usize])
.expect("output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("enc");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[f32] = output.as_slice().expect("read");
let expected: [f32; 12] = [
1.0, 0.0, 0.0, 0.0,
0.5, 0.5, 0.5, 0.5,
0.6, 0.8, 0.0, 0.0,
];
for i in 0..12 {
let diff = (got[i] - expected[i]).abs();
assert!(
diff < 1e-5,
"multirow mismatch at {}: got {}, expected {}, diff {}",
i, got[i], expected[i], diff
);
}
}
#[test]
fn test_l2_norm_f32_round_trip() {
let (device, mut registry) = setup();
let eps = 0.0f32;
let dim = 64u32;
let rows = 8u32;
let n = (rows * dim) as usize;
let mut input_data = vec![0.0f32; n];
let mut seed = 0x1234u32;
for v in input_data.iter_mut() {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
*v = (seed as i32 as f32) / (i32::MAX as f32);
}
let mut input = device
.alloc_buffer(n * 4, DType::F32, vec![rows as usize, dim as usize])
.expect("input");
input
.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(n * 4, DType::F32, vec![rows as usize, dim as usize])
.expect("output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("enc");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[f32] = output.as_slice().expect("read");
for r in 0..rows as usize {
let mut sum_sq = 0.0f64;
for c in 0..dim as usize {
let v = input_data[r * dim as usize + c] as f64;
sum_sq += v * v;
}
let row_norm = sum_sq.sqrt() as f32;
for c in 0..dim as usize {
let idx = r * dim as usize + c;
let reconstructed = got[idx] * row_norm;
let diff = (reconstructed - input_data[idx]).abs();
assert!(
diff < 1e-5,
"round-trip mismatch at (r={}, c={}): got {}, expected {}, diff {}",
r, c, reconstructed, input_data[idx], diff
);
}
}
}
#[test]
fn test_l2_norm_f32_zero_row_with_eps() {
let (device, mut registry) = setup();
let eps = 1e-6f32;
let dim = 4u32;
let rows = 1u32;
let input_data = [0.0f32; 4];
let mut input = device
.alloc_buffer(16, DType::F32, vec![dim as usize])
.expect("input");
input
.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(16, DType::F32, vec![dim as usize])
.expect("output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("enc");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[f32] = output.as_slice().expect("read");
for (i, v) in got.iter().enumerate().take(4) {
assert!(v.is_finite(), "zero-row produced non-finite at {}: {}", i, v);
assert!(v.abs() < 1e-3, "zero-row not near zero at {}: {}", i, v);
}
}
#[test]
fn test_l2_norm_f32_eps_effect() {
let (device, mut registry) = setup();
let eps = 9.0f32; let dim = 2u32;
let rows = 1u32;
let input_data = [0.0f32, 4.0f32];
let mut input = device
.alloc_buffer(8, DType::F32, vec![dim as usize])
.expect("input");
input
.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(8, DType::F32, vec![dim as usize])
.expect("output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("enc");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[f32] = output.as_slice().expect("read");
let expected = [0.0f32, 0.8f32];
for i in 0..2 {
let diff = (got[i] - expected[i]).abs();
assert!(
diff < 1e-5,
"eps-effect mismatch at {}: got {}, expected {}, diff {}",
i, got[i], expected[i], diff
);
}
}
#[test]
fn test_l2_norm_bf16_3_4_5_triangle() {
use half::bf16;
let (device, mut registry) = setup();
let eps = 0.0f32;
let dim = 2u32;
let rows = 1u32;
let input_data = [bf16::from_f32(3.0), bf16::from_f32(4.0)];
let mut input = device
.alloc_buffer(4, DType::BF16, vec![dim as usize])
.expect("input");
input
.as_mut_slice::<bf16>()
.expect("mut")
.copy_from_slice(&input_data);
let output = device
.alloc_buffer(4, DType::BF16, vec![dim as usize])
.expect("output");
let params = alloc_params(&device, eps, dim);
let mut encoder = device.command_encoder().expect("enc");
mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
)
.expect("dispatch");
encoder.commit_and_wait().expect("commit");
let got: &[bf16] = output.as_slice().expect("read");
let got_f32 = [got[0].to_f32(), got[1].to_f32()];
let expected = [0.6f32, 0.8f32];
for i in 0..2 {
let diff = (got_f32[i] - expected[i]).abs();
assert!(
diff < 1e-2,
"bf16 3-4-5 triangle mismatch at {}: got {}, expected {}, diff {}",
i, got_f32[i], expected[i], diff
);
}
}
#[test]
fn test_l2_norm_rejects_zero_rows() {
let (device, mut registry) = setup();
let dim = 4u32;
let input = device
.alloc_buffer(16, DType::F32, vec![dim as usize])
.expect("input");
let output = device
.alloc_buffer(16, DType::F32, vec![dim as usize])
.expect("output");
let params = alloc_params(&device, 0.0, dim);
let mut encoder = device.command_encoder().expect("enc");
let res = mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
0, dim,
);
assert!(res.is_err(), "zero rows should error");
}
#[test]
fn test_l2_norm_rejects_mismatched_dtype() {
use half::bf16;
let _ = bf16::from_f32(0.0);
let (device, mut registry) = setup();
let dim = 4u32;
let rows = 1u32;
let input = device
.alloc_buffer(16, DType::F32, vec![dim as usize])
.expect("input");
let output = device
.alloc_buffer(8, DType::BF16, vec![dim as usize])
.expect("output");
let params = alloc_params(&device, 0.0, dim);
let mut encoder = device.command_encoder().expect("enc");
let res = mlx_native::ops::l2_norm::dispatch_l2_norm(
&mut encoder,
&mut registry,
device.metal_device(),
&input,
&output,
¶ms,
rows,
dim,
);
assert!(res.is_err(), "dtype mismatch should error");
}