1use half::f16;
7
8use crate::error::{BonsaiError, BonsaiResult};
9
10pub const QK1_0_G128: usize = 128;
12
13pub const BLOCK_SIZE_BYTES: usize = 18;
15
16#[derive(Debug, Clone, Copy, PartialEq)]
24#[repr(C)]
25pub struct BlockQ1_0G128 {
26 pub d: f16,
28 pub qs: [u8; QK1_0_G128 / 8],
30}
31
32const _: () = assert!(std::mem::size_of::<BlockQ1_0G128>() == BLOCK_SIZE_BYTES);
33
34impl BlockQ1_0G128 {
35 pub fn from_bytes(data: &[u8]) -> BonsaiResult<&Self> {
37 if data.len() < BLOCK_SIZE_BYTES {
38 return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
39 }
40 let ptr = data.as_ptr() as *const BlockQ1_0G128;
43 Ok(unsafe { &*ptr })
44 }
45
46 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
48 if data.len() % BLOCK_SIZE_BYTES != 0 {
49 return Err(BonsaiError::InvalidBlockSize { actual: data.len() });
50 }
51 let count = data.len() / BLOCK_SIZE_BYTES;
52 let ptr = data.as_ptr() as *const BlockQ1_0G128;
53 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
55 }
56
57 #[inline]
60 pub fn sign_bit(&self, i: usize) -> bool {
61 debug_assert!(i < QK1_0_G128);
62 let byte_index = i / 8;
63 let bit_offset = i % 8;
64 (self.qs[byte_index] >> bit_offset) & 1 != 0
65 }
66
67 #[inline]
69 pub fn weight(&self, i: usize) -> f32 {
70 let d = self.d.to_f32();
71 if self.sign_bit(i) {
72 d
73 } else {
74 -d
75 }
76 }
77}
78
79#[derive(Debug)]
84pub struct OneBitTensor<'a> {
85 pub name: String,
87 pub shape: Vec<u64>,
89 blocks: &'a [BlockQ1_0G128],
91}
92
93impl<'a> OneBitTensor<'a> {
94 pub fn from_raw(name: String, shape: Vec<u64>, data: &'a [u8]) -> BonsaiResult<Self> {
96 let blocks = BlockQ1_0G128::slice_from_bytes(data)?;
97 Ok(Self {
98 name,
99 shape,
100 blocks,
101 })
102 }
103
104 pub fn num_blocks(&self) -> usize {
106 self.blocks.len()
107 }
108
109 pub fn element_count(&self) -> usize {
111 self.blocks.len() * QK1_0_G128
112 }
113
114 pub fn block(&self, index: usize) -> &BlockQ1_0G128 {
116 &self.blocks[index]
117 }
118
119 pub fn blocks(&self) -> &[BlockQ1_0G128] {
121 self.blocks
122 }
123
124 pub fn dequantize_all(&self) -> Vec<f32> {
129 let n = self.element_count();
130 let mut output = vec![0.0f32; n];
131 for (i, block) in self.blocks.iter().enumerate() {
132 let d = block.d.to_f32();
133 let base = i * QK1_0_G128;
134 for j in 0..QK1_0_G128 {
135 let byte_index = j / 8;
136 let bit_offset = j % 8;
137 let bit = (block.qs[byte_index] >> bit_offset) & 1;
138 output[base + j] = if bit != 0 { d } else { -d };
139 }
140 }
141 output
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
150 BlockQ1_0G128 {
151 d: f16::from_f32(scale),
152 qs: bits,
153 }
154 }
155
156 #[test]
157 fn block_size_is_18_bytes() {
158 assert_eq!(std::mem::size_of::<BlockQ1_0G128>(), 18);
159 }
160
161 #[test]
162 fn all_ones_dequantize_to_positive() {
163 let block = make_block(2.0, [0xFF; 16]);
164 for i in 0..128 {
165 assert!(block.sign_bit(i));
166 assert!((block.weight(i) - 2.0).abs() < 0.01);
167 }
168 }
169
170 #[test]
171 fn all_zeros_dequantize_to_negative() {
172 let block = make_block(3.0, [0x00; 16]);
173 for i in 0..128 {
174 assert!(!block.sign_bit(i));
175 assert!((block.weight(i) + 3.0).abs() < 0.01);
176 }
177 }
178
179 #[test]
180 fn alternating_bits() {
181 let block = make_block(1.0, [0xAA; 16]);
183 for i in 0..128 {
184 if i % 2 == 0 {
185 assert!(!block.sign_bit(i), "bit {i} should be 0");
186 } else {
187 assert!(block.sign_bit(i), "bit {i} should be 1");
188 }
189 }
190 }
191
192 #[test]
193 fn from_bytes_roundtrip() {
194 let block = make_block(1.5, [0xFF; 16]);
195 let bytes: &[u8] = unsafe {
196 std::slice::from_raw_parts(
197 &block as *const BlockQ1_0G128 as *const u8,
198 BLOCK_SIZE_BYTES,
199 )
200 };
201 let parsed = BlockQ1_0G128::from_bytes(bytes).expect("block parse should succeed");
202 assert_eq!(parsed, &block);
203 }
204
205 #[test]
206 fn one_bit_tensor_dequantize() {
207 let block = make_block(2.0, [0xFF; 16]);
208 let bytes: Vec<u8> = unsafe {
209 std::slice::from_raw_parts(
210 &block as *const BlockQ1_0G128 as *const u8,
211 BLOCK_SIZE_BYTES,
212 )
213 .to_vec()
214 };
215 let tensor = OneBitTensor::from_raw("test".to_string(), vec![128], &bytes)
216 .expect("tensor creation should succeed");
217 assert_eq!(tensor.num_blocks(), 1);
218 assert_eq!(tensor.element_count(), 128);
219
220 let values = tensor.dequantize_all();
221 for &v in &values {
222 assert!((v - 2.0).abs() < 0.01);
223 }
224 }
225}