#![cfg(target_os = "macos")]
mod common;
use std::collections::BTreeMap;
use common::{
Dt,
max_abs_diff,
naive_aura_encode_f32,
pack_bytes,
pack_u32_bytes,
ramp,
srht_rotation,
unpack_bytes,
unpack_u32_bytes,
};
use metaltile_core::{dtype::DType, ir::KernelMode};
use metaltile_runtime::Context;
use metaltile_std::ffai::aura_encode::aura_encode_int4;
fn f32_slice_to_bytes(vals: &[f32]) -> Vec<u8> { pack_bytes(vals, Dt::F32) }
fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> { unpack_bytes(bytes, Dt::F32) }
fn identity_rotation(dim: usize) -> Vec<f32> {
let mut r = vec![0.0_f32; dim * dim];
for d in 0..dim {
r[d * dim + d] = 1.0;
}
r
}
fn int4_uniform_codebook() -> (Vec<f32>, Vec<f32>) {
let levels = 16;
let codebook: Vec<f32> =
(0..levels).map(|i| -1.0 + 2.0 * (i as f32) / (levels as f32 - 1.0)).collect();
let boundaries: Vec<f32> =
(0..levels - 1).map(|i| 0.5 * (codebook[i] + codebook[i + 1])).collect();
(codebook, boundaries)
}
#[test]
fn aura_encode_int4_matches_naive_cpu_reference_f32() {
let dim = 128usize;
let bits = 4usize;
let rows = 2usize;
let packed_width = (dim * bits).div_ceil(32);
let (codebook, boundaries) = int4_uniform_codebook();
let rotation = identity_rotation(dim);
let input = ramp(rows * dim, 23, 9.0);
let (expected_packed, expected_norms) =
naive_aura_encode_f32(&input, &rotation, &boundaries, &codebook, rows, dim, bits);
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("input".into(), f32_slice_to_bytes(&input));
buffers.insert("rotation".into(), f32_slice_to_bytes(&rotation));
buffers.insert("boundaries".into(), f32_slice_to_bytes(&boundaries));
buffers.insert("codebook".into(), f32_slice_to_bytes(&codebook));
buffers.insert("packed_out".into(), pack_u32_bytes(&vec![0u32; rows * packed_width]));
buffers.insert("norms_out".into(), f32_slice_to_bytes(&vec![0.0_f32; rows]));
buffers.insert("dim".into(), (dim as u32).to_le_bytes().to_vec());
buffers.insert("packed_width".into(), (packed_width as u32).to_le_bytes().to_vec());
let ctx = Context::new().expect("Context::new should succeed on macOS");
let mut kernel = aura_encode_int4::kernel_ir_for(DType::F32);
kernel.mode = KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [rows, 1, 1], [dim, 1, 1])
.expect("dispatch_with_grid should succeed");
let packed_bytes =
result.outputs.get("packed_out").expect("`packed_out` buffer in dispatch result");
let norms_bytes =
result.outputs.get("norms_out").expect("`norms_out` buffer in dispatch result");
let actual_packed = unpack_u32_bytes(packed_bytes);
let actual_norms = bytes_to_f32_vec(norms_bytes);
assert_eq!(actual_packed, expected_packed, "packed_out mismatch — quantisation indices differ",);
let diff = max_abs_diff(&expected_norms, &actual_norms);
assert!(diff < 1e-4, "norms_out diverges from CPU reference: max |diff| = {diff:.2e}",);
}
#[test]
fn aura_encode_int4_minimum_dim_f32() {
let dim = 32usize;
let bits = 4usize;
let rows = 1usize;
let packed_width = (dim * bits).div_ceil(32);
let (codebook, boundaries) = int4_uniform_codebook();
let rotation = identity_rotation(dim);
let input = ramp(rows * dim, 13, 6.0);
let (expected_packed, expected_norms) =
naive_aura_encode_f32(&input, &rotation, &boundaries, &codebook, rows, dim, bits);
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("input".into(), f32_slice_to_bytes(&input));
buffers.insert("rotation".into(), f32_slice_to_bytes(&rotation));
buffers.insert("boundaries".into(), f32_slice_to_bytes(&boundaries));
buffers.insert("codebook".into(), f32_slice_to_bytes(&codebook));
buffers.insert("packed_out".into(), pack_u32_bytes(&vec![0u32; rows * packed_width]));
buffers.insert("norms_out".into(), f32_slice_to_bytes(&vec![0.0_f32; rows]));
buffers.insert("dim".into(), (dim as u32).to_le_bytes().to_vec());
buffers.insert("packed_width".into(), (packed_width as u32).to_le_bytes().to_vec());
let ctx = Context::new().expect("Context::new should succeed on macOS");
let mut kernel = aura_encode_int4::kernel_ir_for(DType::F32);
kernel.mode = KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [rows, 1, 1], [dim, 1, 1])
.expect("dispatch_with_grid should succeed");
let packed_bytes =
result.outputs.get("packed_out").expect("`packed_out` buffer in dispatch result");
let norms_bytes =
result.outputs.get("norms_out").expect("`norms_out` buffer in dispatch result");
let actual_packed = unpack_u32_bytes(packed_bytes);
let actual_norms = bytes_to_f32_vec(norms_bytes);
assert_eq!(actual_packed, expected_packed, "packed_out mismatch at dim=32");
let diff = max_abs_diff(&expected_norms, &actual_norms);
assert!(diff < 1e-4, "norms_out diverges: max |diff| = {diff:.2e}");
}
#[test]
fn aura_encode_int4_srht_rotation_f32() {
let dim = 128usize;
let bits = 4usize;
let rows = 3usize;
let packed_width = (dim * bits).div_ceil(32);
let (codebook, boundaries) = int4_uniform_codebook();
let rotation = srht_rotation(dim, 0xA09A_5EED);
let input = ramp(rows * dim, 29, 11.0);
let (expected_packed, expected_norms) =
naive_aura_encode_f32(&input, &rotation, &boundaries, &codebook, rows, dim, bits);
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("input".into(), f32_slice_to_bytes(&input));
buffers.insert("rotation".into(), f32_slice_to_bytes(&rotation));
buffers.insert("boundaries".into(), f32_slice_to_bytes(&boundaries));
buffers.insert("codebook".into(), f32_slice_to_bytes(&codebook));
buffers.insert("packed_out".into(), pack_u32_bytes(&vec![0u32; rows * packed_width]));
buffers.insert("norms_out".into(), f32_slice_to_bytes(&vec![0.0_f32; rows]));
buffers.insert("dim".into(), (dim as u32).to_le_bytes().to_vec());
buffers.insert("packed_width".into(), (packed_width as u32).to_le_bytes().to_vec());
let ctx = Context::new().expect("Context::new should succeed on macOS");
let mut kernel = aura_encode_int4::kernel_ir_for(DType::F32);
kernel.mode = KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [rows, 1, 1], [dim, 1, 1])
.expect("dispatch_with_grid should succeed");
let packed_bytes =
result.outputs.get("packed_out").expect("`packed_out` buffer in dispatch result");
let norms_bytes =
result.outputs.get("norms_out").expect("`norms_out` buffer in dispatch result");
let actual_packed = unpack_u32_bytes(packed_bytes);
let actual_norms = bytes_to_f32_vec(norms_bytes);
assert_eq!(
actual_packed, expected_packed,
"packed_out mismatch under SRHT rotation — rotation matmul stage diverges",
);
let diff = max_abs_diff(&expected_norms, &actual_norms);
assert!(diff < 1e-4, "norms_out diverges under SRHT rotation: max |diff| = {diff:.2e}",);
}