use super::super::gemv::compute_chunk_q4k_scalar;
use super::super::*;
#[test]
fn test_compute_chunk_scalar_with_offset() {
let in_dim = 256;
let out_dim = 4;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q4k_data = vec![0u8; out_dim * row_bytes];
for row in 0..out_dim {
let offset = row * row_bytes;
q4k_data[offset] = 0x00;
q4k_data[offset + 1] = 0x3C; q4k_data[offset + 2] = 0x00;
q4k_data[offset + 3] = 0x00; for i in 0..12 {
q4k_data[offset + 4 + i] = 0x01;
}
q4k_data[offset + 16..offset + 144].fill(0x55); }
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.01).collect();
let mut chunk = vec![0.0f32; 2];
compute_chunk_q4k_scalar(
&q4k_data,
&input,
&mut chunk,
2, out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for &val in &chunk {
assert!(val.is_finite());
}
}
#[test]
fn test_compute_chunk_scalar_exceeds_outdim() {
let in_dim = 256;
let out_dim = 2;
let num_blocks_per_row = 1;
let row_bytes = SUPER_BLOCK_BYTES;
let mut q4k_data = vec![0u8; out_dim * row_bytes];
for row in 0..out_dim {
let offset = row * row_bytes;
q4k_data[offset] = 0x00;
q4k_data[offset + 1] = 0x3C;
q4k_data[offset + 2] = 0x00;
q4k_data[offset + 3] = 0x00;
for i in 0..12 {
q4k_data[offset + 4 + i] = 0x01;
}
q4k_data[offset + 16..offset + 144].fill(0x33);
}
let input = vec![1.0f32; in_dim];
let mut chunk = vec![0.0f32; 4];
compute_chunk_q4k_scalar(
&q4k_data,
&input,
&mut chunk,
0,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
for i in 0..2 {
assert!(chunk[i].is_finite());
}
assert_eq!(chunk[2], 0.0, "Elements beyond out_dim should remain zero");
assert_eq!(chunk[3], 0.0, "Elements beyond out_dim should remain zero");
}
#[test]
fn test_matmul_q4k_scalar_multiple_blocks() {
let in_dim = 512; let out_dim = 2;
let num_blocks = 2;
let mut q4k_data = Vec::new();
for _ in 0..out_dim {
for _ in 0..num_blocks {
q4k_data.extend_from_slice(&[0x00, 0x3C]); q4k_data.extend_from_slice(&[0x00, 0x00]); q4k_data.extend_from_slice(&[0x02u8; 12]); q4k_data.extend_from_slice(&[0x88u8; 128]); }
}
let input: Vec<f32> = (0..in_dim).map(|i| (i as f32) * 0.001).collect();
let output_scalar = matmul_q4k_f32_scalar(&q4k_data, &input, out_dim, in_dim);
let output_optimized = matmul_q4k_f32(&q4k_data, &input, out_dim, in_dim);
assert_eq!(output_scalar.len(), out_dim);
assert_eq!(output_optimized.len(), out_dim);
for (i, (s, o)) in output_scalar.iter().zip(output_optimized.iter()).enumerate() {
let diff = (s - o).abs();
assert!(diff < 1e-3, "Row {}: scalar={} vs optimized={}, diff={}", i, s, o, diff);
}
}