use half::f16;
use crate::error::{BonsaiError, BonsaiResult};
pub const QK1_0_G128: usize = 128;
pub const BLOCK_SIZE_BYTES: usize = 18;
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ1_0G128 {
pub d: f16,
pub qs: [u8; QK1_0_G128 / 8],
}
const _: () = assert!(std::mem::size_of::<BlockQ1_0G128>() == BLOCK_SIZE_BYTES);
impl BlockQ1_0G128 {
pub fn from_bytes(data: &[u8]) -> BonsaiResult<&Self> {
if data.len() < BLOCK_SIZE_BYTES {
return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
}
let ptr = data.as_ptr() as *const BlockQ1_0G128;
Ok(unsafe { &*ptr })
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_SIZE_BYTES != 0 {
return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
}
let count = data.len() / BLOCK_SIZE_BYTES;
let ptr = data.as_ptr() as *const BlockQ1_0G128;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
#[inline]
pub fn sign_bit(&self, i: usize) -> bool {
debug_assert!(i < QK1_0_G128);
let byte_index = i / 8;
let bit_offset = i % 8;
(self.qs[byte_index] >> bit_offset) & 1 != 0
}
#[inline]
pub fn weight(&self, i: usize) -> f32 {
let d = self.d.to_f32();
if self.sign_bit(i) {
d
} else {
-d
}
}
}
#[derive(Debug)]
pub struct OneBitTensor<'a> {
pub name: String,
pub shape: Vec<u64>,
blocks: &'a [BlockQ1_0G128],
}
impl<'a> OneBitTensor<'a> {
pub fn from_raw(name: String, shape: Vec<u64>, data: &'a [u8]) -> BonsaiResult<Self> {
let blocks = BlockQ1_0G128::slice_from_bytes(data)?;
Ok(Self {
name,
shape,
blocks,
})
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn element_count(&self) -> usize {
self.blocks.len() * QK1_0_G128
}
pub fn block(&self, index: usize) -> &BlockQ1_0G128 {
&self.blocks[index]
}
pub fn blocks(&self) -> &[BlockQ1_0G128] {
self.blocks
}
pub fn dequantize_all(&self) -> Vec<f32> {
let n = self.element_count();
let mut output = vec![0.0f32; n];
for (i, block) in self.blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = i * QK1_0_G128;
for j in 0..QK1_0_G128 {
let byte_index = j / 8;
let bit_offset = j % 8;
let bit = (block.qs[byte_index] >> bit_offset) & 1;
output[base + j] = if bit != 0 { d } else { -d };
}
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: bits,
}
}
#[test]
fn block_size_is_18_bytes() {
assert_eq!(std::mem::size_of::<BlockQ1_0G128>(), 18);
}
#[test]
fn all_ones_dequantize_to_positive() {
let block = make_block(2.0, [0xFF; 16]);
for i in 0..128 {
assert!(block.sign_bit(i));
assert!((block.weight(i) - 2.0).abs() < 0.01);
}
}
#[test]
fn all_zeros_dequantize_to_negative() {
let block = make_block(3.0, [0x00; 16]);
for i in 0..128 {
assert!(!block.sign_bit(i));
assert!((block.weight(i) + 3.0).abs() < 0.01);
}
}
#[test]
fn alternating_bits() {
let block = make_block(1.0, [0xAA; 16]);
for i in 0..128 {
if i % 2 == 0 {
assert!(!block.sign_bit(i), "bit {i} should be 0");
} else {
assert!(block.sign_bit(i), "bit {i} should be 1");
}
}
}
#[test]
fn from_bytes_roundtrip() {
let block = make_block(1.5, [0xFF; 16]);
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
&block as *const BlockQ1_0G128 as *const u8,
BLOCK_SIZE_BYTES,
)
};
let parsed = BlockQ1_0G128::from_bytes(bytes).expect("block parse should succeed");
assert_eq!(parsed, &block);
}
#[test]
fn one_bit_tensor_dequantize() {
let block = make_block(2.0, [0xFF; 16]);
let bytes: Vec<u8> = unsafe {
std::slice::from_raw_parts(
&block as *const BlockQ1_0G128 as *const u8,
BLOCK_SIZE_BYTES,
)
.to_vec()
};
let tensor = OneBitTensor::from_raw("test".to_string(), vec![128], &bytes)
.expect("tensor creation should succeed");
assert_eq!(tensor.num_blocks(), 1);
assert_eq!(tensor.element_count(), 128);
let values = tensor.dequantize_all();
for &v in &values {
assert!((v - 2.0).abs() < 0.01);
}
}
}