use super::*;
#[test]
fn test_fused_q4k_dot_invalid_data_length() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not a multiple"));
}
#[test]
fn test_fused_q4k_dot_activation_length_mismatch() {
let data = vec![0u8; 144];
let activations = vec![0.0f32; 128];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("doesn't match"));
}
#[test]
fn test_fused_q4k_dot_zero_data() {
let data = vec![0u8; 144];
let activations = vec![1.0f32; 256];
let result = fused_q4k_dot(&data, &activations).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_dot_single_super_block() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..12 {
data[4 + i] = 0x11;
}
for i in 0..128 {
data[16 + i] = 0x55; }
let activations = vec![1.0f32; 256];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_dot_multiple_super_blocks() {
let mut data = vec![0u8; 288];
data[0..2].copy_from_slice(&0x3800u16.to_le_bytes());
data[144..146].copy_from_slice(&0x3800u16.to_le_bytes());
let activations = vec![0.5f32; 512];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_dot_empty_data() {
let data = vec![];
let activations = vec![];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.0);
}
#[test]
fn test_fused_q4k_dot_simd_invalid_input() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q4k_dot_simd(&data, &activations);
assert!(result.is_err());
}
#[test]
fn test_fused_q4k_dot_simd_matches_scalar() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..12 {
data[4 + i] = 0x11;
}
for i in 0..128 {
data[16 + i] = ((i * 3) % 256) as u8;
}
let activations: Vec<f32> = (0..256).map(|i| (i as f32) * 0.01).collect();
let scalar_result = fused_q4k_dot(&data, &activations).unwrap();
let simd_result = fused_q4k_dot_simd(&data, &activations).unwrap();
let rel_err = if scalar_result.abs() > 1e-6 {
(simd_result - scalar_result).abs() / scalar_result.abs()
} else {
(simd_result - scalar_result).abs()
};
assert!(
rel_err < 0.001,
"scalar={} simd={} rel_err={}",
scalar_result,
simd_result,
rel_err
);
}
#[test]
fn test_fused_q4k_dot_simd_zero_activations() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
let activations = vec![0.0f32; 256];
let result = fused_q4k_dot_simd(&data, &activations).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_q8k_dot_invalid_data_length() {
let data = vec![0u8; 100]; let q8k_scales = vec![1.0f32; 1];
let q8k_quants = vec![10i8; 256];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_err());
}
#[test]
fn test_fused_q4k_q8k_dot_scales_too_small() {
let data = vec![0u8; 144]; let q8k_scales = vec![]; let q8k_quants = vec![10i8; 256];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_err());
}
#[test]
fn test_fused_q4k_q8k_dot_quants_too_small() {
let data = vec![0u8; 144]; let q8k_scales = vec![1.0f32; 1];
let q8k_quants = vec![10i8; 128];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_err());
}
#[test]
fn test_fused_q4k_q8k_dot_zero_data() {
let data = vec![0u8; 144];
let q8k_scales = vec![0.0f32; 1]; let q8k_quants = vec![10i8; 256];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_q8k_dot_basic() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..12 {
data[4 + i] = 0x11;
}
for i in 0..128 {
data[16 + i] = 0x55;
}
let q8k_scales = vec![0.1f32; 1];
let q8k_quants = vec![10i8; 256];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_q8k_dot_simd_invalid_input() {
let data = vec![0u8; 100]; let q8k_scales = vec![1.0f32; 1];
let q8k_quants = vec![10i8; 256];
let result = fused_q4k_q8k_dot_simd(&data, &q8k_scales, &q8k_quants);
assert!(result.is_err());
}
#[test]
fn test_fused_q4k_q8k_dot_simd_matches_scalar() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..12 {
data[4 + i] = 0x11;
}
for i in 0..128 {
data[16 + i] = ((i * 7) % 256) as u8;
}
let q8k_scales = vec![0.1f32; 1];
let q8k_quants: Vec<i8> = (0..256).map(|i| ((i % 64) - 32) as i8).collect();
let scalar_result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants).unwrap();
let simd_result = fused_q4k_q8k_dot_simd(&data, &q8k_scales, &q8k_quants).unwrap();
let rel_err = if scalar_result.abs() > 1e-6 {
(simd_result - scalar_result).abs() / scalar_result.abs()
} else {
(simd_result - scalar_result).abs()
};
assert!(
rel_err < 0.02,
"scalar={} simd={} rel_err={}",
scalar_result,
simd_result,
rel_err
);
}
#[test]
fn test_fused_q4k_q8k_dot_multiple_super_blocks() {
let mut data = vec![0u8; 288];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[144..146].copy_from_slice(&0x3C00u16.to_le_bytes());
let q8k_scales = vec![0.1f32; 2];
let q8k_quants = vec![5i8; 512];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_dot_max_nibble_values() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
for i in 0..128 {
data[16 + i] = 0xFF;
}
let activations = vec![1.0f32; 256];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_q8k_dot_negative_quants() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
let q8k_scales = vec![0.1f32; 1];
let q8k_quants = vec![-10i8; 256];
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_dot_simd_large_input() {
let mut data = vec![0u8; 1152];
for sb in 0..8 {
let offset = sb * 144;
data[offset..offset + 2].copy_from_slice(&0x2E66u16.to_le_bytes());
}
let activations = vec![0.5f32; 2048];
let scalar_result = fused_q4k_dot(&data, &activations).unwrap();
let simd_result = fused_q4k_dot_simd(&data, &activations).unwrap();
let rel_err = if scalar_result.abs() > 1e-6 {
(simd_result - scalar_result).abs() / scalar_result.abs()
} else {
(simd_result - scalar_result).abs()
};
assert!(
rel_err < 0.001,
"scalar={} simd={} rel_err={}",
scalar_result,
simd_result,
rel_err
);
}
#[test]
fn test_fused_q4k_q8k_dot_mixed_signs() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..128 {
data[16 + i] = if i % 2 == 0 { 0x0F } else { 0xF0 };
}
let q8k_scales = vec![0.1f32; 1];
let q8k_quants: Vec<i8> = (0..256)
.map(|i| if i % 2 == 0 { 10 } else { -10 })
.collect();
let result = fused_q4k_q8k_dot(&data, &q8k_scales, &q8k_quants);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_dot_scale_extraction() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x4000u16.to_le_bytes());
data[2..4].copy_from_slice(&[0x00, 0x00]);
data[4] = 0x3F;
for i in 0..128 {
data[16 + i] = 0x01;
}
let activations = vec![1.0f32; 256];
let result = fused_q4k_dot(&data, &activations);
assert!(result.is_ok());
assert!(result.unwrap().abs() > 0.0);
}
include!("fused_k_tests_q4k.rs");
include!("fused_k_tests_dot_errors.rs");
include!("fused_k_tests_avx2_q4k.rs");