use super::super::gemv::compute_chunk_scalar;
use super::super::*;
#[test]
fn test_parallel_dispatch_large_matrix() {
let out_dim = 4096;
let in_dim = 2048; let total_work = out_dim * in_dim;
assert!(total_work >= 8_000_000, "Test must trigger parallel path");
let num_superblocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_superblocks_per_row * SUPER_BLOCK_BYTES;
let total_bytes = out_dim * row_bytes;
let mut q6k_data = vec![0u8; total_bytes];
for row in 0..out_dim {
for sb in 0..num_superblocks_per_row {
let offset = row * row_bytes + sb * SUPER_BLOCK_BYTES;
q6k_data[offset] = 0x00;
q6k_data[offset + 1] = 0x3C;
for i in 0..128 {
q6k_data[offset + 2 + i] = ((row + sb + i) % 64) as u8;
}
for i in 0..64 {
q6k_data[offset + 130 + i] = ((row ^ sb ^ i) % 4) as u8;
}
for i in 0..16 {
q6k_data[offset + 194 + i] = 0x10;
}
}
}
let input: Vec<f32> = (0..in_dim).map(|i| (i % 10) as f32 * 0.1).collect();
let result = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(result.len(), out_dim);
for (i, &val) in result.iter().enumerate() {
assert!(val.is_finite(), "Result[{}] is not finite: {}", i, val);
}
let scalar_result = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
for i in (0..out_dim).step_by(512) {
let diff = (result[i] - scalar_result[i]).abs();
let tol = scalar_result[i].abs() * 0.01 + 1e-4;
assert!(
diff < tol,
"Parallel vs scalar mismatch at row {}: parallel={}, scalar={}, diff={}",
i,
result[i],
scalar_result[i],
diff
);
}
}
#[test]
#[allow(deprecated)]
fn test_parallel_colmajor_large_matrix() {
let ne0 = 2048; let ne1 = 4096;
let blocks_per_col = (ne0 + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let col_bytes = blocks_per_col * SUPER_BLOCK_BYTES;
let total_bytes = ne1 * col_bytes;
let mut q6k_data = vec![0u8; total_bytes];
for col in 0..ne1 {
for sb in 0..blocks_per_col {
let offset = col * col_bytes + sb * SUPER_BLOCK_BYTES;
q6k_data[offset] = 0x00;
q6k_data[offset + 1] = 0x38;
for i in 0..128 {
q6k_data[offset + 2 + i] = ((col ^ sb ^ i) % 64) as u8;
}
for i in 0..64 {
q6k_data[offset + 130 + i] = ((col + sb) % 4) as u8;
}
for i in 0..16 {
q6k_data[offset + 194 + i] = 0x20;
}
}
}
let input: Vec<f32> = (0..ne1).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let result = matmul_q6k_f32_colmajor_dispatch(&q6k_data, &input, ne0, ne1);
assert_eq!(result.len(), ne0);
for (i, &val) in result.iter().enumerate() {
assert!(val.is_finite(), "Result[{}] is not finite: {}", i, val);
}
}
#[test]
fn test_compute_chunk_scalar_small() {
let in_dim = 256;
let out_dim = 4;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q6k_data = vec![0u8; out_dim * row_bytes];
for row in 0..out_dim {
let offset = row * row_bytes;
q6k_data[offset] = 0x00;
q6k_data[offset + 1] = 0x3C;
for i in 0..128 {
q6k_data[offset + 2 + i] = 0x00;
}
for i in 0..64 {
q6k_data[offset + 130 + i] = 0x00;
}
for i in 0..16 {
q6k_data[offset + 194 + i] = 0x01;
}
}
let input = vec![1.0f32; in_dim];
let mut chunk = vec![0.0f32; out_dim];
compute_chunk_scalar(
&q6k_data,
&input,
&mut chunk,
0,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for (i, &val) in chunk.iter().enumerate() {
assert!(val.is_finite(), "Chunk[{}] is not finite: {}", i, val);
}
}
#[test]
fn test_compute_chunk_scalar_with_offset() {
let in_dim = 256;
let out_dim = 4;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q6k_data = Vec::new();
for row in 0..out_dim {
for i in 0..128 {
q6k_data.push(((row * 17 + i) % 256) as u8);
}
for i in 0..64 {
q6k_data.push(((row * 7 + i) % 256) as u8);
}
q6k_data.extend_from_slice(&[0x02u8; 16]);
q6k_data.extend_from_slice(&[0x00, 0x3C]);
}
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.01).collect();
let mut chunk = vec![0.0f32; 2];
compute_chunk_scalar(
&q6k_data,
&input,
&mut chunk,
2, out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for &val in &chunk {
assert!(val.is_finite());
}
}
#[test]
fn test_compute_chunk_scalar_exceeds_outdim() {
let in_dim = 256;
let out_dim = 2;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q6k_data = Vec::new();
for _ in 0..out_dim {
q6k_data.extend_from_slice(&[0x33u8; 128]); q6k_data.extend_from_slice(&[0x11u8; 64]); q6k_data.extend_from_slice(&[0x01u8; 16]); q6k_data.extend_from_slice(&[0x00, 0x3C]); }
let input = vec![1.0f32; in_dim];
let mut chunk = vec![0.0f32; 4];
compute_chunk_scalar(
&q6k_data,
&input,
&mut chunk,
0,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for i in 0..2 {
assert!(chunk[i].is_finite());
}
assert_eq!(chunk[2], 0.0, "Elements beyond out_dim should remain zero");
assert_eq!(chunk[3], 0.0, "Elements beyond out_dim should remain zero");
}
#[test]
fn test_matmul_q6k_scalar_multiple_blocks() {
let in_dim = 512; let out_dim = 2;
let num_blocks = 2;
let mut q6k_data = Vec::new();
for _ in 0..out_dim {
for _ in 0..num_blocks {
q6k_data.extend_from_slice(&[0x44u8; 128]); q6k_data.extend_from_slice(&[0x00u8; 64]); q6k_data.extend_from_slice(&[0x02u8; 16]); q6k_data.extend_from_slice(&[0x00, 0x3C]); }
}
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.001).collect();
let output_scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
let output_dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);
assert_eq!(output_scalar.len(), out_dim);
assert_eq!(output_dispatch.len(), out_dim);
for (i, (s, d)) in output_scalar.iter().zip(output_dispatch.iter()).enumerate() {
let diff = (s - d).abs();
assert!(diff < 1e-3, "Row {}: scalar={} vs dispatch={}, diff={}", i, s, d, diff);
}
}