use super::super::*;
fn build_q4k_test_data(out_dim: usize, in_dim: usize) -> Vec<u8> {
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let total_bytes = out_dim * row_bytes;
let mut data = vec![0u8; total_bytes];
for row in 0..out_dim {
for sb in 0..num_blocks_per_row {
let offset = row * row_bytes + sb * SUPER_BLOCK_BYTES;
data[offset] = 0x00;
data[offset + 1] = 0x38;
data[offset + 2] = 0x00;
data[offset + 3] = 0x34;
for i in 0..12 {
data[offset + 4 + i] = ((row + sb + i + 1) & 0x3F) as u8;
}
for i in 0..128 {
let low = ((row + sb + i) % 16) as u8;
let high = ((row + sb + i + 5) % 16) as u8;
data[offset + 16 + i] = low | (high << 4);
}
}
}
data
}
fn build_input(in_dim: usize) -> Vec<f32> {
(0..in_dim).map(|i| (i as f32 * 0.00137).sin()).collect()
}
#[test]
fn test_q4k_parallel_dispatch_matches_scalar() {
let out_dim = 4096;
let in_dim = 2048; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000, "Must trigger parallel path");
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(scalar.len(), out_dim);
assert_eq!(dispatch.len(), out_dim);
let check_indices: Vec<usize> =
(0..out_dim).step_by(64).chain(std::iter::once(out_dim - 1)).collect();
for &i in &check_indices {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q4k_parallel_dispatch_prime_outdim() {
let out_dim = 4099; let in_dim = 2048;
let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, out_dim / 2, out_dim - 1] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q4k_parallel_dispatch_few_rows_large_indim() {
let out_dim = 2;
let in_dim = 4_194_304; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for i in 0..out_dim {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q4k_parallel_dispatch_exact_threshold() {
let out_dim = 31_250;
let in_dim = 256; let total_work = out_dim * in_dim;
assert_eq!(total_work, 8_000_000);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, 100, 10_000, 31_249] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q4k_dispatch_just_below_threshold() {
let out_dim = 31_249;
let in_dim = 256;
let total_work = out_dim * in_dim;
assert!(total_work < 8_000_000);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, 100, 15_000, 31_248] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q4k_parallel_dispatch_single_row() {
let out_dim = 1;
let in_dim = 8_388_608; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), 1);
let diff = (scalar[0] - dispatch[0]).abs();
let tol = scalar[0].abs() * 2e-4 + 1e-4;
assert!(diff < tol, "scalar={}, dispatch={}, diff={}", scalar[0], dispatch[0], diff);
}
#[test]
fn test_q4k_parallel_dispatch_zero_input() {
let out_dim = 4096;
let in_dim = 2048;
assert!(out_dim * in_dim >= 8_000_000);
let q4k_data = build_q4k_test_data(out_dim, in_dim);
let input = vec![0.0f32; in_dim];
let dispatch = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for (i, &val) in dispatch.iter().enumerate() {
assert_eq!(val, 0.0, "Row {}: expected 0.0 with zero input, got {}", i, val);
}
}