use super::super::*;
#[test]
fn test_fused_q4k_golden_parity() {
let in_dim = 256;
let out_dim = 4;
let num_blocks = 1;
let mut q4k_data = Vec::with_capacity(out_dim * num_blocks * SUPER_BLOCK_BYTES);
for row in 0..out_dim {
let d: u16 = 0x2E66; let dmin: u16 = 0x2A66; q4k_data.extend_from_slice(&d.to_le_bytes());
q4k_data.extend_from_slice(&dmin.to_le_bytes());
let scales_packed = [0x01u8; 12];
q4k_data.extend_from_slice(&scales_packed);
let mut qs = [0u8; 128];
for (i, q) in qs.iter_mut().enumerate() {
let low = ((row + i) % 16) as u8;
let high = ((row + i + 1) % 16) as u8;
*q = low | (high << 4);
}
q4k_data.extend_from_slice(&qs);
}
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.01).collect();
let fused_output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
let mut reference_output = vec![0.0f32; out_dim];
for row in 0..out_dim {
let row_start = row * SUPER_BLOCK_BYTES;
let row_q4k = &q4k_data[row_start..row_start + SUPER_BLOCK_BYTES];
let f32_weights = dequantize_q4k_to_f32(row_q4k, in_dim);
let mut sum = 0.0f32;
for (w, x) in f32_weights.iter().zip(input.iter()) {
sum += w * x;
}
reference_output[row] = sum;
}
for (i, (fused, reference)) in fused_output.iter().zip(reference_output.iter()).enumerate() {
let diff = (fused - reference).abs();
assert!(
diff < 1e-3,
"Row {}: Fused kernel divergence: {} vs {} (d={})",
i,
fused,
reference,
diff
);
}
}
#[test]
fn test_scalar_vs_optimized_parity() {
let in_dim = 256;
let out_dim = 2;
let mut q4k_data = Vec::new();
for _ in 0..out_dim {
q4k_data.extend_from_slice(&[0x00, 0x3C]); q4k_data.extend_from_slice(&[0x00, 0x00]); q4k_data.extend_from_slice(&[0x01u8; 12]); q4k_data.extend_from_slice(&[0x55u8; 128]); }
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.001).collect();
let scalar_output = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let optimized_output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
for (i, (s, o)) in scalar_output.iter().zip(optimized_output.iter()).enumerate() {
let diff = (s - o).abs();
assert!(
diff < 1e-4,
"Row {}: Scalar vs optimized divergence: {} vs {} (d={})",
i,
s,
o,
diff
);
}
}
#[test]
fn test_no_nan_inf() {
let in_dim = 256;
let out_dim = 4;
let mut q4k_data = Vec::new();
for _ in 0..out_dim {
q4k_data.extend_from_slice(&[0x00, 0x3C]); q4k_data.extend_from_slice(&[0x00, 0x38]); q4k_data.extend_from_slice(&[0x3Fu8; 12]); q4k_data.extend_from_slice(&[0xFFu8; 128]); }
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.01).collect();
let output = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
for (i, &val) in output.iter().enumerate() {
assert!(val.is_finite(), "Row {}: Output is not finite: {}", i, val);
}
}
#[test]
fn test_determinism() {
let in_dim = 256;
let out_dim = 2;
let mut q4k_data = Vec::new();
for _ in 0..out_dim {
q4k_data.extend_from_slice(&[0x66, 0x2E]); q4k_data.extend_from_slice(&[0x66, 0x2A]); q4k_data.extend_from_slice(&[0x15u8; 12]);
q4k_data.extend_from_slice(&[0xABu8; 128]);
}
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.005).collect();
let output1 = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
let output2 = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
for (i, (a, b)) in output1.iter().zip(output2.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "Row {}: Non-deterministic output: {} vs {}", i, a, b);
}
}