use super::super::*;
#[test]
fn test_execute_parallel_basic() {
let m = 2;
let k = 256;
let matvec = TiledQ4KMatvec::new(m, k);
let mut weights = vec![0u8; m * Q4K_SUPERBLOCK_BYTES];
for row in 0..m {
let offset = row * Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x3C;
}
let input = vec![1.0f32; k];
let mut output = vec![0.0f32; m];
matvec.execute_parallel(&weights, &input, &mut output);
for val in &output {
assert!(val.is_finite());
}
}
#[test]
fn test_execute_parallel_matches_scalar() {
let m = 4;
let k = 256;
let matvec = TiledQ4KMatvec::new(m, k);
let mut weights = vec![0u8; m * Q4K_SUPERBLOCK_BYTES];
for row in 0..m {
let offset = row * Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x3C;
weights[offset + 2] = 0x00;
weights[offset + 3] = 0x00;
weights[offset + 4] = 0x01;
weights[offset + 5] = 0x02;
for i in 16..144 {
weights[offset + i] = ((row * 7 + i * 3) % 256) as u8;
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32) * 0.01).collect();
let mut output_scalar = vec![0.0f32; m];
let mut output_parallel = vec![0.0f32; m];
matvec.execute_scalar(&weights, &input, &mut output_scalar);
matvec.execute_parallel(&weights, &input, &mut output_parallel);
for i in 0..m {
assert!(
(output_scalar[i] - output_parallel[i]).abs() < 1e-5,
"Mismatch at row {}: scalar={}, parallel={}",
i,
output_scalar[i],
output_parallel[i]
);
}
}
#[test]
fn test_execute_parallel_multiple_superblocks() {
let m = 3;
let k = 512; let sb_per_row = k / Q4K_SUPERBLOCK_SIZE;
let matvec = TiledQ4KMatvec::new(m, k);
assert_eq!(matvec.superblocks_per_row(), sb_per_row);
let mut weights = vec![0u8; m * sb_per_row * Q4K_SUPERBLOCK_BYTES];
for row in 0..m {
for sb in 0..sb_per_row {
let offset = (row * sb_per_row + sb) * Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x38;
for i in 16..144 {
weights[offset + i] = 0x55; }
}
}
let input = vec![1.0f32; k];
let mut output_scalar = vec![0.0f32; m];
let mut output_parallel = vec![0.0f32; m];
matvec.execute_scalar(&weights, &input, &mut output_scalar);
matvec.execute_parallel(&weights, &input, &mut output_parallel);
for i in 0..m {
assert!(
(output_scalar[i] - output_parallel[i]).abs() < 1e-5,
"Row {} mismatch: scalar={}, parallel={}",
i,
output_scalar[i],
output_parallel[i]
);
}
}
#[test]
fn test_execute_parallel_larger_matrix() {
let m = 64;
let k = 256;
let matvec = TiledQ4KMatvec::new(m, k);
let mut weights = vec![0u8; m * Q4K_SUPERBLOCK_BYTES];
for row in 0..m {
let offset = row * Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x3C;
for i in 16..144 {
weights[offset + i] = ((row + i) % 256) as u8;
}
}
let input: Vec<f32> = (0..k).map(|i| ((i % 10) as f32) * 0.1).collect();
let mut output_scalar = vec![0.0f32; m];
let mut output_parallel = vec![0.0f32; m];
matvec.execute_scalar(&weights, &input, &mut output_scalar);
matvec.execute_parallel(&weights, &input, &mut output_parallel);
for i in 0..m {
assert!(
(output_scalar[i] - output_parallel[i]).abs() < 1e-4,
"Row {} mismatch: scalar={}, parallel={}",
i,
output_scalar[i],
output_parallel[i]
);
}
}
#[test]
fn test_execute_parallel_large_k() {
let m = 8;
let k = 1024; let sb_per_row = k / Q4K_SUPERBLOCK_SIZE;
let matvec = TiledQ4KMatvec::new(m, k);
assert_eq!(matvec.superblocks_per_row(), 4);
let mut weights = vec![0u8; m * sb_per_row * Q4K_SUPERBLOCK_BYTES];
for row in 0..m {
for sb in 0..sb_per_row {
let offset = (row * sb_per_row + sb) * Q4K_SUPERBLOCK_BYTES;
weights[offset] = 0x00;
weights[offset + 1] = 0x40;
weights[offset + 4] = 0x03;
for i in 16..144 {
weights[offset + i] = 0xAA; }
}
}
let input = vec![0.5f32; k];
let mut output = vec![0.0f32; m];
matvec.execute_parallel(&weights, &input, &mut output);
for (i, val) in output.iter().enumerate() {
assert!(val.is_finite(), "Row {} is not finite", i);
}
}
#[test]
fn test_execute_parallel_single_row() {
let m = 1;
let k = 256;
let matvec = TiledQ4KMatvec::new(m, k);
let mut weights = vec![0u8; Q4K_SUPERBLOCK_BYTES];
weights[0] = 0x00;
weights[1] = 0x3C;
for i in 16..144 {
weights[i] = 0x11;
}
let input = vec![1.0f32; k];
let mut output_scalar = vec![0.0f32; m];
let mut output_parallel = vec![0.0f32; m];
matvec.execute_scalar(&weights, &input, &mut output_scalar);
matvec.execute_parallel(&weights, &input, &mut output_parallel);
assert!(
(output_scalar[0] - output_parallel[0]).abs() < 1e-5,
"scalar={}, parallel={}",
output_scalar[0],
output_parallel[0]
);
}