1use nodedb_codec::vector_quant::{
12 codec::{AdcLut, VectorCodec},
13 layout::{QuantHeader, QuantMode, UnifiedQuantizedVector},
14};
15
16use crate::quantize::pq::PqCodec;
17
18pub struct PqQuantized(pub UnifiedQuantizedVector);
22
23impl AsRef<UnifiedQuantizedVector> for PqQuantized {
24 #[inline]
25 fn as_ref(&self) -> &UnifiedQuantizedVector {
26 &self.0
27 }
28}
29
30pub struct PqQuery {
38 pub distance_table: Vec<Vec<f32>>,
40 pub raw: Vec<f32>,
42}
43
44#[inline]
47fn packed_bits_of(q: &PqQuantized) -> &[u8] {
48 q.0.packed_bits()
49}
50
51impl VectorCodec for PqCodec {
54 type Quantized = PqQuantized;
55 type Query = PqQuery;
57
58 fn encode(&self, v: &[f32]) -> Self::Quantized {
67 let codes = self.encode(v);
68 let header = QuantHeader {
69 quant_mode: QuantMode::Pq as u16,
70 dim: self.dim as u16,
71 global_scale: 0.0,
72 residual_norm: 0.0,
73 dot_quantized: 0.0,
74 outlier_bitmask: 0,
75 reserved: [0; 8],
76 };
77 let uqv = UnifiedQuantizedVector::new(header, &codes, &[])
78 .expect("PqCodec::encode: layout construction is infallible (no outliers)");
79 PqQuantized(uqv)
80 }
81
82 fn prepare_query(&self, q: &[f32]) -> Self::Query {
92 let distance_table = self.build_distance_table(q).expect(
93 "PqCodec::prepare_query: build_distance_table failed; \
94 if a governor is attached ensure it has sufficient budget",
95 );
96 PqQuery {
97 distance_table,
98 raw: q.to_vec(),
99 }
100 }
101
102 fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
105 let m = self.m as u16;
106 let k = self.k as u16;
107 let mut lut = AdcLut::new(m, k);
108 for (sub, sub_table) in q.distance_table.iter().enumerate() {
109 for (centroid, &dist) in sub_table.iter().enumerate() {
110 let idx = sub * self.k + centroid;
111 lut.table[idx] = dist;
112 }
113 }
114 Some(lut)
115 }
116
117 #[inline]
124 fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
125 let dq_a = self
126 .decode(packed_bits_of(q))
127 .expect("PqCodec::fast_symmetric_distance: decode failed");
128 let dq_b = self
129 .decode(packed_bits_of(v))
130 .expect("PqCodec::fast_symmetric_distance: decode failed");
131 dq_a.iter()
132 .zip(dq_b.iter())
133 .map(|(&a, &b)| {
134 let d = a - b;
135 d * d
136 })
137 .sum()
138 }
139
140 #[inline]
144 fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
145 self.asymmetric_distance(&q.distance_table, packed_bits_of(v))
146 }
147}
148
149#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn make_codec() -> PqCodec {
156 let vecs: Vec<Vec<f32>> = (0..80)
157 .map(|i| {
158 let c = (i / 20) as f32 * 5.0;
159 vec![
160 c + (i as f32) * 0.1,
161 c - (i as f32) * 0.05,
162 c + 1.0,
163 c - 1.0,
164 ]
165 })
166 .collect();
167 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
168 PqCodec::train(&refs, 4, 2, 8, 10)
169 }
170
171 #[test]
174 fn encode_packed_bits_matches_raw_encode() {
175 let codec = make_codec();
176 let v = vec![2.0f32, 1.0, 3.0, -1.0];
177 let raw = codec.encode(&v);
178 let quantized = <PqCodec as VectorCodec>::encode(&codec, &v);
179 assert_eq!(quantized.as_ref().packed_bits(), raw.as_slice());
180 }
181
182 #[test]
184 fn fast_symmetric_distance_is_non_negative_finite() {
185 let codec = make_codec();
186 let a = <PqCodec as VectorCodec>::encode(&codec, &[0.5, 0.1, 1.0, -0.5]);
187 let b = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
188 let d = codec.fast_symmetric_distance(&a, &b);
189 assert!(d.is_finite(), "expected finite distance, got {d}");
190 assert!(d >= 0.0, "expected non-negative distance, got {d}");
191 }
192
193 #[test]
195 fn exact_asymmetric_distance_is_non_negative_finite() {
196 let codec = make_codec();
197 let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
198 let v = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
199 let d = codec.exact_asymmetric_distance(&q, &v);
200 assert!(d.is_finite(), "expected finite distance, got {d}");
201 assert!(d >= 0.0, "expected non-negative distance, got {d}");
202 }
203
204 #[test]
206 fn adc_lut_has_correct_shape() {
207 let codec = make_codec();
208 let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
209 let lut =
210 <PqCodec as VectorCodec>::adc_lut(&codec, &q).expect("PqCodec must produce an AdcLut");
211 assert_eq!(lut.subspace_count, codec.m as u16);
212 assert_eq!(lut.centroids_per_subspace, codec.k as u16);
213 assert_eq!(lut.table.len(), codec.m * codec.k);
214 assert!(lut.table.iter().all(|v| v.is_finite()));
215 }
216
217 fn use_vector_codec<C: VectorCodec>(c: &C, q: &[f32], v: &[f32]) -> f32 {
219 let qv = c.encode(v);
220 let qq = c.prepare_query(q);
221 c.fast_symmetric_distance(&qv, &qv) + c.exact_asymmetric_distance(&qq, &qv)
222 }
223
224 #[test]
225 fn trait_bounds_compile() {
226 let codec = make_codec();
227 let result = use_vector_codec(&codec, &[0.5, 0.1, 1.0, -0.5], &[5.0, 4.0, 6.0, 4.5]);
228 assert!(result.is_finite());
229 }
230}