Skip to main content

nodedb_codec/vector_quant/ternary/
codec.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`TernaryCodec`] — BitNet b1.58 ternary [`VectorCodec`] implementation.
4
5use crate::vector_quant::codec::VectorCodec;
6use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
7
8use super::packing::{cold_to_hot, pack_hot, quantize, unpack_hot};
9use super::simd::ternary_dot;
10
11/// BitNet b1.58 ternary codec.
12///
13/// Stateless aside from `dim`. Per-vector scaling (absmean) is computed in
14/// [`VectorCodec::encode`] and stored in the header's `global_scale` field.
15#[non_exhaustive]
16pub struct TernaryCodec {
17    pub dim: usize,
18}
19
20impl TernaryCodec {
21    pub fn new(dim: usize) -> Self {
22        Self { dim }
23    }
24}
25
26/// Owned ternary-quantized vector (wraps [`UnifiedQuantizedVector`]).
27pub struct TernaryQuantized(pub UnifiedQuantizedVector);
28
29impl AsRef<UnifiedQuantizedVector> for TernaryQuantized {
30    #[inline]
31    fn as_ref(&self) -> &UnifiedQuantizedVector {
32        &self.0
33    }
34}
35
36/// Prepared ternary query: hot-packed trits + per-vector scale.
37pub struct TernaryQuery {
38    pub trits_hot: Vec<u8>,
39    pub scale: f32,
40}
41
42/// Squared L2 norm of ternary trits scaled by `scale`.
43fn scaled_norm_sq(hot: &[u8], dim: usize, scale: f32) -> f32 {
44    let trits = unpack_hot(hot, dim);
45    let count: i32 = trits.iter().map(|&t| (t != 0) as i32).sum();
46    scale * scale * count as f32
47}
48
49/// `‖a·sa - b·sb‖² ≈ sa²·‖a‖² + sb²·‖b‖² - 2·sa·sb·dot(a,b)`
50fn l2_from_dot(dot: i32, norm_a: f32, norm_b: f32, sa: f32, sb: f32) -> f32 {
51    (norm_a + norm_b - 2.0 * sa * sb * dot as f32).max(0.0)
52}
53
54/// Expand cold-packed bits to hot if needed.
55fn ensure_hot(packed: &[u8], quant_mode: u16, dim: usize) -> std::borrow::Cow<'_, [u8]> {
56    if quant_mode == QuantMode::TernaryPacked as u16 {
57        std::borrow::Cow::Owned(cold_to_hot(packed, dim))
58    } else {
59        std::borrow::Cow::Borrowed(packed)
60    }
61}
62
63impl VectorCodec for TernaryCodec {
64    type Quantized = TernaryQuantized;
65    type Query = TernaryQuery;
66
67    fn encode(&self, v: &[f32]) -> TernaryQuantized {
68        let (trits, scale) = quantize(v);
69        let hot = pack_hot(&trits);
70        let header = QuantHeader {
71            quant_mode: QuantMode::TernarySimd as u16,
72            dim: self.dim as u16,
73            global_scale: scale,
74            residual_norm: 0.0,
75            dot_quantized: 0.0,
76            outlier_bitmask: 0,
77            reserved: [0; 8],
78        };
79        let uqv = UnifiedQuantizedVector::new(header, &hot, &[])
80            .expect("ternary encode: layout must be valid");
81        TernaryQuantized(uqv)
82    }
83
84    fn prepare_query(&self, q: &[f32]) -> TernaryQuery {
85        let (trits, scale) = quantize(q);
86        TernaryQuery {
87            trits_hot: pack_hot(&trits),
88            scale,
89        }
90    }
91
92    fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
93        let qh = q.0.header();
94        let vh = v.0.header();
95
96        let q_hot = ensure_hot(q.0.packed_bits(), qh.quant_mode, self.dim);
97        let v_hot = ensure_hot(v.0.packed_bits(), vh.quant_mode, self.dim);
98
99        let dot = ternary_dot(&q_hot, &v_hot, self.dim);
100        let norm_q = scaled_norm_sq(&q_hot, self.dim, qh.global_scale);
101        let norm_v = scaled_norm_sq(&v_hot, self.dim, vh.global_scale);
102        l2_from_dot(dot, norm_q, norm_v, qh.global_scale, vh.global_scale)
103    }
104
105    fn exact_asymmetric_distance(&self, q: &TernaryQuery, v: &Self::Quantized) -> f32 {
106        let vh = v.0.header();
107
108        let v_hot = ensure_hot(v.0.packed_bits(), vh.quant_mode, self.dim);
109
110        let dot = ternary_dot(&q.trits_hot, &v_hot, self.dim);
111        let norm_q = scaled_norm_sq(&q.trits_hot, self.dim, q.scale);
112        let norm_v = scaled_norm_sq(&v_hot, self.dim, vh.global_scale);
113        l2_from_dot(dot, norm_q, norm_v, q.scale, vh.global_scale)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn encode_produces_hot_format_by_default() {
123        let codec = TernaryCodec::new(8);
124        let v = vec![1.0f32, -1.0, 0.5, -0.5, 0.1, -0.1, 0.9, -0.9];
125        let q = codec.encode(&v);
126        assert_eq!(q.0.header().quant_mode, QuantMode::TernarySimd as u16);
127    }
128
129    #[test]
130    fn dot_product_self_approx_norm_sq() {
131        let dim = 16;
132        let codec = TernaryCodec::new(dim);
133        let v: Vec<f32> = (0..dim)
134            .map(|i| {
135                if i % 3 == 0 {
136                    1.0
137                } else if i % 3 == 1 {
138                    -1.0
139                } else {
140                    0.0
141                }
142            })
143            .collect();
144        let qv = codec.encode(&v);
145        let dist = codec.fast_symmetric_distance(&qv, &qv);
146        assert!(dist < 1e-4, "self-distance should be ~0, got {dist}");
147    }
148}