#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
#![cfg(target_vendor = "apple")]
use mlx_native::ops::kv_cache_copy;
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
kv_cache_copy::register(&mut registry);
(device, registry)
}
fn f32_to_bf16_bytes(val: f32) -> [u8; 2] {
let bits = val.to_bits();
let bf16_bits = ((bits + 0x7FFF + ((bits >> 16) & 1)) >> 16) as u16;
bf16_bits.to_le_bytes()
}
fn bf16_bytes_to_f32(bytes: [u8; 2]) -> f32 {
let bf16_bits = u16::from_le_bytes(bytes);
f32::from_bits((bf16_bits as u32) << 16)
}
fn write_bf16_values(buf: &mut [u8], values: &[f32]) {
for (i, &v) in values.iter().enumerate() {
let bytes = f32_to_bf16_bytes(v);
buf[i * 2] = bytes[0];
buf[i * 2 + 1] = bytes[1];
}
}
#[test]
fn test_kv_cache_copy_linear() {
let (device, mut registry) = setup();
let n_new: u32 = 4;
let row_size: u32 = 128; let cache_cap: u32 = 64;
let write_pos: u32 = 0;
let total_src = n_new as usize * row_size as usize;
let total_cache = cache_cap as usize * row_size as usize;
let src_f32: Vec<f32> = (0..total_src).map(|i| (i as f32) * 0.01).collect();
let mut src_bytes = vec![0u8; total_src * 2];
write_bf16_values(&mut src_bytes, &src_f32);
let mut src_buf = device
.alloc_buffer(total_src * 2, DType::BF16, vec![total_src])
.expect("alloc src");
src_buf
.as_mut_slice::<u8>()
.expect("write src")
.copy_from_slice(&src_bytes);
let mut cache_buf = device
.alloc_buffer(total_cache * 2, DType::BF16, vec![total_cache])
.expect("alloc cache");
for b in cache_buf.as_mut_slice::<u8>().expect("write cache").iter_mut() {
*b = 0;
}
let mut encoder = device.command_encoder().expect("encoder");
kv_cache_copy::dispatch_kv_cache_copy(
&mut encoder,
&mut registry,
device.metal_device(),
&src_buf,
&cache_buf,
write_pos,
row_size,
n_new,
cache_cap,
false, )
.expect("dispatch_kv_cache_copy");
encoder.commit_and_wait().expect("commit_and_wait");
let cache_bytes: Vec<u8> = cache_buf.as_slice::<u8>().expect("read cache").to_vec();
for i in 0..total_src {
let src_val = bf16_bytes_to_f32([src_bytes[i * 2], src_bytes[i * 2 + 1]]);
let cache_val = bf16_bytes_to_f32([cache_bytes[i * 2], cache_bytes[i * 2 + 1]]);
assert!(
(src_val - cache_val).abs() < 1e-10,
"linear copy: bitwise mismatch at element {}: src_bf16={}, cache_bf16={} (src_f32={}, cache_f32={})",
i, src_val, cache_val,
src_f32[i], cache_val,
);
}
for i in total_src..total_cache {
let val = bf16_bytes_to_f32([cache_bytes[i * 2], cache_bytes[i * 2 + 1]]);
assert!(
val == 0.0,
"linear copy: cache element {} should be 0.0, got {}",
i, val
);
}
}
#[test]
fn test_kv_cache_copy_sliding_wrap() {
let (device, mut registry) = setup();
let n_new: u32 = 3;
let row_size: u32 = 64;
let cache_cap: u32 = 4; let write_pos: u32 = 6;
let total_src = n_new as usize * row_size as usize;
let total_cache = cache_cap as usize * row_size as usize;
let src_f32: Vec<f32> = (0..total_src).map(|i| ((i + 1) as f32) * 0.1).collect();
let mut src_bytes = vec![0u8; total_src * 2];
write_bf16_values(&mut src_bytes, &src_f32);
let cache_init: Vec<f32> = (0..total_cache).map(|i| -(i as f32) * 0.01).collect();
let mut cache_init_bytes = vec![0u8; total_cache * 2];
write_bf16_values(&mut cache_init_bytes, &cache_init);
let mut src_buf = device
.alloc_buffer(total_src * 2, DType::BF16, vec![total_src])
.expect("alloc src");
src_buf
.as_mut_slice::<u8>()
.expect("write src")
.copy_from_slice(&src_bytes);
let mut cache_buf = device
.alloc_buffer(total_cache * 2, DType::BF16, vec![total_cache])
.expect("alloc cache");
cache_buf
.as_mut_slice::<u8>()
.expect("write cache")
.copy_from_slice(&cache_init_bytes);
let mut encoder = device.command_encoder().expect("encoder");
kv_cache_copy::dispatch_kv_cache_copy(
&mut encoder,
&mut registry,
device.metal_device(),
&src_buf,
&cache_buf,
write_pos,
row_size,
n_new,
cache_cap,
true, )
.expect("dispatch_kv_cache_copy sliding");
encoder.commit_and_wait().expect("commit_and_wait");
let cache_bytes: Vec<u8> = cache_buf.as_slice::<u8>().expect("read cache").to_vec();
for t in 0..n_new as usize {
let cache_row = ((write_pos as usize) + t) % (cache_cap as usize);
for j in 0..row_size as usize {
let src_elem = t * row_size as usize + j;
let cache_elem = cache_row * row_size as usize + j;
let src_val = bf16_bytes_to_f32([src_bytes[src_elem * 2], src_bytes[src_elem * 2 + 1]]);
let cache_val = bf16_bytes_to_f32([cache_bytes[cache_elem * 2], cache_bytes[cache_elem * 2 + 1]]);
assert!(
(src_val - cache_val).abs() < 1e-10,
"sliding wrap: mismatch at token {} elem {} (cache_row={}): src={}, cache={}",
t, j, cache_row, src_val, cache_val,
);
}
}
}
#[test]
fn test_kv_cache_copy_empty() {
let (device, mut registry) = setup();
let src_buf = device
.alloc_buffer(2, DType::BF16, vec![1])
.expect("alloc src");
let cache_buf = device
.alloc_buffer(2, DType::BF16, vec![1])
.expect("alloc cache");
let mut encoder = device.command_encoder().expect("encoder");
let result = kv_cache_copy::dispatch_kv_cache_copy(
&mut encoder,
&mut registry,
device.metal_device(),
&src_buf,
&cache_buf,
0, 64, 0, 64, false,
);
assert!(result.is_ok(), "n_new=0 should succeed (no-op)");
}
#[test]
fn test_kv_cache_copy_global_overflow_error() {
let (device, mut registry) = setup();
let src_buf = device
.alloc_buffer(256, DType::BF16, vec![128])
.expect("alloc src");
let cache_buf = device
.alloc_buffer(256, DType::BF16, vec![128])
.expect("alloc cache");
let mut encoder = device.command_encoder().expect("encoder");
let result = kv_cache_copy::dispatch_kv_cache_copy(
&mut encoder,
&mut registry,
device.metal_device(),
&src_buf,
&cache_buf,
3, 1, 2, 4, false, );
assert!(result.is_err(), "Should error on global cache overflow");
}