use oxibonsai_core::{BlockTQ2_0, BlockTQ2_0_g128, QK_TQ2_0, QK_TQ2_0_G128};
use crate::error::{KernelError, KernelResult};
#[inline]
fn decode_weight_f32(byte: u8, lane: usize) -> f32 {
let code = (byte >> (lane * 2)) & 0b11;
match code {
0b00 => -1.0_f32,
0b01 => 0.0_f32,
0b10 => 1.0_f32,
_ => 0.0_f32, }
}
pub fn gemv_tq2_0_g128(
blocks: &[BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK_TQ2_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_TQ2_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK_TQ2_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for row in 0..n_rows {
let mut sum = 0.0_f32;
for bi in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + bi];
let d = block.d.to_f32();
let input_base = bi * QK_TQ2_0_G128;
let mut block_sum = 0.0_f32;
for byte_idx in 0..32 {
let byte = block.qs[byte_idx];
for lane in 0..4_usize {
let weight = decode_weight_f32(byte, lane);
block_sum += weight * input[input_base + byte_idx * 4 + lane];
}
}
sum += d * block_sum;
}
output[row] = sum;
}
Ok(())
}
pub fn gemv_tq2_0(
blocks: &[BlockTQ2_0],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK_TQ2_0 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_TQ2_0,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK_TQ2_0;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for row in 0..n_rows {
let mut sum = 0.0_f32;
for bi in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + bi];
let d = block.d.to_f32();
let input_base = bi * QK_TQ2_0;
let mut block_sum = 0.0_f32;
for byte_idx in 0..64 {
let byte = block.qs[byte_idx];
for lane in 0..4_usize {
let weight = decode_weight_f32(byte, lane);
block_sum += weight * input[input_base + byte_idx * 4 + lane];
}
}
sum += d * block_sum;
}
output[row] = sum;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_g128_block(scale: f32, qs: [u8; 32]) -> BlockTQ2_0_g128 {
BlockTQ2_0_g128 {
qs,
d: f16::from_f32(scale),
}
}
fn make_g256_block(scale: f32, qs: [u8; 64]) -> BlockTQ2_0 {
BlockTQ2_0 {
qs,
d: f16::from_f32(scale),
}
}
#[test]
fn gemv_tq2_0_g128_identity() {
let blocks = vec![
make_g128_block(1.0, [0xAA; 32]), make_g128_block(1.0, [0xAA; 32]), ];
let input = vec![1.0_f32; 128];
let mut output = vec![0.0_f32; 2];
gemv_tq2_0_g128(&blocks, &input, &mut output, 2, 128).expect("gemv should succeed");
assert!(
(output[0] - 128.0).abs() < 0.5,
"output[0]: expected 128.0, got {}",
output[0]
);
assert!(
(output[1] - 128.0).abs() < 0.5,
"output[1]: expected 128.0, got {}",
output[1]
);
}
#[test]
fn gemv_tq2_0_g128_all_negative() {
let blocks = vec![
make_g128_block(1.0, [0x00; 32]), make_g128_block(1.0, [0x00; 32]), ];
let input = vec![1.0_f32; 128];
let mut output = vec![0.0_f32; 2];
gemv_tq2_0_g128(&blocks, &input, &mut output, 2, 128).expect("gemv should succeed");
assert!(
(output[0] + 128.0).abs() < 0.5,
"output[0]: expected -128.0, got {}",
output[0]
);
assert!(
(output[1] + 128.0).abs() < 0.5,
"output[1]: expected -128.0, got {}",
output[1]
);
}
#[test]
fn gemv_tq2_0_g128_alternating() {
let blocks = vec![make_g128_block(1.0, [0x46; 32])];
let input = vec![1.0_f32; 128];
let mut output = vec![0.0_f32; 1];
gemv_tq2_0_g128(&blocks, &input, &mut output, 1, 128).expect("gemv should succeed");
assert!(
output[0].abs() < 1e-4,
"alternating: expected 0.0, got {}",
output[0]
);
}
#[test]
fn gemv_tq2_0_g128_not_block_aligned() {
let blocks = vec![make_g128_block(1.0, [0xAA; 32])];
let input = vec![1.0_f32; 100];
let mut output = vec![0.0_f32; 1];
let result = gemv_tq2_0_g128(&blocks, &input, &mut output, 1, 100);
assert!(result.is_err(), "expected NotBlockAligned error");
}
#[test]
fn gemv_tq2_0_g128_dimension_validation() {
let blocks = vec![make_g128_block(1.0, [0xAA; 32])];
let input = vec![1.0_f32; 128];
let mut output = vec![0.0_f32; 2];
let result = gemv_tq2_0_g128(&blocks, &input, &mut output, 2, 128);
assert!(result.is_err(), "expected DimensionMismatch error");
}
#[test]
fn gemv_tq2_0_g128_multiple_rows() {
let blocks = vec![make_g128_block(1.0, [0xAA; 32]); 8];
let input = vec![1.0_f32; 256];
let mut output = vec![0.0_f32; 4];
gemv_tq2_0_g128(&blocks, &input, &mut output, 4, 256).expect("gemv should succeed");
for (i, &v) in output.iter().enumerate() {
assert!(
(v - 256.0).abs() < 1.0,
"output[{i}]: expected 256.0, got {v}",
);
}
}
#[test]
fn gemv_tq2_0_identity() {
let blocks = vec![make_g256_block(1.0, [0xAA; 64])];
let input = vec![1.0_f32; 256];
let mut output = vec![0.0_f32; 1];
gemv_tq2_0(&blocks, &input, &mut output, 1, 256).expect("gemv should succeed");
assert!(
(output[0] - 256.0).abs() < 1.0,
"expected 256.0, got {}",
output[0]
);
}
#[test]
fn gemv_tq2_0_not_block_aligned() {
let blocks = vec![make_g256_block(1.0, [0xAA; 64])];
let input = vec![1.0_f32; 100];
let mut output = vec![0.0_f32; 1];
let result = gemv_tq2_0(&blocks, &input, &mut output, 1, 100);
assert!(result.is_err(), "expected NotBlockAligned error");
}
}