use super::super::*;
fn build_q6k_test_data(out_dim: usize, in_dim: usize) -> Vec<u8> {
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let total_bytes = out_dim * row_bytes;
let mut data = vec![0u8; total_bytes];
for row in 0..out_dim {
for sb in 0..num_blocks_per_row {
let offset = row * row_bytes + sb * SUPER_BLOCK_BYTES;
for i in 0..128 {
data[offset + i] = ((row + sb + i + 1) % 256) as u8;
}
for i in 0..64 {
data[offset + 128 + i] = ((row ^ sb ^ i) % 4) as u8;
}
for i in 0..16 {
data[offset + 192 + i] = ((row + sb + i) % 64 + 1) as u8;
}
data[offset + 208] = 0x00;
data[offset + 209] = 0x38;
}
}
data
}
fn build_input(in_dim: usize) -> Vec<f32> {
(0..in_dim).map(|i| (i as f32 * 0.00137).sin()).collect()
}
#[test]
fn test_q6k_parallel_dispatch_matches_scalar() {
let out_dim = 4096;
let in_dim = 2048; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000, "Must trigger parallel path");
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(scalar.len(), out_dim);
assert_eq!(dispatch.len(), out_dim);
let check_indices: Vec<usize> =
(0..out_dim).step_by(64).chain(std::iter::once(out_dim - 1)).collect();
for &i in &check_indices {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q6k_parallel_dispatch_prime_outdim() {
let out_dim = 4099; let in_dim = 2048;
let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, out_dim / 2, out_dim - 1] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q6k_parallel_dispatch_few_rows_large_indim() {
let out_dim = 2;
let in_dim = 4_194_304; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for i in 0..out_dim {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q6k_parallel_dispatch_exact_threshold() {
let out_dim = 31_250;
let in_dim = 256; let total_work = out_dim * in_dim;
assert_eq!(total_work, 8_000_000);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, 100, 10_000, 31_249] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q6k_dispatch_just_below_threshold() {
let out_dim = 31_249;
let in_dim = 256;
let total_work = out_dim * in_dim;
assert!(total_work < 8_000_000);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for &i in &[0, 100, 15_000, 31_248] {
let diff = (scalar[i] - dispatch[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, dispatch={}, diff={}",
i,
scalar[i],
dispatch[i],
diff
);
}
}
#[test]
fn test_q6k_parallel_dispatch_single_row() {
let out_dim = 1;
let in_dim = 8_388_608; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000);
assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), 1);
let diff = (scalar[0] - dispatch[0]).abs();
let tol = scalar[0].abs() * 1e-4 + 1e-4;
assert!(diff < tol, "scalar={}, dispatch={}, diff={}", scalar[0], dispatch[0], diff);
}
#[test]
fn test_q6k_parallel_dispatch_zero_input() {
let out_dim = 4096;
let in_dim = 2048;
assert!(out_dim * in_dim >= 8_000_000);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = vec![0.0f32; in_dim];
let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(dispatch.len(), out_dim);
for (i, &val) in dispatch.iter().enumerate() {
assert_eq!(val, 0.0, "Row {}: expected 0.0 with zero input, got {}", i, val);
}
}
#[test]
fn test_q6k_public_api_parallel_route() {
let out_dim = 4096;
let in_dim = 2048;
assert!(out_dim * in_dim >= 8_000_000);
let q6k_data = build_q6k_test_data(out_dim, in_dim);
let input = build_input(in_dim);
let result = matmul_q6k_f32(&q6k_data, &input, out_dim, in_dim);
let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
assert_eq!(result.len(), out_dim);
for &i in &[0, 1000, 4095] {
let diff = (scalar[i] - result[i]).abs();
let tol = scalar[i].abs() * 1e-4 + 1e-4;
assert!(
diff < tol,
"Row {}: scalar={}, public_api={}, diff={}",
i,
scalar[i],
result[i],
diff
);
}
}