1use half::f16;
8
9use crate::error::{BonsaiError, BonsaiResult};
10
11pub const QK_TQ2_0_G128: usize = 128;
17
18pub const QK_TQ2_0: usize = 256;
20
21pub const BLOCK_TQ2_0_G128_BYTES: usize = 34;
23
24pub const BLOCK_TQ2_0_BYTES: usize = 66;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33#[repr(u8)]
34pub enum TernaryCode {
35 Neg = 0b00,
37 Zero = 0b01,
39 Pos = 0b10,
41}
42
43impl TernaryCode {
44 pub fn to_i8(self) -> i8 {
46 match self {
47 Self::Neg => -1,
48 Self::Zero => 0,
49 Self::Pos => 1,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq)]
63#[repr(C)]
64pub struct BlockTQ2_0_g128 {
65 pub qs: [u8; 32],
67 pub d: f16,
69}
70
71const _: () = assert!(std::mem::size_of::<BlockTQ2_0_g128>() == BLOCK_TQ2_0_G128_BYTES);
72
73impl BlockTQ2_0_g128 {
74 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
78 let expected_len = blocks.len() * QK_TQ2_0_G128;
79 if output.len() < expected_len {
80 return Err(BonsaiError::KQuantError {
81 reason: format!(
82 "TQ2_0_g128 dequant: output len {} < expected {}",
83 output.len(),
84 expected_len
85 ),
86 });
87 }
88 for (block_idx, block) in blocks.iter().enumerate() {
89 let d = block.d.to_f32();
90 let base = block_idx * QK_TQ2_0_G128;
91 for j in 0..QK_TQ2_0_G128 {
92 let byte_idx = j / 4;
93 let lane = j % 4;
94 let code_val = Self::ternary_decode(block.qs[byte_idx], lane);
95 output[base + j] = d * (code_val as f32);
96 }
97 }
98 Ok(())
99 }
100
101 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
105 if input.len() % QK_TQ2_0_G128 != 0 {
106 return Err(BonsaiError::KQuantError {
107 reason: format!(
108 "TQ2_0_g128 quantize: input len {} not a multiple of {}",
109 input.len(),
110 QK_TQ2_0_G128
111 ),
112 });
113 }
114 let num_blocks = input.len() / QK_TQ2_0_G128;
115 let mut blocks = Vec::with_capacity(num_blocks);
116
117 for block_idx in 0..num_blocks {
118 let base = block_idx * QK_TQ2_0_G128;
119 let chunk = &input[base..base + QK_TQ2_0_G128];
120
121 let absmax = chunk
122 .iter()
123 .copied()
124 .fold(0.0f32, |acc, x| acc.max(x.abs()));
125
126 let mut qs = [0u8; 32];
127
128 if absmax == 0.0 {
129 for b in qs.iter_mut() {
131 *b = 0x55;
132 }
133 blocks.push(BlockTQ2_0_g128 { qs, d: f16::ZERO });
134 continue;
135 }
136
137 let threshold = 0.5 * absmax;
138 for (j, &x) in chunk.iter().enumerate() {
139 let code: u8 = if x >= threshold {
140 TernaryCode::Pos as u8 } else if x <= -threshold {
142 TernaryCode::Neg as u8 } else {
144 TernaryCode::Zero as u8 };
146 let byte_idx = j / 4;
147 let shift = (j % 4) * 2;
148 qs[byte_idx] |= code << shift;
149 }
150
151 blocks.push(BlockTQ2_0_g128 {
152 qs,
153 d: f16::from_f32(absmax),
154 });
155 }
156 Ok(blocks)
157 }
158
159 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
163 if data.len() % BLOCK_TQ2_0_G128_BYTES != 0 {
164 return Err(BonsaiError::KQuantError {
165 reason: format!(
166 "TQ2_0_g128 slice_from_bytes: byte len {} not a multiple of {}",
167 data.len(),
168 BLOCK_TQ2_0_G128_BYTES
169 ),
170 });
171 }
172 let align = std::mem::align_of::<Self>();
173 if data.as_ptr().align_offset(align) != 0 {
174 return Err(BonsaiError::KQuantError {
175 reason: format!(
176 "TQ2_0_g128 slice_from_bytes: pointer not {}-byte aligned",
177 align
178 ),
179 });
180 }
181 let count = data.len() / BLOCK_TQ2_0_G128_BYTES;
182 let ptr = data.as_ptr() as *const Self;
183 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
186 }
187
188 pub fn ternary_decode(byte: u8, lane: usize) -> i8 {
192 let shift = lane * 2;
193 let code = (byte >> shift) & 0x03;
194 match code {
195 0b00 => -1,
196 0b01 => 0,
197 0b10 => 1,
198 _ => 0, }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq)]
212#[repr(C)]
213pub struct BlockTQ2_0 {
214 pub qs: [u8; 64],
216 pub d: f16,
218}
219
220const _: () = assert!(std::mem::size_of::<BlockTQ2_0>() == BLOCK_TQ2_0_BYTES);
221
222impl BlockTQ2_0 {
223 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
227 let expected_len = blocks.len() * QK_TQ2_0;
228 if output.len() < expected_len {
229 return Err(BonsaiError::KQuantError {
230 reason: format!(
231 "TQ2_0 dequant: output len {} < expected {}",
232 output.len(),
233 expected_len
234 ),
235 });
236 }
237 for (block_idx, block) in blocks.iter().enumerate() {
238 let d = block.d.to_f32();
239 let base = block_idx * QK_TQ2_0;
240 for j in 0..QK_TQ2_0 {
241 let byte_idx = j / 4;
242 let lane = j % 4;
243 let code_val = ternary_decode_g256(block.qs[byte_idx], lane);
244 output[base + j] = d * (code_val as f32);
245 }
246 }
247 Ok(())
248 }
249
250 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
254 if input.len() % QK_TQ2_0 != 0 {
255 return Err(BonsaiError::KQuantError {
256 reason: format!(
257 "TQ2_0 quantize: input len {} not a multiple of {}",
258 input.len(),
259 QK_TQ2_0
260 ),
261 });
262 }
263 let num_blocks = input.len() / QK_TQ2_0;
264 let mut blocks = Vec::with_capacity(num_blocks);
265
266 for block_idx in 0..num_blocks {
267 let base = block_idx * QK_TQ2_0;
268 let chunk = &input[base..base + QK_TQ2_0];
269
270 let absmax = chunk
271 .iter()
272 .copied()
273 .fold(0.0f32, |acc, x| acc.max(x.abs()));
274
275 let mut qs = [0u8; 64];
276
277 if absmax == 0.0 {
278 for b in qs.iter_mut() {
279 *b = 0x55;
280 }
281 blocks.push(BlockTQ2_0 { qs, d: f16::ZERO });
282 continue;
283 }
284
285 let threshold = 0.5 * absmax;
286 for (j, &x) in chunk.iter().enumerate() {
287 let code: u8 = if x >= threshold {
288 TernaryCode::Pos as u8
289 } else if x <= -threshold {
290 TernaryCode::Neg as u8
291 } else {
292 TernaryCode::Zero as u8
293 };
294 let byte_idx = j / 4;
295 let shift = (j % 4) * 2;
296 qs[byte_idx] |= code << shift;
297 }
298
299 blocks.push(BlockTQ2_0 {
300 qs,
301 d: f16::from_f32(absmax),
302 });
303 }
304 Ok(blocks)
305 }
306
307 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
311 if data.len() % BLOCK_TQ2_0_BYTES != 0 {
312 return Err(BonsaiError::KQuantError {
313 reason: format!(
314 "TQ2_0 slice_from_bytes: byte len {} not a multiple of {}",
315 data.len(),
316 BLOCK_TQ2_0_BYTES
317 ),
318 });
319 }
320 let align = std::mem::align_of::<Self>();
321 if data.as_ptr().align_offset(align) != 0 {
322 return Err(BonsaiError::KQuantError {
323 reason: format!("TQ2_0 slice_from_bytes: pointer not {}-byte aligned", align),
324 });
325 }
326 let count = data.len() / BLOCK_TQ2_0_BYTES;
327 let ptr = data.as_ptr() as *const Self;
328 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
331 }
332}
333
334fn ternary_decode_g256(byte: u8, lane: usize) -> i8 {
338 let shift = lane * 2;
339 let code = (byte >> shift) & 0x03;
340 match code {
341 0b00 => -1,
342 0b01 => 0,
343 0b10 => 1,
344 _ => 0,
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn tq2_0_g128_block_size_correct() {
358 assert_eq!(
359 std::mem::size_of::<BlockTQ2_0_g128>(),
360 BLOCK_TQ2_0_G128_BYTES
361 );
362 assert_eq!(BLOCK_TQ2_0_G128_BYTES, 34);
363 }
364
365 #[test]
366 fn tq2_0_block_size_correct() {
367 assert_eq!(std::mem::size_of::<BlockTQ2_0>(), BLOCK_TQ2_0_BYTES);
368 assert_eq!(BLOCK_TQ2_0_BYTES, 66);
369 }
370
371 #[test]
372 fn tq2_0_g128_roundtrip_uniform() {
373 let mut input = [0.0f32; 128];
375 for (i, x) in input.iter_mut().enumerate() {
376 *x = match i % 3 {
377 0 => 0.5,
378 1 => -0.5,
379 _ => 0.0,
380 };
381 }
382 let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
383 let mut output = vec![0.0f32; 128];
384 BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
385 let mse: f32 = input
386 .iter()
387 .zip(output.iter())
388 .map(|(a, b)| (a - b) * (a - b))
389 .sum::<f32>()
390 / 128.0;
391 assert!(mse < 1e-3, "MSE {mse} too high");
392 }
393
394 #[test]
395 fn tq2_0_g128_all_zero_input() {
396 let input = [0.0f32; 128];
397 let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
398 assert_eq!(blocks.len(), 1);
399 assert_eq!(blocks[0].d, f16::ZERO);
400 let mut output = vec![0.0f32; 128];
401 BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
402 for &v in &output {
403 assert_eq!(v, 0.0, "all outputs should be zero");
404 }
405 }
406
407 #[test]
408 fn tq2_0_g128_all_positive() {
409 let input = [1.0f32; 128];
410 let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
411 assert_eq!(blocks.len(), 1);
412 assert!(
414 (blocks[0].d.to_f32() - 1.0).abs() < 1e-3,
415 "d should be ~1.0"
416 );
417 for &b in &blocks[0].qs {
419 assert_eq!(b, 0xAA, "all bytes should be 0xAA for all-positive");
420 }
421 }
422
423 #[test]
424 fn tq2_0_g128_all_negative() {
425 let input = [-1.0f32; 128];
426 let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
427 assert_eq!(blocks.len(), 1);
428 assert!(
430 (blocks[0].d.to_f32() - 1.0).abs() < 1e-3,
431 "d should be ~1.0"
432 );
433 for &b in &blocks[0].qs {
435 assert_eq!(b, 0x00, "all bytes should be 0x00 for all-negative");
436 }
437 }
438
439 #[test]
440 fn tq2_0_g128_mixed_threshold() {
441 let mut input = [0.0f32; 128];
449 let pattern = [2.0f32, 0.9, 0.0, -0.9, -2.0];
450 for i in 0..128 {
451 input[i] = pattern[i % 5];
452 }
453 let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
454 let mut output = vec![0.0f32; 128];
455 BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
456
457 let expected_pattern = [2.0f32, 0.0, 0.0, 0.0, -2.0];
458 for i in 0..128 {
459 let expected = expected_pattern[i % 5];
460 assert!(
461 (output[i] - expected).abs() < 1e-3,
462 "index {i}: expected {expected}, got {}",
463 output[i]
464 );
465 }
466 }
467
468 #[test]
469 fn tq2_0_g128_slice_from_bytes_misaligned() {
470 let data = vec![0u8; 35];
472 let result = BlockTQ2_0_g128::slice_from_bytes(&data);
473 assert!(result.is_err(), "35-byte slice should fail");
474 }
475
476 #[test]
477 fn tq2_0_g128_slice_from_bytes_aligned() {
478 let block = BlockTQ2_0_g128 {
480 qs: [0u8; 32],
481 d: f16::from_f32(1.0),
482 };
483 let bytes: &[u8] = unsafe {
484 std::slice::from_raw_parts(
485 &block as *const BlockTQ2_0_g128 as *const u8,
486 BLOCK_TQ2_0_G128_BYTES,
487 )
488 };
489 let result =
490 BlockTQ2_0_g128::slice_from_bytes(bytes).expect("aligned slice should succeed");
491 assert_eq!(result.len(), 1);
492 assert_eq!(result[0].d, f16::from_f32(1.0));
493 }
494
495 #[test]
496 fn tq2_0_roundtrip_random() {
497 let mut input = [0.0f32; 256];
499 for (i, x) in input.iter_mut().enumerate() {
500 *x = ((i as f32) / 128.0 - 1.0).clamp(-1.0, 1.0);
501 }
502 let blocks = BlockTQ2_0::quantize(&input).expect("quantize should succeed");
503 let mut output = vec![0.0f32; 256];
504 BlockTQ2_0::dequant(&blocks, &mut output).expect("dequant should succeed");
505 let mse: f32 = input
506 .iter()
507 .zip(output.iter())
508 .map(|(a, b)| (a - b) * (a - b))
509 .sum::<f32>()
510 / 256.0;
511 assert!(mse < 0.15, "MSE {mse} too high for TQ2_0 roundtrip");
515 }
516
517 #[test]
518 fn ternary_decode_all_lanes() {
519 let byte: u8 = 0b10011100;
526 assert_eq!(
527 BlockTQ2_0_g128::ternary_decode(byte, 0),
528 -1,
529 "lane 0: 0b00 → -1"
530 );
531 assert_eq!(
532 BlockTQ2_0_g128::ternary_decode(byte, 1),
533 0,
534 "lane 1: 0b11 → 0 (reserved)"
535 );
536 assert_eq!(
537 BlockTQ2_0_g128::ternary_decode(byte, 2),
538 0,
539 "lane 2: 0b01 → 0"
540 );
541 assert_eq!(
542 BlockTQ2_0_g128::ternary_decode(byte, 3),
543 1,
544 "lane 3: 0b10 → +1"
545 );
546 }
547}