use super::super::*;
#[test]
fn test_tiled_q4k_matvec_creation() {
let matvec = TiledQ4KMatvec::new(4096, 4096);
assert_eq!(matvec.m, 4096);
assert_eq!(matvec.k, 4096);
assert_eq!(matvec.superblocks_per_row(), 16); assert_eq!(matvec.total_superblocks(), 4096 * 16);
}
#[test]
#[should_panic(expected = "K dimension")]
fn test_tiled_q4k_matvec_unaligned_k() {
let _ = TiledQ4KMatvec::new(4096, 100); }
#[test]
fn test_tiled_q4k_matvec_weight_offset() {
let matvec = TiledQ4KMatvec::new(100, 512);
assert_eq!(matvec.weight_row_offset(0), 0);
assert_eq!(matvec.weight_row_offset(1), 2 * Q4K_SUPERBLOCK_BYTES);
}
#[test]
fn test_tiled_q4k_matvec_optimal_rows() {
let matvec = TiledQ4KMatvec::new(4096, 4096);
let rows = matvec.optimal_parallel_rows(256 * 1024);
assert!(rows >= 4); assert!(rows <= 4096); }
#[test]
fn test_tiled_q4k_matvec_stats() {
let matvec = TiledQ4KMatvec::new(4096, 4096);
let stats = matvec.stats();
assert_eq!(stats.superblocks, 4096 * 16);
assert_eq!(stats.arithmetic_ops, 4096 * 4096 * 2);
assert!(stats.arithmetic_intensity > 1.0);
}
#[test]
fn test_q4k_constants() {
assert_eq!(Q4K_SUPERBLOCK_SIZE, 256);
assert_eq!(Q4K_SUPERBLOCK_BYTES, 144);
}
#[test]
fn test_k_equals_superblock() {
let matvec = TiledQ4KMatvec::new(100, 256);
assert_eq!(matvec.superblocks_per_row(), 1);
assert_eq!(matvec.total_superblocks(), 100);
}
#[test]
fn test_large_m_dimension() {
let matvec = TiledQ4KMatvec::new(100_000, 256);
assert_eq!(matvec.superblocks_per_row(), 1);
assert_eq!(matvec.total_superblocks(), 100_000);
let rows = matvec.optimal_parallel_rows(256 * 1024);
assert!(rows >= 4);
}
#[test]
fn test_large_k_dimension() {
let matvec = TiledQ4KMatvec::new(10, 32768); assert_eq!(matvec.superblocks_per_row(), 128);
let stats = matvec.stats();
assert!(stats.arithmetic_intensity > 0.0);
}
#[test]
fn test_tiling_stats_complete() {
let matvec = TiledQ4KMatvec::new(100, 512);
let stats = matvec.stats();
assert_eq!(stats.input_bytes, 512 * 4);
assert_eq!(stats.output_bytes, 100 * 4);
assert_eq!(stats.superblocks, 100 * 2); assert!(stats.total_weight_bytes > 0);
}
#[test]
fn test_f16_conversion() {
assert_eq!(f16_to_f32(&[0x00, 0x00]), 0.0);
let one = f16_to_f32(&[0x00, 0x3C]);
assert!((one - 1.0).abs() < 0.001);
let neg_one = f16_to_f32(&[0x00, 0xBC]);
assert!((neg_one - (-1.0)).abs() < 0.001);
assert!(f16_to_f32(&[0x00, 0x7C]).is_infinite());
assert!(f16_to_f32(&[0x01, 0x7C]).is_nan());
}
#[test]
fn test_f16_subnormal() {
let subnormal = f16_to_f32(&[0x01, 0x00]);
assert!(subnormal > 0.0);
assert!(subnormal < 0.001);
let neg_zero = f16_to_f32(&[0x00, 0x80]);
assert_eq!(neg_zero, -0.0);
assert!(neg_zero.is_sign_negative());
let neg_inf = f16_to_f32(&[0x00, 0xFC]);
assert!(neg_inf.is_infinite());
assert!(neg_inf.is_sign_negative());
}
#[test]
fn test_extract_scale_min_6bit() {
let scales = [0x3F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let (sc, mn) = extract_scale_min_6bit(&scales, 0);
assert_eq!(sc, 63.0);
assert_eq!(mn, 0.0);
let scales2 = [0x00, 0x2A, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let (sc1, mn1) = extract_scale_min_6bit(&scales2, 1);
assert_eq!(sc1, 42.0); assert_eq!(mn1, 21.0);
let scales3 = [0xC0, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00];
let (sc4, mn4) = extract_scale_min_6bit(&scales3, 4);
assert_eq!(sc4, 55.0);
assert_eq!(mn4, 41.0);
let scales4 = [0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0xC0, 0x00, 0x00, 0x00, 0xFE];
let (sc7, mn7) = extract_scale_min_6bit(&scales4, 7);
assert_eq!(sc7, 30.0);
assert_eq!(mn7, 63.0);
}
#[test]
fn test_extract_scale_min_matches_reference() {
let header: [u8; 16] = [
0x00, 0x3C, 0x00, 0x3C, 0xC5, 0x8A, 0x4F, 0xD4,
0x91, 0xD6, 0x63, 0xAB, 0x37, 0xB9, 0x2C, 0xE5,
];
let scales = &header[4..16];
for i in 0..4u8 {
let (sc, mn) = extract_scale_min_6bit(scales, i as usize);
let expected_sc = (scales[i as usize] & 0x3F) as f32;
let expected_mn = (scales[4 + i as usize] & 0x3F) as f32;
assert_eq!(sc, expected_sc, "SB {i} scale mismatch");
assert_eq!(mn, expected_mn, "SB {i} min mismatch");
}
for i in 0..4u8 {
let (sc, mn) = extract_scale_min_6bit(scales, (4 + i) as usize);
let combo = scales[8 + i as usize];
let expected_sc = ((combo & 0x0F) | ((scales[i as usize] >> 6) << 4)) as f32;
let expected_mn = (((combo >> 4) & 0x0F) | ((scales[4 + i as usize] >> 6) << 4)) as f32;
assert_eq!(sc, expected_sc, "SB {} scale mismatch", 4 + i);
assert_eq!(mn, expected_mn, "SB {} min mismatch", 4 + i);
}
}
#[test]
fn test_execute_scalar() {
let matvec = TiledQ4KMatvec::new(2, 256);
let mut weights = vec![0u8; 2 * Q4K_SUPERBLOCK_BYTES];
weights[0] = 0x00;
weights[1] = 0x3C;
weights[2] = 0x00;
weights[3] = 0x00;
let offset = Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x3C;
let input = vec![1.0f32; 256];
let mut output = vec![0.0f32; 2];
matvec.execute_scalar(&weights, &input, &mut output);
assert!(output[0].is_finite());
assert!(output[1].is_finite());
}
#[test]
fn test_scalar_dot_matches_dequantize_oracle() {
use crate::backends::q4k::dequantize_q4k_to_f32;
let mut sb = vec![0u8; 144];
sb[0] = 0x00;
sb[1] = 0x38; sb[2] = 0x00;
sb[3] = 0x34;
sb[4] = 0xCA; sb[5] = 0x94; sb[6] = 0x5E; sb[7] = 0x28; sb[8] = 0x45; sb[9] = 0x8F; sb[10] = 0xD9; sb[11] = 0x23; sb[12] = 0x35; sb[13] = 0x7A; sb[14] = 0x2F; sb[15] = 0x83;
for i in 0..128 {
let lo = ((i * 7 + 3) % 16) as u8;
let hi = ((i * 11 + 5) % 16) as u8;
sb[16 + i] = lo | (hi << 4);
}
let dequant = dequantize_q4k_to_f32(&sb, 256);
let mut input = vec![0.0f32; 256];
for i in 0..256 {
input[i] = (i as f32 * 0.01) - 1.28; }
let expected: f32 = dequant.iter().zip(input.iter()).map(|(w, x)| w * x).sum();
let matvec = TiledQ4KMatvec::new(1, 256);
let mut output = vec![0.0f32; 1];
matvec.execute_scalar(&sb, &input, &mut output);
let actual = output[0];
let rel_err = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_err < 1e-5,
"GH-182: scalar_superblock_dot diverges from dequantize oracle!\n\
expected={expected}, actual={actual}, rel_err={rel_err}\n\
This indicates scale extraction or qs addressing mismatch."
);
}