nodedb_codec/vector_quant/ternary/
codec.rs1use crate::vector_quant::codec::VectorCodec;
6use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
7
8use super::packing::{cold_to_hot, pack_hot, quantize, unpack_hot};
9use super::simd::ternary_dot;
10
11#[non_exhaustive]
16pub struct TernaryCodec {
17 pub dim: usize,
18}
19
20impl TernaryCodec {
21 pub fn new(dim: usize) -> Self {
22 Self { dim }
23 }
24}
25
26pub struct TernaryQuantized(pub UnifiedQuantizedVector);
28
29impl AsRef<UnifiedQuantizedVector> for TernaryQuantized {
30 #[inline]
31 fn as_ref(&self) -> &UnifiedQuantizedVector {
32 &self.0
33 }
34}
35
36pub struct TernaryQuery {
38 pub trits_hot: Vec<u8>,
39 pub scale: f32,
40}
41
42fn scaled_norm_sq(hot: &[u8], dim: usize, scale: f32) -> f32 {
44 let trits = unpack_hot(hot, dim);
45 let count: i32 = trits.iter().map(|&t| (t != 0) as i32).sum();
46 scale * scale * count as f32
47}
48
49fn l2_from_dot(dot: i32, norm_a: f32, norm_b: f32, sa: f32, sb: f32) -> f32 {
51 (norm_a + norm_b - 2.0 * sa * sb * dot as f32).max(0.0)
52}
53
54fn ensure_hot(packed: &[u8], quant_mode: u16, dim: usize) -> std::borrow::Cow<'_, [u8]> {
56 if quant_mode == QuantMode::TernaryPacked as u16 {
57 std::borrow::Cow::Owned(cold_to_hot(packed, dim))
58 } else {
59 std::borrow::Cow::Borrowed(packed)
60 }
61}
62
63impl VectorCodec for TernaryCodec {
64 type Quantized = TernaryQuantized;
65 type Query = TernaryQuery;
66
67 fn encode(&self, v: &[f32]) -> TernaryQuantized {
68 let (trits, scale) = quantize(v);
69 let hot = pack_hot(&trits);
70 let header = QuantHeader {
71 quant_mode: QuantMode::TernarySimd as u16,
72 dim: self.dim as u16,
73 global_scale: scale,
74 residual_norm: 0.0,
75 dot_quantized: 0.0,
76 outlier_bitmask: 0,
77 reserved: [0; 8],
78 };
79 let uqv = UnifiedQuantizedVector::new(header, &hot, &[])
80 .expect("ternary encode: layout must be valid");
81 TernaryQuantized(uqv)
82 }
83
84 fn prepare_query(&self, q: &[f32]) -> TernaryQuery {
85 let (trits, scale) = quantize(q);
86 TernaryQuery {
87 trits_hot: pack_hot(&trits),
88 scale,
89 }
90 }
91
92 fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
93 let qh = q.0.header();
94 let vh = v.0.header();
95
96 let q_hot = ensure_hot(q.0.packed_bits(), qh.quant_mode, self.dim);
97 let v_hot = ensure_hot(v.0.packed_bits(), vh.quant_mode, self.dim);
98
99 let dot = ternary_dot(&q_hot, &v_hot, self.dim);
100 let norm_q = scaled_norm_sq(&q_hot, self.dim, qh.global_scale);
101 let norm_v = scaled_norm_sq(&v_hot, self.dim, vh.global_scale);
102 l2_from_dot(dot, norm_q, norm_v, qh.global_scale, vh.global_scale)
103 }
104
105 fn exact_asymmetric_distance(&self, q: &TernaryQuery, v: &Self::Quantized) -> f32 {
106 let vh = v.0.header();
107
108 let v_hot = ensure_hot(v.0.packed_bits(), vh.quant_mode, self.dim);
109
110 let dot = ternary_dot(&q.trits_hot, &v_hot, self.dim);
111 let norm_q = scaled_norm_sq(&q.trits_hot, self.dim, q.scale);
112 let norm_v = scaled_norm_sq(&v_hot, self.dim, vh.global_scale);
113 l2_from_dot(dot, norm_q, norm_v, q.scale, vh.global_scale)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn encode_produces_hot_format_by_default() {
123 let codec = TernaryCodec::new(8);
124 let v = vec![1.0f32, -1.0, 0.5, -0.5, 0.1, -0.1, 0.9, -0.9];
125 let q = codec.encode(&v);
126 assert_eq!(q.0.header().quant_mode, QuantMode::TernarySimd as u16);
127 }
128
129 #[test]
130 fn dot_product_self_approx_norm_sq() {
131 let dim = 16;
132 let codec = TernaryCodec::new(dim);
133 let v: Vec<f32> = (0..dim)
134 .map(|i| {
135 if i % 3 == 0 {
136 1.0
137 } else if i % 3 == 1 {
138 -1.0
139 } else {
140 0.0
141 }
142 })
143 .collect();
144 let qv = codec.encode(&v);
145 let dist = codec.fast_symmetric_distance(&qv, &qv);
146 assert!(dist < 1e-4, "self-distance should be ~0, got {dist}");
147 }
148}