#[test]
fn test_fused_q4k_q8k_dot_error_messages() {
let data = vec![0u8; 100];
let err = fused_q4k_q8k_dot(&data, &[1.0], &[1; 256]).unwrap_err();
assert!(err.to_string().contains("not a multiple"));
let data = vec![0u8; 144];
let err = fused_q4k_q8k_dot(&data, &[], &[1; 256]).unwrap_err();
assert!(err.to_string().contains("scales"));
let err = fused_q4k_q8k_dot(&data, &[1.0], &[1; 100]).unwrap_err();
assert!(err.to_string().contains("quants"));
}
#[test]
fn test_fused_q4k_q8k_dot_simd_error_paths() {
let err = fused_q4k_q8k_dot_simd(&[0u8; 100], &[1.0], &[1i8; 256]).unwrap_err();
assert!(err.to_string().contains("not a multiple"));
}
#[test]
fn test_fused_q4k_dot_simd_error_paths() {
let data = vec![0u8; 144];
let activations = vec![0.0f32; 100];
let err = fused_q4k_dot_simd(&data, &activations).unwrap_err();
assert!(err.to_string().contains("doesn't match"));
}
#[test]
fn test_fused_q4k_dot_packed_scale_blocks() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[4] = 0b1100_0000; data[12] = 0b0010_0101;
for i in 64..96 {
data[16 + i] = 0x22; }
let activations = vec![1.0f32; 256];
let result = fused_q4k_dot(&data, &activations).expect("should succeed");
assert!(
(result - 3392.0).abs() < 1.0,
"Expected about 3392.0, got {}",
result
);
}
#[test]
fn test_fused_q4k_dot_sign_reversal() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[4] = 1;
for i in 0..128 {
data[16 + i] = 0x55;
}
let pos_act = vec![1.0f32; 256];
let neg_act = vec![-1.0f32; 256];
let pos_result = fused_q4k_dot(&data, &pos_act).expect("pos");
let neg_result = fused_q4k_dot(&data, &neg_act).expect("neg");
assert!(
(pos_result + neg_result).abs() < 0.01,
"Negating activations should negate result: {} vs {}",
pos_result,
neg_result
);
}
#[test]
fn test_fused_q4k_q8k_dot_sign_reversal() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[4] = 1;
for i in 0..128 {
data[16 + i] = 0x55;
}
let q8k_scales = vec![1.0f32];
let pos_quants = vec![10i8; 256];
let neg_quants = vec![-10i8; 256];
let pos_result = fused_q4k_q8k_dot(&data, &q8k_scales, &pos_quants).expect("pos");
let neg_result = fused_q4k_q8k_dot(&data, &q8k_scales, &neg_quants).expect("neg");
assert!(
(pos_result + neg_result).abs() < 1.0,
"Negating quants should negate result: {} vs {}",
pos_result,
neg_result
);
}
#[test]
fn test_fused_q4k_dot_simd_16_super_blocks() {
let mut data = vec![0u8; 16 * 144];
for sb in 0..16 {
let offset = sb * 144;
data[offset..offset + 2].copy_from_slice(&0x2E66u16.to_le_bytes());
data[offset + 2..offset + 4].copy_from_slice(&0x2800u16.to_le_bytes());
for i in 0..12 {
data[offset + 4 + i] = ((sb + i * 5 + 1) % 63) as u8;
}
for i in 0..128 {
data[offset + 16 + i] = ((sb * 37 + i * 23 + 5) % 256) as u8;
}
}
let activations: Vec<f32> = (0..4096)
.map(|i| ((i * 7 + 3) % 200) as f32 * 0.005 - 0.5)
.collect();
let scalar = fused_q4k_dot(&data, &activations).expect("scalar");
let simd = fused_q4k_dot_simd(&data, &activations).expect("simd");
let rel_err = if scalar.abs() > 1e-6 {
(simd - scalar).abs() / scalar.abs()
} else {
(simd - scalar).abs()
};
assert!(
rel_err < 0.01,
"16-superblock parity: scalar={}, simd={}, rel_err={}",
scalar,
simd,
rel_err
);
}
#[test]
fn test_fused_q4k_q8k_dot_empty() {
let result = fused_q4k_q8k_dot(&[], &[], &[]).expect("empty should work");
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_q8k_dot_simd_empty() {
let result = fused_q4k_q8k_dot_simd(&[], &[], &[]).expect("empty should work");
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_dot_simd_empty() {
let result = fused_q4k_dot_simd(&[], &[]).expect("empty should work");
assert_eq!(result, 0.0);
}
#[test]
fn test_q4k_layout_consistency_pmat170() {
use crate::apr::dequantize_q4_k;
use crate::quantize::fused_q4k_parallel_matvec;
let in_dim = 256;
let out_dim = 256;
let num_elements = in_dim * out_dim;
let bytes_per_row = 144;
let total_bytes = out_dim * bytes_per_row;
let q4k_bytes: Vec<u8> = (0..total_bytes)
.map(|i| ((i * 17 + 37) % 256) as u8)
.collect();
let dequant = dequantize_q4_k(&q4k_bytes, num_elements);
let mut fused_matrix = vec![0.0f32; num_elements];
for col in 0..in_dim {
let mut basis = vec![0.0f32; in_dim];
basis[col] = 1.0;
if let Ok(column) = fused_q4k_parallel_matvec(&q4k_bytes, &basis, in_dim, out_dim) {
for row in 0..out_dim {
fused_matrix[row * in_dim + col] = column[row];
}
}
}
let mut mismatches = 0;
let mut max_rel_err = 0.0f32;
for i in 0..num_elements {
let diff = (dequant[i] - fused_matrix[i]).abs();
let rel_err = if dequant[i].abs() > 1e-6 {
diff / dequant[i].abs()
} else {
diff
};
if rel_err > 0.01 {
mismatches += 1;
max_rel_err = max_rel_err.max(rel_err);
}
}
assert_eq!(
mismatches, 0,
"Q4K layout mismatch: {} elements differ (max rel_err: {:.4}%). \
This indicates dequantize_q4_k has different element ordering \
than fused_q4k_parallel_matvec, which would cause GPU explosion.",
mismatches,
max_rel_err * 100.0
);
}
fn build_q4k_test_block(d: f32, dmin: f32, nibble_val: u8) -> [u8; 144] {
let mut block = [0u8; 144];
let d_bits = half::f16::from_f32(d).to_bits();
block[0..2].copy_from_slice(&d_bits.to_le_bytes());
let dmin_bits = half::f16::from_f32(dmin).to_bits();
block[2..4].copy_from_slice(&dmin_bits.to_le_bytes());
for i in 0..12 {
block[4 + i] = 0x01; }
let packed = (nibble_val & 0x0F) | ((nibble_val & 0x0F) << 4);
for i in 0..128 {
block[16 + i] = packed;
}
block
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_parity_with_scalar() {
if !is_x86_feature_detected!("avx2") {
return; }
let block = build_q4k_test_block(1.0, 0.0, 3);
let q4k_data = block.to_vec();
let q8k_scales = vec![1.0f32];
let q8k_quants = vec![1i8; 256];
let scalar = fused_q4k_q8k_dot(&q4k_data, &q8k_scales, &q8k_quants).expect("scalar");
let avx2 = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) }.expect("avx2");
let diff = (scalar - avx2).abs();
assert!(
diff < 1.0,
"scalar={scalar} vs avx2={avx2}, diff={diff} exceeds tolerance"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_zero_quants() {
if !is_x86_feature_detected!("avx2") {
return;
}
let block = build_q4k_test_block(1.0, 0.0, 0);
let q4k_data = block.to_vec();
let q8k_scales = vec![1.0f32];
let q8k_quants = vec![0i8; 256];
let result = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) }.expect("result");
assert!(
result.abs() < 1e-6,
"zero × zero should produce ~0, got {result}"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_multi_superblock() {
if !is_x86_feature_detected!("avx2") {
return;
}
let block = build_q4k_test_block(1.0, 0.0, 5);
let mut q4k_data = Vec::with_capacity(144 * 4);
for _ in 0..4 {
q4k_data.extend_from_slice(&block);
}
let q8k_scales = vec![1.0f32; 4];
let q8k_quants = vec![2i8; 256 * 4];
let scalar = fused_q4k_q8k_dot(&q4k_data, &q8k_scales, &q8k_quants).expect("scalar");
let avx2 = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) }.expect("avx2");
let diff = (scalar - avx2).abs();
let rel_tolerance = scalar.abs().max(1.0) * 0.01;
assert!(
diff < rel_tolerance,
"4-block: scalar={scalar} vs avx2={avx2}, diff={diff}"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_negative_quants() {
if !is_x86_feature_detected!("avx2") {
return;
}
let block = build_q4k_test_block(1.0, 0.0, 7);
let q4k_data = block.to_vec();
let q8k_scales = vec![1.0f32];
let q8k_quants = vec![-3i8; 256];
let scalar = fused_q4k_q8k_dot(&q4k_data, &q8k_scales, &q8k_quants).expect("scalar");
let avx2 = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) }.expect("avx2");
let diff = (scalar - avx2).abs();
let rel_tolerance = scalar.abs().max(1.0) * 0.01;
assert!(
diff < rel_tolerance,
"neg quants: scalar={scalar} vs avx2={avx2}, diff={diff}"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_with_dmin() {
if !is_x86_feature_detected!("avx2") {
return;
}
let block = build_q4k_test_block(1.0, 0.5, 4);
let q4k_data = block.to_vec();
let q8k_scales = vec![2.0f32];
let q8k_quants = vec![5i8; 256];
let scalar = fused_q4k_q8k_dot(&q4k_data, &q8k_scales, &q8k_quants).expect("scalar");
let avx2 = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) }.expect("avx2");
let diff = (scalar - avx2).abs();
let rel_tolerance = scalar.abs().max(1.0) * 0.05;
assert!(
diff < rel_tolerance,
"dmin: scalar={scalar} vs avx2={avx2}, diff={diff}"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_invalid_data_length() {
if !is_x86_feature_detected!("avx2") {
return;
}
let q4k_data = vec![0u8; 100]; let q8k_scales = vec![1.0f32];
let q8k_quants = vec![1i8; 256];
let result = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) };
assert!(result.is_err(), "should fail for non-144-aligned data");
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_q4k_q8k_dot_buffer_too_small() {
if !is_x86_feature_detected!("avx2") {
return;
}
let block = build_q4k_test_block(1.0, 0.0, 1);
let q4k_data = block.to_vec();
let q8k_scales = vec![1.0f32];
let q8k_quants = vec![1i8; 128];
let result = unsafe { fused_q4k_q8k_dot_avx2(&q4k_data, &q8k_scales, &q8k_quants) };
assert!(result.is_err(), "should fail for too-small Q8K buffer");
}