use oxibonsai_core::{BlockTQ2_0, BlockTQ2_0_g128, QK_TQ2_0, QK_TQ2_0_G128};
use crate::error::{KernelError, KernelResult};
#[inline]
fn decode_code_f32(byte: u8, lane: usize) -> f32 {
let code = (byte >> (lane * 2)) & 0b11;
match code {
0b00 => -1.0_f32,
0b01 => 0.0_f32,
0b10 => 1.0_f32,
_ => 0.0_f32, }
}
pub fn dequant_tq2_0_g128(blocks: &[BlockTQ2_0_g128], output: &mut [f32]) -> KernelResult<()> {
let needed = blocks.len() * QK_TQ2_0_G128;
if output.len() < needed {
return Err(KernelError::BufferTooSmall {
needed,
available: output.len(),
});
}
for (bi, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = bi * QK_TQ2_0_G128;
for byte_idx in 0..32 {
let byte = block.qs[byte_idx];
for lane in 0..4_usize {
output[base + byte_idx * 4 + lane] = d * decode_code_f32(byte, lane);
}
}
}
Ok(())
}
pub fn dequant_tq2_0(blocks: &[BlockTQ2_0], output: &mut [f32]) -> KernelResult<()> {
let needed = blocks.len() * QK_TQ2_0;
if output.len() < needed {
return Err(KernelError::BufferTooSmall {
needed,
available: output.len(),
});
}
for (bi, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = bi * QK_TQ2_0;
for byte_idx in 0..64 {
let byte = block.qs[byte_idx];
for lane in 0..4_usize {
output[base + byte_idx * 4 + lane] = d * decode_code_f32(byte, lane);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_g128_block(scale: f32, qs: [u8; 32]) -> BlockTQ2_0_g128 {
BlockTQ2_0_g128 {
qs,
d: f16::from_f32(scale),
}
}
fn make_g256_block(scale: f32, qs: [u8; 64]) -> BlockTQ2_0 {
BlockTQ2_0 {
qs,
d: f16::from_f32(scale),
}
}
#[test]
fn tq2_0_g128_dequant_all_zero() {
let block = make_g128_block(1.0, [0x55; 32]);
let mut output = vec![0.0f32; QK_TQ2_0_G128];
dequant_tq2_0_g128(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 1e-6, "index {i}: expected 0.0, got {v}",);
}
}
#[test]
fn tq2_0_g128_dequant_all_pos() {
let block = make_g128_block(2.0, [0xAA; 32]);
let mut output = vec![0.0f32; QK_TQ2_0_G128];
dequant_tq2_0_g128(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v - 2.0).abs() < 1e-4, "index {i}: expected 2.0, got {v}",);
}
}
#[test]
fn tq2_0_g128_dequant_all_neg() {
let block = make_g128_block(2.0, [0x00; 32]);
let mut output = vec![0.0f32; QK_TQ2_0_G128];
dequant_tq2_0_g128(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v + 2.0).abs() < 1e-4, "index {i}: expected -2.0, got {v}",);
}
}
#[test]
fn tq2_0_g128_dequant_buffer_too_small() {
let block = make_g128_block(1.0, [0xAA; 32]);
let mut output = vec![0.0f32; 0];
let result = dequant_tq2_0_g128(&[block], &mut output);
assert!(result.is_err(), "expected BufferTooSmall error");
}
#[test]
fn tq2_0_g128_dequant_mixed() {
let mut qs = [0x55u8; 32]; qs[0] = 0b10_01_00_10; let block = make_g128_block(3.0, qs);
let mut output = vec![0.0f32; QK_TQ2_0_G128];
dequant_tq2_0_g128(&[block], &mut output).expect("dequant should succeed");
assert!(
(output[0] - 3.0).abs() < 1e-4,
"output[0]: expected 3.0, got {}",
output[0]
);
assert!(
(output[1] + 3.0).abs() < 1e-4,
"output[1]: expected -3.0, got {}",
output[1]
);
assert!(
output[2].abs() < 1e-6,
"output[2]: expected 0.0, got {}",
output[2]
);
assert!(
(output[3] - 3.0).abs() < 1e-4,
"output[3]: expected 3.0, got {}",
output[3]
);
for (offset, val) in output[4..].iter().enumerate() {
assert!(
val.abs() < 1e-6,
"output[{}]: expected 0.0, got {}",
offset + 4,
val
);
}
}
#[test]
fn tq2_0_dequant_all_zero() {
let block = make_g256_block(1.0, [0x55; 64]);
let mut output = vec![0.0f32; QK_TQ2_0];
dequant_tq2_0(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 1e-6, "index {i}: expected 0.0, got {v}");
}
}
#[test]
fn tq2_0_dequant_all_pos() {
let block = make_g256_block(2.0, [0xAA; 64]);
let mut output = vec![0.0f32; QK_TQ2_0];
dequant_tq2_0(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v - 2.0).abs() < 1e-4, "index {i}: expected 2.0, got {v}",);
}
}
#[test]
fn tq2_0_dequant_buffer_too_small() {
let block = make_g256_block(1.0, [0xAA; 64]);
let mut output = vec![0.0f32; 0];
let result = dequant_tq2_0(&[block], &mut output);
assert!(result.is_err(), "expected BufferTooSmall error");
}
}