#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_q6k_dot_invalid_data_length() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not a multiple"));
}
#[test]
fn test_fused_q6k_dot_activation_length_mismatch() {
let data = vec![0u8; 210];
let activations = vec![0.0f32; 128];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("doesn't match"));
}
#[test]
fn test_fused_q6k_dot_zero_data() {
let mut data = vec![0u8; 210];
data[208..210].copy_from_slice(&[0x00, 0x00]);
let activations = vec![1.0f32; 256];
let result = fused_q6k_dot(&data, &activations).expect("result");
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q6k_dot_single_super_block() {
let mut data = vec![0u8; 210];
data[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
for i in 0..16 {
data[192 + i] = 1; }
for i in 0..128 {
data[i] = ((i % 16) as u8) | (((i % 16) as u8) << 4);
}
let activations = vec![1.0f32; 256];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q6k_dot_simd_matches_scalar() {
let mut data = vec![0u8; 210];
data[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
for i in 0..16 {
data[192 + i] = 1;
}
for i in 0..128 {
data[i] = (i % 256) as u8;
}
for i in 0..64 {
data[128 + i] = (i % 256) as u8;
}
let activations: Vec<f32> = (0..256).map(|i| (i as f32) * 0.01).collect();
let scalar_result = fused_q6k_dot(&data, &activations).expect("scalar_result");
let simd_result = fused_q6k_dot_simd(&data, &activations).expect("simd_result");
let rel_err = if scalar_result.abs() > 1e-6 {
(simd_result - scalar_result).abs() / scalar_result.abs()
} else {
(simd_result - scalar_result).abs()
};
assert!(
rel_err < 0.01,
"scalar={} simd={} rel_err={}",
scalar_result,
simd_result,
rel_err
);
}
#[test]
fn test_fused_q6k_dot_simd_invalid_input() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q6k_dot_simd(&data, &activations);
assert!(result.is_err());
}
#[test]
fn test_fused_q6k_dot_multiple_super_blocks() {
let mut data = vec![0u8; 420];
data[208..210].copy_from_slice(&0x3800u16.to_le_bytes());
data[418..420].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..16 {
data[192 + i] = 2;
data[402 + i] = 2;
}
let activations = vec![0.5f32; 512];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q5k_dot_invalid_data_length() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not a multiple"));
}
#[test]
fn test_fused_q5k_dot_activation_length_mismatch() {
let data = vec![0u8; 176];
let activations = vec![0.0f32; 128];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("doesn't match"));
}
#[test]
fn test_fused_q5k_dot_zero_data() {
let data = vec![0u8; 176];
let activations = vec![1.0f32; 256];
let result = fused_q5k_dot(&data, &activations).expect("result");
assert!(result.abs() < 1e-6);
}
#[test]
fn test_fused_q5k_dot_single_super_block() {
let mut data = vec![0u8; 176];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..128 {
data[48 + i] = ((i % 16) as u8) | (((i + 1) % 16) << 4) as u8;
}
let activations = vec![1.0f32; 256];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q5k_dot_simd_matches_scalar() {
let mut data = vec![0u8; 176];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..12 {
data[4 + i] = 0x11; }
for i in 0..32 {
data[16 + i] = (i % 256) as u8;
}
for i in 0..128 {
data[48 + i] = ((i * 3) % 256) as u8;
}
let activations: Vec<f32> = (0..256).map(|i| (i as f32) * 0.01).collect();
let scalar_result = fused_q5k_dot(&data, &activations).expect("scalar_result");
let simd_result = fused_q5k_dot_simd(&data, &activations).expect("simd_result");
assert!(
(scalar_result - simd_result).abs() < 1e-6,
"scalar={} simd={}",
scalar_result,
simd_result
);
}
#[test]
fn test_fused_q5k_dot_simd_invalid_input() {
let data = vec![0u8; 100]; let activations = vec![0.0f32; 256];
let result = fused_q5k_dot_simd(&data, &activations);
assert!(result.is_err());
}
#[test]
fn test_fused_q5k_dot_multiple_super_blocks() {
let mut data = vec![0u8; 352];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
data[176..178].copy_from_slice(&0x3C00u16.to_le_bytes());
data[178..180].copy_from_slice(&0x3800u16.to_le_bytes());
let activations = vec![0.5f32; 512];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_ok());
}
fn zero_q8_block() -> Q8_0Block {
Q8_0Block {
scale: 0.0,
quants: [0i8; 32],
}
}
fn make_q8_block(scale: f32, quant_val: i8) -> Q8_0Block {
Q8_0Block {
scale,
quants: [quant_val; 32],
}
}
#[test]
fn test_fused_q4k_q8_dot_invalid_data_length() {
let data = vec![0u8; 100]; let q8_blocks: Vec<Q8_0Block> = (0..8).map(|_| zero_q8_block()).collect();
let result = fused_q4k_q8_dot(&data, &q8_blocks);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not a multiple"));
}
#[test]
fn test_fused_q4k_q8_dot_block_count_mismatch() {
let data = vec![0u8; 144];
let q8_blocks: Vec<Q8_0Block> = (0..4).map(|_| zero_q8_block()).collect();
let result = fused_q4k_q8_dot(&data, &q8_blocks);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("doesn't match"));
}
#[test]
fn test_fused_q4k_q8_dot_zero_data() {
let data = vec![0u8; 144];
let q8_blocks: Vec<Q8_0Block> = (0..8).map(|_| zero_q8_block()).collect();
let result = fused_q4k_q8_dot(&data, &q8_blocks).expect("result");
assert_eq!(result, 0.0);
}
#[test]
fn test_fused_q4k_q8_dot_single_super_block() {
let mut data = vec![0u8; 144];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x3800u16.to_le_bytes());
for i in 0..128 {
data[16 + i] = 0x55; }
let q8_blocks: Vec<Q8_0Block> = (0..8).map(|_| make_q8_block(0.1, 10)).collect();
let result = fused_q4k_q8_dot(&data, &q8_blocks);
assert!(result.is_ok());
}
#[test]
fn test_fused_q4k_q8_dot_multiple_super_blocks() {
let mut data = vec![0u8; 288];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[144..146].copy_from_slice(&0x3C00u16.to_le_bytes());
let q8_blocks: Vec<Q8_0Block> = (0..16).map(|_| zero_q8_block()).collect();
let result = fused_q4k_q8_dot(&data, &q8_blocks);
assert!(result.is_ok());
}
#[test]
fn test_fused_q6k_dot_negative_scales() {
let mut data = vec![0u8; 210];
data[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
#[allow(clippy::cast_sign_loss)]
for i in 0..16 {
data[192 + i] = (-5i8) as u8;
}
let activations = vec![1.0f32; 256];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q5k_dot_with_high_bits() {
let mut data = vec![0u8; 176];
data[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
data[2..4].copy_from_slice(&0x2E66u16.to_le_bytes());
for i in 0..32 {
data[16 + i] = 0xFF;
}
for i in 0..128 {
data[48 + i] = 0xF0; }
let activations = vec![1.0f32; 256];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_ok());
}
#[test]
fn test_fused_q6k_dot_empty_data() {
let data = vec![];
let activations = vec![];
let result = fused_q6k_dot(&data, &activations);
assert!(result.is_ok());
assert_eq!(result.expect("result"), 0.0);
}
#[test]
fn test_fused_q5k_dot_empty_data() {
let data = vec![];
let activations = vec![];
let result = fused_q5k_dot(&data, &activations);
assert!(result.is_ok());
assert_eq!(result.expect("result"), 0.0);
}
}