Skip to main content

oxibonsai_core/
quant_ternary.rs

1//! Ternary quantization block types for TQ2_0_g128 and TQ2_0 formats.
2//!
3//! Two ternary formats: `BlockTQ2_0_g128` (128 weights, 34 bytes, PrismML)
4//! and `BlockTQ2_0` (256 weights, 66 bytes, llama.cpp compat).
5//! Both use 2-bit coding: `00→-1`, `01→0`, `10→+1`, 4 weights per byte LSB-first.
6
7use half::f16;
8
9use crate::error::{BonsaiError, BonsaiResult};
10
11// ---------------------------------------------------------------------------
12// Constants
13// ---------------------------------------------------------------------------
14
15/// Number of weights per TQ2_0_g128 block.
16pub const QK_TQ2_0_G128: usize = 128;
17
18/// Number of weights per TQ2_0 block.
19pub const QK_TQ2_0: usize = 256;
20
21/// Number of bytes per TQ2_0_g128 block.
22pub const BLOCK_TQ2_0_G128_BYTES: usize = 34;
23
24/// Number of bytes per TQ2_0 block.
25pub const BLOCK_TQ2_0_BYTES: usize = 66;
26
27// ---------------------------------------------------------------------------
28// TernaryCode
29// ---------------------------------------------------------------------------
30
31/// Ternary weight code for 2-bit encoding.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33#[repr(u8)]
34pub enum TernaryCode {
35    /// Negative weight (-1): bit pattern `0b00`.
36    Neg = 0b00,
37    /// Zero weight (0): bit pattern `0b01`.
38    Zero = 0b01,
39    /// Positive weight (+1): bit pattern `0b10`.
40    Pos = 0b10,
41}
42
43impl TernaryCode {
44    /// Convert to integer representation: Neg→-1, Zero→0, Pos→+1.
45    pub fn to_i8(self) -> i8 {
46        match self {
47            Self::Neg => -1,
48            Self::Zero => 0,
49            Self::Pos => 1,
50        }
51    }
52}
53
54// ---------------------------------------------------------------------------
55// BlockTQ2_0_g128
56// ---------------------------------------------------------------------------
57
58/// TQ2_0_g128 block: 128 weights at 2 bits each, PrismML format.
59///
60/// Layout (34 bytes): `qs[32]` packed codes + `d` FP16 scale.
61/// Bit coding: `00→-1`, `01→0`, `10→+1`, 4 weights per byte LSB-first.
62#[derive(Debug, Clone, Copy, PartialEq)]
63#[repr(C)]
64pub struct BlockTQ2_0_g128 {
65    /// 128 × 2-bit quantized weights, 4 per byte, LSB-first.
66    pub qs: [u8; 32],
67    /// Block scale (FP16).
68    pub d: f16,
69}
70
71const _: () = assert!(std::mem::size_of::<BlockTQ2_0_g128>() == BLOCK_TQ2_0_G128_BYTES);
72
73impl BlockTQ2_0_g128 {
74    /// Dequantize a slice of TQ2_0_g128 blocks into f32 output.
75    ///
76    /// `output` must have length >= `blocks.len() * 128`.
77    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
78        let expected_len = blocks.len() * QK_TQ2_0_G128;
79        if output.len() < expected_len {
80            return Err(BonsaiError::KQuantError {
81                reason: format!(
82                    "TQ2_0_g128 dequant: output len {} < expected {}",
83                    output.len(),
84                    expected_len
85                ),
86            });
87        }
88        for (block_idx, block) in blocks.iter().enumerate() {
89            let d = block.d.to_f32();
90            let base = block_idx * QK_TQ2_0_G128;
91            for j in 0..QK_TQ2_0_G128 {
92                let byte_idx = j / 4;
93                let lane = j % 4;
94                let code_val = Self::ternary_decode(block.qs[byte_idx], lane);
95                output[base + j] = d * (code_val as f32);
96            }
97        }
98        Ok(())
99    }
100
101    /// Quantize f32 input into TQ2_0_g128 blocks.
102    ///
103    /// Input length must be a multiple of 128.
104    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
105        if input.len() % QK_TQ2_0_G128 != 0 {
106            return Err(BonsaiError::KQuantError {
107                reason: format!(
108                    "TQ2_0_g128 quantize: input len {} not a multiple of {}",
109                    input.len(),
110                    QK_TQ2_0_G128
111                ),
112            });
113        }
114        let num_blocks = input.len() / QK_TQ2_0_G128;
115        let mut blocks = Vec::with_capacity(num_blocks);
116
117        for block_idx in 0..num_blocks {
118            let base = block_idx * QK_TQ2_0_G128;
119            let chunk = &input[base..base + QK_TQ2_0_G128];
120
121            let absmax = chunk
122                .iter()
123                .copied()
124                .fold(0.0f32, |acc, x| acc.max(x.abs()));
125
126            let mut qs = [0u8; 32];
127
128            if absmax == 0.0 {
129                // All zero: code = 0b01 (Zero), qs bytes = 0b01_01_01_01 = 0x55
130                for b in qs.iter_mut() {
131                    *b = 0x55;
132                }
133                blocks.push(BlockTQ2_0_g128 { qs, d: f16::ZERO });
134                continue;
135            }
136
137            let threshold = 0.5 * absmax;
138            for (j, &x) in chunk.iter().enumerate() {
139                let code: u8 = if x >= threshold {
140                    TernaryCode::Pos as u8 // 0b10
141                } else if x <= -threshold {
142                    TernaryCode::Neg as u8 // 0b00
143                } else {
144                    TernaryCode::Zero as u8 // 0b01
145                };
146                let byte_idx = j / 4;
147                let shift = (j % 4) * 2;
148                qs[byte_idx] |= code << shift;
149            }
150
151            blocks.push(BlockTQ2_0_g128 {
152                qs,
153                d: f16::from_f32(absmax),
154            });
155        }
156        Ok(blocks)
157    }
158
159    /// Zero-copy cast of a byte slice to a slice of TQ2_0_g128 blocks.
160    ///
161    /// Returns error if length is not a multiple of 34 or pointer is misaligned.
162    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
163        if data.len() % BLOCK_TQ2_0_G128_BYTES != 0 {
164            return Err(BonsaiError::KQuantError {
165                reason: format!(
166                    "TQ2_0_g128 slice_from_bytes: byte len {} not a multiple of {}",
167                    data.len(),
168                    BLOCK_TQ2_0_G128_BYTES
169                ),
170            });
171        }
172        let align = std::mem::align_of::<Self>();
173        if data.as_ptr().align_offset(align) != 0 {
174            return Err(BonsaiError::KQuantError {
175                reason: format!(
176                    "TQ2_0_g128 slice_from_bytes: pointer not {}-byte aligned",
177                    align
178                ),
179            });
180        }
181        let count = data.len() / BLOCK_TQ2_0_G128_BYTES;
182        let ptr = data.as_ptr() as *const Self;
183        // SAFETY: repr(C) layout validated by compile-time assert; length and alignment
184        // checked above; lifetime tied to input slice.
185        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
186    }
187
188    /// Decode a 2-bit code at `lane` (0..4) from `byte`, returning the weight as i8.
189    ///
190    /// Code map: `0b00→-1`, `0b01→0`, `0b10→+1`, `0b11→0` (reserved treated as zero).
191    pub fn ternary_decode(byte: u8, lane: usize) -> i8 {
192        let shift = lane * 2;
193        let code = (byte >> shift) & 0x03;
194        match code {
195            0b00 => -1,
196            0b01 => 0,
197            0b10 => 1,
198            _ => 0, // 0b11 reserved → zero
199        }
200    }
201}
202
203// ---------------------------------------------------------------------------
204// BlockTQ2_0
205// ---------------------------------------------------------------------------
206
207/// TQ2_0 block: 256 weights at 2 bits each, llama.cpp compat format.
208///
209/// Layout (66 bytes): `qs[64]` packed codes + `d` FP16 scale.
210/// Same 2-bit coding as TQ2_0_g128.
211#[derive(Debug, Clone, Copy, PartialEq)]
212#[repr(C)]
213pub struct BlockTQ2_0 {
214    /// 256 × 2-bit quantized weights, 4 per byte, LSB-first.
215    pub qs: [u8; 64],
216    /// Block scale (FP16).
217    pub d: f16,
218}
219
220const _: () = assert!(std::mem::size_of::<BlockTQ2_0>() == BLOCK_TQ2_0_BYTES);
221
222impl BlockTQ2_0 {
223    /// Dequantize a slice of TQ2_0 blocks into f32 output.
224    ///
225    /// `output` must have length >= `blocks.len() * 256`.
226    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
227        let expected_len = blocks.len() * QK_TQ2_0;
228        if output.len() < expected_len {
229            return Err(BonsaiError::KQuantError {
230                reason: format!(
231                    "TQ2_0 dequant: output len {} < expected {}",
232                    output.len(),
233                    expected_len
234                ),
235            });
236        }
237        for (block_idx, block) in blocks.iter().enumerate() {
238            let d = block.d.to_f32();
239            let base = block_idx * QK_TQ2_0;
240            for j in 0..QK_TQ2_0 {
241                let byte_idx = j / 4;
242                let lane = j % 4;
243                let code_val = ternary_decode_g256(block.qs[byte_idx], lane);
244                output[base + j] = d * (code_val as f32);
245            }
246        }
247        Ok(())
248    }
249
250    /// Quantize f32 input into TQ2_0 blocks.
251    ///
252    /// Input length must be a multiple of 256.
253    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
254        if input.len() % QK_TQ2_0 != 0 {
255            return Err(BonsaiError::KQuantError {
256                reason: format!(
257                    "TQ2_0 quantize: input len {} not a multiple of {}",
258                    input.len(),
259                    QK_TQ2_0
260                ),
261            });
262        }
263        let num_blocks = input.len() / QK_TQ2_0;
264        let mut blocks = Vec::with_capacity(num_blocks);
265
266        for block_idx in 0..num_blocks {
267            let base = block_idx * QK_TQ2_0;
268            let chunk = &input[base..base + QK_TQ2_0];
269
270            let absmax = chunk
271                .iter()
272                .copied()
273                .fold(0.0f32, |acc, x| acc.max(x.abs()));
274
275            let mut qs = [0u8; 64];
276
277            if absmax == 0.0 {
278                for b in qs.iter_mut() {
279                    *b = 0x55;
280                }
281                blocks.push(BlockTQ2_0 { qs, d: f16::ZERO });
282                continue;
283            }
284
285            let threshold = 0.5 * absmax;
286            for (j, &x) in chunk.iter().enumerate() {
287                let code: u8 = if x >= threshold {
288                    TernaryCode::Pos as u8
289                } else if x <= -threshold {
290                    TernaryCode::Neg as u8
291                } else {
292                    TernaryCode::Zero as u8
293                };
294                let byte_idx = j / 4;
295                let shift = (j % 4) * 2;
296                qs[byte_idx] |= code << shift;
297            }
298
299            blocks.push(BlockTQ2_0 {
300                qs,
301                d: f16::from_f32(absmax),
302            });
303        }
304        Ok(blocks)
305    }
306
307    /// Zero-copy cast of a byte slice to a slice of TQ2_0 blocks.
308    ///
309    /// Returns error if length is not a multiple of 66 or pointer is misaligned.
310    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
311        if data.len() % BLOCK_TQ2_0_BYTES != 0 {
312            return Err(BonsaiError::KQuantError {
313                reason: format!(
314                    "TQ2_0 slice_from_bytes: byte len {} not a multiple of {}",
315                    data.len(),
316                    BLOCK_TQ2_0_BYTES
317                ),
318            });
319        }
320        let align = std::mem::align_of::<Self>();
321        if data.as_ptr().align_offset(align) != 0 {
322            return Err(BonsaiError::KQuantError {
323                reason: format!("TQ2_0 slice_from_bytes: pointer not {}-byte aligned", align),
324            });
325        }
326        let count = data.len() / BLOCK_TQ2_0_BYTES;
327        let ptr = data.as_ptr() as *const Self;
328        // SAFETY: repr(C) layout validated by compile-time assert; length and alignment
329        // checked above; lifetime tied to input slice.
330        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
331    }
332}
333
334/// Decode a 2-bit code at `lane` (0..4) from `byte` for BlockTQ2_0.
335///
336/// Code map: `0b00→-1`, `0b01→0`, `0b10→+1`, `0b11→0` (reserved treated as zero).
337fn ternary_decode_g256(byte: u8, lane: usize) -> i8 {
338    let shift = lane * 2;
339    let code = (byte >> shift) & 0x03;
340    match code {
341        0b00 => -1,
342        0b01 => 0,
343        0b10 => 1,
344        _ => 0,
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn tq2_0_g128_block_size_correct() {
358        assert_eq!(
359            std::mem::size_of::<BlockTQ2_0_g128>(),
360            BLOCK_TQ2_0_G128_BYTES
361        );
362        assert_eq!(BLOCK_TQ2_0_G128_BYTES, 34);
363    }
364
365    #[test]
366    fn tq2_0_block_size_correct() {
367        assert_eq!(std::mem::size_of::<BlockTQ2_0>(), BLOCK_TQ2_0_BYTES);
368        assert_eq!(BLOCK_TQ2_0_BYTES, 66);
369    }
370
371    #[test]
372    fn tq2_0_g128_roundtrip_uniform() {
373        // Alternating 0.5, -0.5, 0.0 pattern for 128 values.
374        let mut input = [0.0f32; 128];
375        for (i, x) in input.iter_mut().enumerate() {
376            *x = match i % 3 {
377                0 => 0.5,
378                1 => -0.5,
379                _ => 0.0,
380            };
381        }
382        let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
383        let mut output = vec![0.0f32; 128];
384        BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
385        let mse: f32 = input
386            .iter()
387            .zip(output.iter())
388            .map(|(a, b)| (a - b) * (a - b))
389            .sum::<f32>()
390            / 128.0;
391        assert!(mse < 1e-3, "MSE {mse} too high");
392    }
393
394    #[test]
395    fn tq2_0_g128_all_zero_input() {
396        let input = [0.0f32; 128];
397        let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
398        assert_eq!(blocks.len(), 1);
399        assert_eq!(blocks[0].d, f16::ZERO);
400        let mut output = vec![0.0f32; 128];
401        BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
402        for &v in &output {
403            assert_eq!(v, 0.0, "all outputs should be zero");
404        }
405    }
406
407    #[test]
408    fn tq2_0_g128_all_positive() {
409        let input = [1.0f32; 128];
410        let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
411        assert_eq!(blocks.len(), 1);
412        // absmax = 1.0 → d = f16(1.0)
413        assert!(
414            (blocks[0].d.to_f32() - 1.0).abs() < 1e-3,
415            "d should be ~1.0"
416        );
417        // All codes should be Pos (0b10), so each byte = 0b10101010 = 0xAA
418        for &b in &blocks[0].qs {
419            assert_eq!(b, 0xAA, "all bytes should be 0xAA for all-positive");
420        }
421    }
422
423    #[test]
424    fn tq2_0_g128_all_negative() {
425        let input = [-1.0f32; 128];
426        let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
427        assert_eq!(blocks.len(), 1);
428        // absmax = 1.0 → d = f16(1.0)
429        assert!(
430            (blocks[0].d.to_f32() - 1.0).abs() < 1e-3,
431            "d should be ~1.0"
432        );
433        // All codes should be Neg (0b00), so each byte = 0b00000000 = 0x00
434        for &b in &blocks[0].qs {
435            assert_eq!(b, 0x00, "all bytes should be 0x00 for all-negative");
436        }
437    }
438
439    #[test]
440    fn tq2_0_g128_mixed_threshold() {
441        // Pattern: [2.0, 0.9, 0.0, -0.9, -2.0] repeating to fill 128 elements.
442        // absmax=2.0, threshold=1.0:
443        //   2.0 ≥ 1.0  → Pos (+d = 2.0)
444        //   0.9 < 1.0  → Zero (0.0)
445        //   0.0 < 1.0  → Zero (0.0)
446        //  -0.9: abs=0.9 < 1.0 → Zero (0.0)
447        //  -2.0 ≤ -1.0 → Neg (-d = -2.0)
448        let mut input = [0.0f32; 128];
449        let pattern = [2.0f32, 0.9, 0.0, -0.9, -2.0];
450        for i in 0..128 {
451            input[i] = pattern[i % 5];
452        }
453        let blocks = BlockTQ2_0_g128::quantize(&input).expect("quantize should succeed");
454        let mut output = vec![0.0f32; 128];
455        BlockTQ2_0_g128::dequant(&blocks, &mut output).expect("dequant should succeed");
456
457        let expected_pattern = [2.0f32, 0.0, 0.0, 0.0, -2.0];
458        for i in 0..128 {
459            let expected = expected_pattern[i % 5];
460            assert!(
461                (output[i] - expected).abs() < 1e-3,
462                "index {i}: expected {expected}, got {}",
463                output[i]
464            );
465        }
466    }
467
468    #[test]
469    fn tq2_0_g128_slice_from_bytes_misaligned() {
470        // 35 bytes is not a multiple of 34 → should return Err.
471        let data = vec![0u8; 35];
472        let result = BlockTQ2_0_g128::slice_from_bytes(&data);
473        assert!(result.is_err(), "35-byte slice should fail");
474    }
475
476    #[test]
477    fn tq2_0_g128_slice_from_bytes_aligned() {
478        // Build a real block and reinterpret as bytes (guaranteed alignment).
479        let block = BlockTQ2_0_g128 {
480            qs: [0u8; 32],
481            d: f16::from_f32(1.0),
482        };
483        let bytes: &[u8] = unsafe {
484            std::slice::from_raw_parts(
485                &block as *const BlockTQ2_0_g128 as *const u8,
486                BLOCK_TQ2_0_G128_BYTES,
487            )
488        };
489        let result =
490            BlockTQ2_0_g128::slice_from_bytes(bytes).expect("aligned slice should succeed");
491        assert_eq!(result.len(), 1);
492        assert_eq!(result[0].d, f16::from_f32(1.0));
493    }
494
495    #[test]
496    fn tq2_0_roundtrip_random() {
497        // 256 values oscillating in [-1, 1].
498        let mut input = [0.0f32; 256];
499        for (i, x) in input.iter_mut().enumerate() {
500            *x = ((i as f32) / 128.0 - 1.0).clamp(-1.0, 1.0);
501        }
502        let blocks = BlockTQ2_0::quantize(&input).expect("quantize should succeed");
503        let mut output = vec![0.0f32; 256];
504        BlockTQ2_0::dequant(&blocks, &mut output).expect("dequant should succeed");
505        let mse: f32 = input
506            .iter()
507            .zip(output.iter())
508            .map(|(a, b)| (a - b) * (a - b))
509            .sum::<f32>()
510            / 256.0;
511        // TQ2_0 is a 3-level ternary quantizer; on a continuous ramp in [-1,1]
512        // a large fraction of values are zeroed (|x| < 0.5 * absmax), so MSE
513        // around 0.08–0.10 is expected.  Require < 0.15 to catch regressions.
514        assert!(mse < 0.15, "MSE {mse} too high for TQ2_0 roundtrip");
515    }
516
517    #[test]
518    fn ternary_decode_all_lanes() {
519        // Construct a byte to test all four lanes:
520        //   lane 0 (bits 1:0): 0b00 → -1
521        //   lane 1 (bits 3:2): 0b11 → 0 (reserved)
522        //   lane 2 (bits 5:4): 0b01 → 0
523        //   lane 3 (bits 7:6): 0b10 → +1
524        // Byte = 0b10_01_11_00 = 0b10011100 = 0x9C
525        let byte: u8 = 0b10011100;
526        assert_eq!(
527            BlockTQ2_0_g128::ternary_decode(byte, 0),
528            -1,
529            "lane 0: 0b00 → -1"
530        );
531        assert_eq!(
532            BlockTQ2_0_g128::ternary_decode(byte, 1),
533            0,
534            "lane 1: 0b11 → 0 (reserved)"
535        );
536        assert_eq!(
537            BlockTQ2_0_g128::ternary_decode(byte, 2),
538            0,
539            "lane 2: 0b01 → 0"
540        );
541        assert_eq!(
542            BlockTQ2_0_g128::ternary_decode(byte, 3),
543            1,
544            "lane 3: 0b10 → +1"
545        );
546    }
547}