use super::gemv::compute_chunk_q4k_scalar;
use super::*;
fn matmul_f32_naive(weights: &[f32], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
let mut output = vec![0.0f32; out_dim];
for row in 0..out_dim {
let mut sum = 0.0f32;
for col in 0..in_dim {
sum += weights[row * in_dim + col] * input[col];
}
output[row] = sum;
}
output
}
#[test]
fn test_golden_vector_q4k_matmul_vs_dequant() {
use crate::backends::q4k::dequantize_q4k_to_f32;
let in_dim = 512; let out_dim = 8;
let mut q4k_data = Vec::new();
for row in 0..out_dim {
for sb in 0..(in_dim / 256) {
q4k_data.extend_from_slice(&[0x66, 0x2E]); q4k_data.extend_from_slice(&[0x66, 0x2A]); let scale_base = ((row * 7 + sb * 3) % 16) as u8;
for i in 0..12 {
q4k_data.push(scale_base + (i as u8 % 4));
}
for i in 0..128 {
let low = ((row + sb + i) % 16) as u8;
let high = ((row + sb + i + 5) % 16) as u8;
q4k_data.push(low | (high << 4));
}
}
}
let input: Vec<f32> = (0..in_dim).map(|i| ((i as f32) * 0.017).sin() * 0.5).collect();
let fused_output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
let total_elements = in_dim * out_dim;
let dequantized_weights = dequantize_q4k_to_f32(&q4k_data, total_elements);
let reference_output = matmul_f32_naive(&dequantized_weights, &input, out_dim, in_dim);
assert_eq!(fused_output.len(), reference_output.len());
let mut max_rel_error = 0.0f32;
let mut max_abs_error = 0.0f32;
for (i, (fused, reference)) in fused_output.iter().zip(reference_output.iter()).enumerate() {
let abs_error = (fused - reference).abs();
let rel_error =
if reference.abs() > 1e-6 { abs_error / reference.abs() } else { abs_error };
max_rel_error = max_rel_error.max(rel_error);
max_abs_error = max_abs_error.max(abs_error);
assert!(
rel_error < 0.05 || abs_error < 0.01,
"Golden invariant violated at row {}: fused={}, reference={}, \
rel_error={:.4}%, abs_error={:.6}",
i,
fused,
reference,
rel_error * 100.0,
abs_error
);
}
eprintln!(
"[Golden Q4K Test] max_rel_error={:.4}%, max_abs_error={:.6}",
max_rel_error * 100.0,
max_abs_error
);
}
#[test]
fn test_golden_vector_q4k_dispatch_vs_dequant() {
use crate::backends::q4k::dequantize_q4k_to_f32;
let in_dim = 1024; let out_dim = 16;
let mut q4k_data = Vec::new();
for row in 0..out_dim {
for sb in 0..(in_dim / 256) {
q4k_data.extend_from_slice(&[0x00, 0x3C]); q4k_data.extend_from_slice(&[0x00, 0x38]); for i in 0..12 {
q4k_data.push(((row + sb + i) % 64) as u8);
}
for i in 0..128 {
let low = ((row * 3 + sb * 7 + i) % 16) as u8;
let high = ((row * 5 + sb * 11 + i * 2) % 16) as u8;
q4k_data.push(low | (high << 4));
}
}
}
let input: Vec<f32> = (0..in_dim).map(|i| ((i as f32) * 0.013 + 0.5).cos() * 0.3).collect();
let dispatch_output = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
let total_elements = in_dim * out_dim;
let dequantized = dequantize_q4k_to_f32(&q4k_data, total_elements);
let reference_output = matmul_f32_naive(&dequantized, &input, out_dim, in_dim);
let mut max_rel_error = 0.0f32;
for (i, (dispatch, reference)) in
dispatch_output.iter().zip(reference_output.iter()).enumerate()
{
let abs_error = (dispatch - reference).abs();
let rel_error =
if reference.abs() > 1e-6 { abs_error / reference.abs() } else { abs_error };
max_rel_error = max_rel_error.max(rel_error);
assert!(
rel_error < 0.05 || abs_error < 0.01,
"Golden invariant violated (dispatch) at row {}: \
dispatch={}, reference={}, rel_error={:.4}%",
i,
dispatch,
reference,
rel_error * 100.0
);
}
eprintln!("[Golden Q4K Dispatch Test] max_rel_error={:.4}%", max_rel_error * 100.0);
}
#[test]
fn test_golden_vector_zero_input() {
let in_dim = 256;
let out_dim = 4;
let mut q4k_data = Vec::new();
for _row in 0..out_dim {
q4k_data.extend_from_slice(&[0x66, 0x2E]); q4k_data.extend_from_slice(&[0x00, 0x00]); q4k_data.extend_from_slice(&[0x01u8; 12]);
q4k_data.extend_from_slice(&[0x55u8; 128]); }
let input = vec![0.0f32; in_dim];
let output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
for (i, val) in output.iter().enumerate() {
assert!(val.abs() < 1e-6, "Zero input should give ~zero output, got {} at row {}", val, i);
}
}
#[test]
fn test_golden_vector_uniform_input() {
use crate::backends::q4k::dequantize_q4k_to_f32;
let in_dim = 256;
let out_dim = 2;
let mut q4k_data = Vec::new();
for row in 0..out_dim {
q4k_data.extend_from_slice(&[0x00, 0x3C]); q4k_data.extend_from_slice(&[0x00, 0x00]); q4k_data.extend_from_slice(&[0x01u8; 12]);
q4k_data.extend_from_slice(&[((row + 1) * 0x11) as u8; 128]);
}
let input = vec![1.0f32; in_dim]; let fused_output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
let total_elements = in_dim * out_dim;
let dequantized = dequantize_q4k_to_f32(&q4k_data, total_elements);
let reference_output = matmul_f32_naive(&dequantized, &input, out_dim, in_dim);
for (i, (fused, reference)) in fused_output.iter().zip(reference_output.iter()).enumerate() {
let rel_error = if reference.abs() > 1e-6 {
(fused - reference).abs() / reference.abs()
} else {
(fused - reference).abs()
};
assert!(
rel_error < 0.05,
"Uniform input failed at row {}: fused={}, ref={}, err={:.2}%",
i,
fused,
reference,
rel_error * 100.0
);
}
}
#[test]
fn test_parallel_dispatch_large_matrix() {
let out_dim = 4096;
let in_dim = 2048; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000, "Test must trigger parallel path");
let num_superblocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_superblocks_per_row * SUPER_BLOCK_BYTES;
let total_bytes = out_dim * row_bytes;
let mut q4k_data = vec![0u8; total_bytes];
for row in 0..out_dim {
for sb in 0..num_superblocks_per_row {
let offset = row * row_bytes + sb * SUPER_BLOCK_BYTES;
q4k_data[offset] = 0x00;
q4k_data[offset + 1] = 0x3C;
q4k_data[offset + 2] = 0x00;
q4k_data[offset + 3] = 0x00;
for i in 0..12 {
q4k_data[offset + 4 + i] = 0x01;
}
for i in 0..128 {
q4k_data[offset + 16 + i] = ((row + sb + i) % 16) as u8;
}
}
}
let input: Vec<f32> = (0..in_dim).map(|i| (i % 10) as f32 * 0.1).collect();
let result = matmul_q4k_f32_dispatch(&q4k_data, &input, out_dim, in_dim);
assert_eq!(result.len(), out_dim);
for (i, &val) in result.iter().enumerate() {
assert!(val.is_finite(), "Result[{}] is not finite: {}", i, val);
}
let scalar_result = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
for i in (0..out_dim).step_by(512) {
let diff = (result[i] - scalar_result[i]).abs();
let tol = scalar_result[i].abs() * 0.01 + 1e-5;
assert!(
diff < tol,
"Parallel vs scalar mismatch at row {}: parallel={}, scalar={}, diff={}",
i,
result[i],
scalar_result[i],
diff
);
}
}
#[test]
#[allow(deprecated)]
fn test_parallel_colmajor_large_matrix() {
let ne0 = 2048; let ne1 = 4096;
let blocks_per_col = (ne0 + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let col_bytes = blocks_per_col * SUPER_BLOCK_BYTES;
let total_bytes = ne1 * col_bytes;
let mut q4k_data = vec![0u8; total_bytes];
for col in 0..ne1 {
for sb in 0..blocks_per_col {
let offset = col * col_bytes + sb * SUPER_BLOCK_BYTES;
q4k_data[offset] = 0x00;
q4k_data[offset + 1] = 0x38;
q4k_data[offset + 2] = 0x00;
q4k_data[offset + 3] = 0x00;
for i in 0..12 {
q4k_data[offset + 4 + i] = 0x02;
}
for i in 0..128 {
q4k_data[offset + 16 + i] = ((col ^ sb ^ i) % 16) as u8;
}
}
}
let input: Vec<f32> = (0..ne1).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let result = matmul_q4k_f32_colmajor_dispatch(&q4k_data, &input, ne0, ne1);
assert_eq!(result.len(), ne0);
for (i, &val) in result.iter().enumerate() {
assert!(val.is_finite(), "Result[{}] is not finite: {}", i, val);
}
}
#[test]
fn test_compute_chunk_scalar_small() {
let in_dim = 256;
let out_dim = 4;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q4k_data = vec![0u8; out_dim * row_bytes];
for row in 0..out_dim {
let offset = row * row_bytes;
q4k_data[offset] = 0x00;
q4k_data[offset + 1] = 0x3C;
q4k_data[offset + 2] = 0x00;
q4k_data[offset + 3] = 0x00;
for i in 0..12 {
q4k_data[offset + 4 + i] = 0x01;
}
for i in 0..128 {
q4k_data[offset + 16 + i] = 0x00;
}
}
let input = vec![1.0f32; in_dim];
let mut chunk = vec![0.0f32; out_dim];
compute_chunk_q4k_scalar(
&q4k_data,
&input,
&mut chunk,
0,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for (i, &val) in chunk.iter().enumerate() {
assert!(val.is_finite(), "Chunk[{}] is not finite: {}", i, val);
}
}