nodedb_codec/vector_quant/
bbq.rs1use crate::vector_quant::codec::VectorCodec;
28use crate::vector_quant::hamming::hamming_distance;
29use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
30
31pub struct BbqQuantized(pub UnifiedQuantizedVector);
36
37impl AsRef<UnifiedQuantizedVector> for BbqQuantized {
38 #[inline]
39 fn as_ref(&self) -> &UnifiedQuantizedVector {
40 &self.0
41 }
42}
43
44pub struct BbqQuery {
48 pub centered: Vec<f32>,
50 pub signs: Vec<u8>,
52 pub query_norm: f32,
55 pub query_dot_quantized: f32,
57}
58
59pub struct BbqCodec {
68 pub dim: usize,
69 centroid: Vec<f32>,
72 pub oversample: u8,
76}
77
78impl BbqCodec {
79 pub fn calibrate(vectors: &[&[f32]], dim: usize, oversample: u8) -> Self {
87 let mut centroid = vec![0.0f32; dim];
88 if vectors.is_empty() {
89 return Self {
90 dim,
91 centroid,
92 oversample,
93 };
94 }
95 for v in vectors {
96 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
97 *c += x;
98 }
99 }
100 let n = vectors.len() as f32;
101 for c in &mut centroid {
102 *c /= n;
103 }
104 Self {
105 dim,
106 centroid,
107 oversample,
108 }
109 }
110
111 fn center(&self, v: &[f32], out: &mut Vec<f32>) {
115 out.clear();
116 out.extend(v.iter().zip(self.centroid.iter()).map(|(&x, &c)| x - c));
117 }
118
119 fn pack_signs(centered: &[f32]) -> Vec<u8> {
122 let nbytes = centered.len().div_ceil(8);
123 let mut bits = vec![0u8; nbytes];
124 for (i, &x) in centered.iter().enumerate() {
125 if x >= 0.0 {
126 bits[i / 8] |= 1 << (7 - (i % 8));
127 }
128 }
129 bits
130 }
131
132 fn norm(v: &[f32]) -> f32 {
134 v.iter().map(|&x| x * x).sum::<f32>().sqrt()
135 }
136
137 fn dot(a: &[f32], b: &[f32]) -> f32 {
139 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
140 }
141
142 fn dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
149 let scale = if dim > 0 {
150 residual_norm / (dim as f32).sqrt()
151 } else {
152 0.0
153 };
154 (0..dim)
155 .map(|i| {
156 let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
157 if bit != 0 { scale } else { -scale }
158 })
159 .collect()
160 }
161}
162
163impl VectorCodec for BbqCodec {
164 type Quantized = BbqQuantized;
165 type Query = BbqQuery;
166
167 fn encode(&self, v: &[f32]) -> BbqQuantized {
168 let mut centered = Vec::with_capacity(self.dim);
170 self.center(v, &mut centered);
171
172 let packed = Self::pack_signs(¢ered);
174
175 let residual_norm = Self::norm(¢ered);
180
181 let sign_fp: Vec<f32> = centered
185 .iter()
186 .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
187 .collect();
188 let dot_vs = Self::dot(¢ered, &sign_fp);
189 let dot_quantized = if residual_norm > 0.0 {
190 dot_vs / residual_norm
191 } else {
192 0.0
193 };
194
195 let centroid_norm = Self::norm(&self.centroid);
199 let dot_vc = Self::dot(¢ered, &self.centroid);
200 let query_alignment = if centroid_norm > 0.0 {
201 dot_vc / centroid_norm
202 } else {
203 0.0
204 };
205
206 let reserved = [0u8; 8];
208 let header = QuantHeader {
211 quant_mode: QuantMode::Bbq as u16,
212 dim: self.dim as u16,
213 global_scale: query_alignment,
214 residual_norm,
215 dot_quantized,
216 outlier_bitmask: 0,
217 reserved,
218 };
219
220 let uqv = UnifiedQuantizedVector::new(header, &packed, &[]).expect(
221 "BBQ encode: UnifiedQuantizedVector construction must succeed with no outliers",
222 );
223 BbqQuantized(uqv)
224 }
225
226 fn prepare_query(&self, q: &[f32]) -> BbqQuery {
227 let mut centered = Vec::with_capacity(self.dim);
228 self.center(q, &mut centered);
229 let signs = Self::pack_signs(¢ered);
230 let query_norm = Self::norm(¢ered);
231 let sign_fp: Vec<f32> = centered
232 .iter()
233 .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
234 .collect();
235 let dot_vs = Self::dot(¢ered, &sign_fp);
236 let query_dot_quantized = if query_norm > 0.0 {
237 dot_vs / query_norm
238 } else {
239 0.0
240 };
241 BbqQuery {
242 centered,
243 signs,
244 query_norm,
245 query_dot_quantized,
246 }
247 }
248
249 fn fast_symmetric_distance(&self, q: &BbqQuantized, v: &BbqQuantized) -> f32 {
256 let q_bits = q.0.packed_bits();
257 let v_bits = v.0.packed_bits();
258 let ham = hamming_distance(q_bits, v_bits);
259 let dim = self.dim as f32;
260 let dot_estimate = 1.0 - 2.0 * ham as f32 / dim;
261 let q_n = q.0.header().residual_norm;
262 let v_n = v.0.header().residual_norm;
263 (q_n * q_n + v_n * v_n - 2.0 * q_n * v_n * dot_estimate).max(0.0)
264 }
265
266 fn exact_asymmetric_distance(&self, q: &BbqQuery, v: &BbqQuantized) -> f32 {
274 let header = v.0.header();
275 let recon = Self::dequantize(v.0.packed_bits(), header.residual_norm, self.dim);
276 q.centered
278 .iter()
279 .zip(recon.iter())
280 .map(|(&a, &b)| (a - b) * (a - b))
281 .sum::<f32>()
282 .sqrt()
283 }
284}
285
286#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
293 let mut x = seed
295 .wrapping_mul(6364136223846793005)
296 .wrapping_add(1442695040888963407);
297 (0..dim)
298 .map(|_| {
299 x = x
300 .wrapping_mul(6364136223846793005)
301 .wrapping_add(1442695040888963407);
302 ((x >> 33) as f32) / (u32::MAX as f32) * 4.0 - 2.0
304 })
305 .collect()
306 }
307
308 #[test]
309 fn calibrate_centroid_mean() {
310 let dim = 8;
311 let a = vec![1.0f32; dim];
313 let b = vec![3.0f32; dim];
314 let c = vec![2.0f32; dim];
315 let refs: Vec<&[f32]> = vec![&a, &b, &c];
316 let codec = BbqCodec::calibrate(&refs, dim, 3);
317 for &x in &codec.centroid {
318 assert!((x - 2.0).abs() < 1e-5, "expected centroid 2.0, got {x}");
319 }
320 }
321
322 #[test]
323 fn calibrate_empty_gives_zero_centroid() {
324 let codec = BbqCodec::calibrate(&[], 4, 3);
325 assert!(codec.centroid.iter().all(|&x| x == 0.0));
326 }
327
328 #[test]
329 fn encode_packed_bits_length() {
330 let dim = 128;
331 let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
332 let refs: Vec<&[f32]> = vec![v.as_slice()];
333 let codec = BbqCodec::calibrate(&refs, dim, 3);
334 let q = codec.encode(&v);
335 let expected_bytes = dim.div_ceil(8);
336 assert_eq!(
337 q.0.packed_bits().len(),
338 expected_bytes,
339 "packed bits length should be dim.div_ceil(8)"
340 );
341 }
342
343 #[test]
344 fn encode_odd_dim_packed_bits_length() {
345 let dim = 17;
347 let v: Vec<f32> = (0..dim).map(|i| i as f32 - 8.0).collect();
348 let refs: Vec<&[f32]> = vec![v.as_slice()];
349 let codec = BbqCodec::calibrate(&refs, dim, 3);
350 let q = codec.encode(&v);
351 assert_eq!(q.0.packed_bits().len(), 3);
352 }
353
354 #[test]
355 fn hamming_scalar_vs_self_zero() {
356 let bits = vec![0b10101010u8, 0b11001100, 0b11110000];
357 assert_eq!(hamming_distance(&bits, &bits), 0);
358 }
359
360 #[test]
361 fn hamming_scalar_known_distance() {
362 let a = vec![0xFFu8];
364 let b = vec![0x00u8];
365 assert_eq!(hamming_distance(&a, &b), 8);
366 }
367
368 #[test]
369 fn hamming_multi_byte_agreement() {
370 let dim = 64;
371 let a: Vec<u8> = (0..dim as u8).collect();
372 let b: Vec<u8> = a.iter().map(|&x| !x).collect();
373 assert_eq!(hamming_distance(&a, &b), 512);
375 }
376
377 #[test]
378 fn distance_non_negative_finite() {
379 let dim = 32;
380 let vecs: Vec<Vec<f32>> = (0..8).map(|i| rand_vec(i, dim)).collect();
381 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
382 let codec = BbqCodec::calibrate(&refs, dim, 3);
383
384 for i in 0..vecs.len() {
385 for j in 0..vecs.len() {
386 let qi = codec.encode(&vecs[i]);
387 let qj = codec.encode(&vecs[j]);
388 let sym = codec.fast_symmetric_distance(&qi, &qj);
389 assert!(
390 sym.is_finite() && sym >= 0.0,
391 "fast_symmetric_distance({i},{j}) = {sym}"
392 );
393
394 let query = codec.prepare_query(&vecs[i]);
395 let asym = codec.exact_asymmetric_distance(&query, &qj);
396 assert!(
397 asym.is_finite() && asym >= 0.0,
398 "exact_asymmetric_distance({i},{j}) = {asym}"
399 );
400 }
401 }
402 }
403
404 #[test]
405 fn oversample_default_is_three() {
406 let codec = BbqCodec::calibrate(&[], 4, 3);
407 assert_eq!(codec.oversample, 3);
408 }
409
410 #[test]
411 fn encode_quant_mode_is_bbq() {
412 let dim = 16;
413 let v: Vec<f32> = vec![1.0; dim];
414 let refs: Vec<&[f32]> = vec![v.as_slice()];
415 let codec = BbqCodec::calibrate(&refs, dim, 3);
416 let q = codec.encode(&v);
417 assert_eq!(q.0.header().quant_mode, QuantMode::Bbq as u16);
418 }
419}