#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
#![cfg(target_vendor = "apple")]
use mlx_native::ops::gather_bench;
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
gather_bench::register(&mut registry);
(device, registry)
}
#[test]
fn test_gather_nibble_correctness() {
let (device, mut registry) = setup();
let capacity: u32 = 4;
let head_dim: u32 = 8;
let n_centroids: usize = 16;
let mut centroids_cpu = vec![0.0f32; n_centroids * head_dim as usize];
for k in 0..n_centroids {
for c in 0..head_dim as usize {
centroids_cpu[k * head_dim as usize + c] = ((k + 1) * (c + 1)) as f32;
}
}
let packed_len = capacity as usize * (head_dim as usize / 2);
let mut packed_cpu = vec![0u8; packed_len];
for p in 0..capacity as usize {
for pair in 0..(head_dim as usize / 2) {
let idx_even = (p % 16) as u8;
let idx_odd = ((p + 1) % 16) as u8;
packed_cpu[p * (head_dim as usize / 2) + pair] = (idx_odd << 4) | idx_even;
}
}
let mut expected = vec![0.0f32; capacity as usize * head_dim as usize];
for p in 0..capacity as usize {
for c in 0..head_dim as usize {
let byte_idx = p * (head_dim as usize / 2) + c / 2;
let byte = packed_cpu[byte_idx];
let idx = if c % 2 == 0 { byte & 0xF } else { (byte >> 4) & 0xF } as usize;
expected[p * head_dim as usize + c] = centroids_cpu[idx * head_dim as usize + c];
}
}
let mut packed_buf = device
.alloc_buffer(packed_len, DType::U8, vec![packed_len])
.expect("alloc packed");
packed_buf
.as_mut_slice::<u8>()
.expect("write packed")
.copy_from_slice(&packed_cpu);
let centroid_byte_len = centroids_cpu.len() * 4;
let mut centroid_buf = device
.alloc_buffer(centroid_byte_len, DType::F32, vec![centroids_cpu.len()])
.expect("alloc centroids");
{
let slice: &mut [f32] = centroid_buf.as_mut_slice().expect("write centroids");
slice.copy_from_slice(¢roids_cpu);
}
let out_len = capacity as usize * head_dim as usize;
let out_buf = device
.alloc_buffer(out_len * 4, DType::F32, vec![out_len])
.expect("alloc out");
let mut encoder = device.command_encoder().expect("encoder");
gather_bench::dispatch_gather_nibble(
&mut encoder,
&mut registry,
device.metal_device(),
&packed_buf,
¢roid_buf,
&out_buf,
capacity,
head_dim,
)
.expect("dispatch_gather_nibble");
encoder.commit_and_wait().expect("commit_and_wait");
let result: Vec<f32> = out_buf.as_slice::<f32>().expect("read out").to_vec();
assert_eq!(result.len(), expected.len());
for (i, (&got, &exp)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-5,
"mismatch at element {i}: got {got}, expected {exp}"
);
}
println!("test_gather_nibble_correctness: PASSED (capacity={capacity}, head_dim={head_dim})");
println!(" First 8 output values: {:?}", &result[..8]);
}
#[test]
fn test_gather_f16_seq_correctness() {
let (device, mut registry) = setup();
let capacity: u32 = 4;
let head_dim: u32 = 8;
let n = capacity as usize * head_dim as usize;
let cache_byte_len = n * 2;
let mut cache_buf = device
.alloc_buffer(cache_byte_len, DType::F16, vec![n])
.expect("alloc cache");
{
let raw: &mut [u8] = cache_buf.as_mut_slice().expect("write cache");
for chunk in raw.chunks_exact_mut(2) {
chunk[0] = 0x00;
chunk[1] = 0x3C;
}
}
let out_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc out");
let mut encoder = device.command_encoder().expect("encoder");
gather_bench::dispatch_gather_f16_seq(
&mut encoder,
&mut registry,
device.metal_device(),
&cache_buf,
&out_buf,
capacity,
head_dim,
)
.expect("dispatch_gather_f16_seq");
encoder.commit_and_wait().expect("commit_and_wait");
let result: Vec<f32> = out_buf.as_slice::<f32>().expect("read out").to_vec();
for (i, &v) in result.iter().enumerate() {
assert!(
(v - 1.0f32).abs() < 1e-4,
"element {i}: expected 1.0, got {v}"
);
}
println!("test_gather_f16_seq_correctness: PASSED (capacity={capacity}, head_dim={head_dim})");
println!(" All {n} elements = {:.4}", result[0]);
}