Skip to main content

oxibonsai_core/
tensor.rs

1//! Q1\_0\_g128 tensor types and 1-bit data access.
2//!
3//! Defines the [`BlockQ1_0G128`] structure matching the PrismML GGUF format
4//! and the [`OneBitTensor`] wrapper for efficient tensor access.
5
6use half::f16;
7
8use crate::error::{BonsaiError, BonsaiResult};
9
10/// Number of weights per Q1\_0\_g128 block.
11pub const QK1_0_G128: usize = 128;
12
13/// Size of a Q1\_0\_g128 block in bytes (2-byte FP16 scale + 16 bytes sign bits).
14pub const BLOCK_SIZE_BYTES: usize = 18;
15
16/// A single Q1\_0\_g128 quantized block.
17///
18/// Layout (18 bytes total):
19/// - `d`: FP16 scale factor (2 bytes) — shared by all 128 weights
20/// - `qs`: 128 sign bits packed into 16 bytes
21///
22/// Weight reconstruction: `w[i] = bit[i] ? +d : -d`
23#[derive(Debug, Clone, Copy, PartialEq)]
24#[repr(C)]
25pub struct BlockQ1_0G128 {
26    /// Scale factor (delta), FP16.
27    pub d: f16,
28    /// 128 sign bits packed into 16 bytes.
29    pub qs: [u8; QK1_0_G128 / 8],
30}
31
32const _: () = assert!(std::mem::size_of::<BlockQ1_0G128>() == BLOCK_SIZE_BYTES);
33
34impl BlockQ1_0G128 {
35    /// Interpret a raw byte slice as a block reference (zero-copy).
36    pub fn from_bytes(data: &[u8]) -> BonsaiResult<&Self> {
37        if data.len() < BLOCK_SIZE_BYTES {
38            return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
39        }
40        // SAFETY: BlockQ1_0G128 is repr(C) with known layout, and we've validated
41        // the minimum size. The f16 type is repr(transparent) over u16.
42        let ptr = data.as_ptr() as *const BlockQ1_0G128;
43        Ok(unsafe { &*ptr })
44    }
45
46    /// Interpret a raw byte slice as a slice of blocks (zero-copy).
47    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
48        if data.len() % BLOCK_SIZE_BYTES != 0 {
49            return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
50        }
51        let count = data.len() / BLOCK_SIZE_BYTES;
52        let ptr = data.as_ptr() as *const BlockQ1_0G128;
53        // SAFETY: Same as above, plus we've checked alignment to block size.
54        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
55    }
56
57    /// Get the sign bit for weight at index `i` (0..127).
58    /// Returns `true` for +d, `false` for -d.
59    #[inline]
60    pub fn sign_bit(&self, i: usize) -> bool {
61        debug_assert!(i < QK1_0_G128);
62        let byte_index = i / 8;
63        let bit_offset = i % 8;
64        (self.qs[byte_index] >> bit_offset) & 1 != 0
65    }
66
67    /// Get the reconstructed weight value at index `i`.
68    #[inline]
69    pub fn weight(&self, i: usize) -> f32 {
70        let d = self.d.to_f32();
71        if self.sign_bit(i) {
72            d
73        } else {
74            -d
75        }
76    }
77}
78
79/// A 1-bit tensor backed by Q1\_0\_g128 blocks.
80///
81/// This wraps raw GGUF tensor data and provides typed access to blocks
82/// without copying or dequantizing the entire tensor.
83#[derive(Debug)]
84pub struct OneBitTensor<'a> {
85    /// Tensor name.
86    pub name: String,
87    /// Shape dimensions.
88    pub shape: Vec<u64>,
89    /// Raw block data.
90    blocks: &'a [BlockQ1_0G128],
91}
92
93impl<'a> OneBitTensor<'a> {
94    /// Create a 1-bit tensor from raw GGUF tensor data bytes.
95    pub fn from_raw(name: String, shape: Vec<u64>, data: &'a [u8]) -> BonsaiResult<Self> {
96        let blocks = BlockQ1_0G128::slice_from_bytes(data)?;
97        Ok(Self {
98            name,
99            shape,
100            blocks,
101        })
102    }
103
104    /// Number of blocks in this tensor.
105    pub fn num_blocks(&self) -> usize {
106        self.blocks.len()
107    }
108
109    /// Total number of elements (weights) in this tensor.
110    pub fn element_count(&self) -> usize {
111        self.blocks.len() * QK1_0_G128
112    }
113
114    /// Get a reference to the block at the given index.
115    pub fn block(&self, index: usize) -> &BlockQ1_0G128 {
116        &self.blocks[index]
117    }
118
119    /// Get all blocks as a slice.
120    pub fn blocks(&self) -> &[BlockQ1_0G128] {
121        self.blocks
122    }
123
124    /// Dequantize all blocks to FP32 values.
125    ///
126    /// For the full tensor, this allocates and fills an output vector.
127    /// For per-operation dequantization, use the kernel crate instead.
128    pub fn dequantize_all(&self) -> Vec<f32> {
129        let n = self.element_count();
130        let mut output = vec![0.0f32; n];
131        for (i, block) in self.blocks.iter().enumerate() {
132            let d = block.d.to_f32();
133            let base = i * QK1_0_G128;
134            for j in 0..QK1_0_G128 {
135                let byte_index = j / 8;
136                let bit_offset = j % 8;
137                let bit = (block.qs[byte_index] >> bit_offset) & 1;
138                output[base + j] = if bit != 0 { d } else { -d };
139            }
140        }
141        output
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
150        BlockQ1_0G128 {
151            d: f16::from_f32(scale),
152            qs: bits,
153        }
154    }
155
156    #[test]
157    fn block_size_is_18_bytes() {
158        assert_eq!(std::mem::size_of::<BlockQ1_0G128>(), 18);
159    }
160
161    #[test]
162    fn all_ones_dequantize_to_positive() {
163        let block = make_block(2.0, [0xFF; 16]);
164        for i in 0..128 {
165            assert!(block.sign_bit(i));
166            assert!((block.weight(i) - 2.0).abs() < 0.01);
167        }
168    }
169
170    #[test]
171    fn all_zeros_dequantize_to_negative() {
172        let block = make_block(3.0, [0x00; 16]);
173        for i in 0..128 {
174            assert!(!block.sign_bit(i));
175            assert!((block.weight(i) + 3.0).abs() < 0.01);
176        }
177    }
178
179    #[test]
180    fn alternating_bits() {
181        // 0xAA = 10101010 in binary: bits 1,3,5,7 set; 0,2,4,6 clear
182        let block = make_block(1.0, [0xAA; 16]);
183        for i in 0..128 {
184            if i % 2 == 0 {
185                assert!(!block.sign_bit(i), "bit {i} should be 0");
186            } else {
187                assert!(block.sign_bit(i), "bit {i} should be 1");
188            }
189        }
190    }
191
192    #[test]
193    fn from_bytes_roundtrip() {
194        let block = make_block(1.5, [0xFF; 16]);
195        let bytes: &[u8] = unsafe {
196            std::slice::from_raw_parts(
197                &block as *const BlockQ1_0G128 as *const u8,
198                BLOCK_SIZE_BYTES,
199            )
200        };
201        let parsed = BlockQ1_0G128::from_bytes(bytes).expect("block parse should succeed");
202        assert_eq!(parsed, &block);
203    }
204
205    #[test]
206    fn one_bit_tensor_dequantize() {
207        let block = make_block(2.0, [0xFF; 16]);
208        let bytes: Vec<u8> = unsafe {
209            std::slice::from_raw_parts(
210                &block as *const BlockQ1_0G128 as *const u8,
211                BLOCK_SIZE_BYTES,
212            )
213            .to_vec()
214        };
215        let tensor = OneBitTensor::from_raw("test".to_string(), vec![128], &bytes)
216            .expect("tensor creation should succeed");
217        assert_eq!(tensor.num_blocks(), 1);
218        assert_eq!(tensor.element_count(), 128);
219
220        let values = tensor.dequantize_all();
221        for &v in &values {
222            assert!((v - 2.0).abs() < 0.01);
223        }
224    }
225}