#![cfg(target_os = "macos")]
mod common;
use std::collections::BTreeMap;
use common::{Dt, gpu_lock, pack_bytes, unpack_bytes};
use metaltile_core::{dtype::DType, ir::KernelMode};
use metaltile_runtime::Context;
use metaltile_std::ffai::kv_cache::kv_cache_update;
fn f32_slice_to_bytes(vals: &[f32]) -> Vec<u8> { pack_bytes(vals, Dt::F32) }
fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> { unpack_bytes(bytes, Dt::F32) }
#[test]
fn kv_cache_update_writes_to_correct_slot_f32() {
let _g = gpu_lock();
let n_kv_heads = 4usize;
let head_dim = 16usize;
let max_seq = 8usize;
let position = 3usize;
let sentinel = 999.0_f32;
let cache = vec![sentinel; n_kv_heads * max_seq * head_dim];
let src: Vec<f32> = (0..n_kv_heads * head_dim).map(|i| 10.0 + i as f32).collect();
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("src".into(), f32_slice_to_bytes(&src));
buffers.insert("out".into(), f32_slice_to_bytes(&cache));
buffers.insert("head_dim".into(), (head_dim as u32).to_le_bytes().to_vec());
buffers.insert("max_seq".into(), (max_seq as u32).to_le_bytes().to_vec());
buffers.insert("position".into(), (position as u32).to_le_bytes().to_vec());
let ctx = Context::new().expect("Context::new should succeed on macOS");
let mut kernel = kv_cache_update::kernel_ir_for(DType::F32);
kernel.mode = KernelMode::Grid3D;
let total_threads = n_kv_heads * head_dim;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [1, 1, 1], [total_threads, 1, 1])
.expect("dispatch_with_grid should succeed");
let out_bytes = result.outputs.get("out").expect("`out` buffer in dispatch result");
let actual = bytes_to_f32_vec(out_bytes);
for h in 0..n_kv_heads {
for d in 0..head_dim {
let cache_idx = h * max_seq * head_dim + position * head_dim + d;
let expected_val = src[h * head_dim + d];
assert!(
(actual[cache_idx] - expected_val).abs() < 1e-6,
"cache[h={h}, pos={position}, d={d}] = {} (expected {})",
actual[cache_idx],
expected_val,
);
}
}
for h in 0..n_kv_heads {
for p in 0..max_seq {
if p == position {
continue;
}
for d in 0..head_dim {
let cache_idx = h * max_seq * head_dim + p * head_dim + d;
assert!(
(actual[cache_idx] - sentinel).abs() < 1e-6,
"cache[h={h}, pos={p}, d={d}] = {} (should be sentinel {})",
actual[cache_idx],
sentinel,
);
}
}
}
}