#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use mlx_native::ops::hadamard;
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn cpu_fwht_unnormalized(x: &mut [f32]) {
let n = x.len();
assert!(n.is_power_of_two(), "FWHT length must be a power of two");
let mut h = 1usize;
while h < n {
let mut i = 0;
while i < n {
for j in i..i + h {
let a = x[j];
let b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
i += h * 2;
}
h *= 2;
}
}
fn cpu_fwht(x: &mut [f32]) {
let n = x.len();
cpu_fwht_unnormalized(x);
let scale = 1.0_f32 / (n as f32).sqrt();
for v in x.iter_mut() {
*v *= scale;
}
}
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
hadamard::register(&mut registry);
(device, registry)
}
fn run_gpu_fwht(
device: &MlxDevice,
registry: &mut KernelRegistry,
data: &mut Vec<f32>,
head_dim: u32,
num_heads: u32,
) {
let n = data.len();
let byte_len = n * 4;
let mut buf = device
.alloc_buffer(byte_len, DType::F32, vec![num_heads as usize, head_dim as usize])
.expect("alloc_buffer");
{
let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
slice.copy_from_slice(data);
}
{
let mut enc = device.command_encoder().expect("command_encoder");
hadamard::dispatch_hadamard_transform(
&mut enc,
registry,
device.metal_device(),
&buf,
head_dim,
num_heads,
)
.expect("dispatch_hadamard_transform");
enc.commit_and_wait().expect("commit_and_wait");
}
{
let slice: &[f32] = buf.as_slice().expect("as_slice");
data.copy_from_slice(slice);
}
}
fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[test]
fn test_hadamard_known_values_d4() {
let (device, mut registry) = setup();
let mut data = vec![1.0f32, 1.0, 1.0, 1.0];
run_gpu_fwht(&device, &mut registry, &mut data, 4, 1);
assert!(
(data[0] - 2.0).abs() < 1e-4,
"data[0] expected 2.0, got {}",
data[0]
);
for i in 1..4 {
assert!(
data[i].abs() < 1e-4,
"data[{}] expected 0.0, got {}",
i,
data[i]
);
}
}
fn involution_test(head_dim: u32, num_heads: u32) {
let (device, mut registry) = setup();
let n = (num_heads as usize) * (head_dim as usize);
let original: Vec<f32> = (0..n)
.map(|i| ((i as f32 * 1.618_034 + 0.5).sin()) * 2.0)
.collect();
let mut data = original.clone();
run_gpu_fwht(&device, &mut registry, &mut data, head_dim, num_heads);
run_gpu_fwht(&device, &mut registry, &mut data, head_dim, num_heads);
for (i, (&got, &exp)) in data.iter().zip(original.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-4,
"Involution failed at dim={head_dim} heads={num_heads} idx={i}: \
got {got:.6}, expected {exp:.6}"
);
}
}
#[test]
fn test_hadamard_involution_d128() {
involution_test(128, 4);
}
#[test]
fn test_hadamard_involution_d256() {
involution_test(256, 8);
}
#[test]
fn test_hadamard_involution_d512() {
involution_test(512, 2);
}
fn energy_test(head_dim: u32, num_heads: u32) {
let (device, mut registry) = setup();
let n = (num_heads as usize) * (head_dim as usize);
let original: Vec<f32> = (0..n)
.map(|i| ((i as f32 * 2.718_28 + 1.0).cos()) * 3.0)
.collect();
let original_norm = l2_norm(&original);
let mut data = original.clone();
run_gpu_fwht(&device, &mut registry, &mut data, head_dim, num_heads);
let transformed_norm = l2_norm(&data);
let rel_err = (transformed_norm - original_norm).abs() / original_norm.max(1e-8);
assert!(
rel_err < 1e-4,
"Energy not preserved at dim={head_dim} heads={num_heads}: \
original_norm={original_norm:.6} transformed_norm={transformed_norm:.6} \
rel_err={rel_err:.6}"
);
}
#[test]
fn test_hadamard_energy_d128() {
energy_test(128, 4);
}
#[test]
fn test_hadamard_energy_d256() {
energy_test(256, 8);
}
#[test]
fn test_hadamard_energy_d512() {
energy_test(512, 2);
}
fn gpu_vs_cpu_test(head_dim: u32, num_heads: u32) {
let (device, mut registry) = setup();
let n = (num_heads as usize) * (head_dim as usize);
let input: Vec<f32> = (0..n)
.map(|i| i as f32 * 0.123 - (n as f32) * 0.0615)
.collect();
let mut cpu_out = input.clone();
for h in 0..num_heads as usize {
let base = h * head_dim as usize;
cpu_fwht(&mut cpu_out[base..base + head_dim as usize]);
}
let mut gpu_out = input.clone();
run_gpu_fwht(&device, &mut registry, &mut gpu_out, head_dim, num_heads);
for (i, (&g, &c)) in gpu_out.iter().zip(cpu_out.iter()).enumerate() {
assert!(
(g - c).abs() < 1e-4,
"GPU vs CPU mismatch at dim={head_dim} heads={num_heads} idx={i}: \
gpu={g:.6} cpu={c:.6}"
);
}
}
#[test]
fn test_hadamard_gpu_vs_cpu_d128() {
gpu_vs_cpu_test(128, 4);
}
#[test]
fn test_hadamard_gpu_vs_cpu_d256() {
gpu_vs_cpu_test(256, 8);
}
#[test]
fn test_hadamard_gpu_vs_cpu_d512() {
gpu_vs_cpu_test(512, 2);
}
#[test]
fn test_hadamard_non_power_of_two_fails() {
let (device, mut registry) = setup();
let byte_len = 6 * 4;
let buf = device
.alloc_buffer(byte_len, DType::F32, vec![1, 6])
.expect("alloc_buffer");
let mut enc = device.command_encoder().expect("command_encoder");
let result = hadamard::dispatch_hadamard_transform(
&mut enc,
&mut registry,
device.metal_device(),
&buf,
6, 1,
);
assert!(result.is_err(), "Non-power-of-two head_dim should fail");
}
#[test]
fn test_hadamard_zero_heads_is_noop() {
let (device, mut registry) = setup();
let byte_len = 256 * 4;
let buf = device
.alloc_buffer(byte_len, DType::F32, vec![256])
.expect("alloc_buffer");
let mut enc = device.command_encoder().expect("command_encoder");
let result = hadamard::dispatch_hadamard_transform(
&mut enc,
&mut registry,
device.metal_device(),
&buf,
256,
0, );
assert!(result.is_ok(), "Zero num_heads should be a no-op, not an error");
enc.commit_and_wait().expect("commit_and_wait");
}