nodedb_codec/vector_quant/
bbq.rs1use crate::error::CodecError;
28use crate::vector_quant::codec::VectorCodec;
29use crate::vector_quant::codec_envelope;
30use crate::vector_quant::hamming::hamming_distance;
31use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
32use serde::{Deserialize, Serialize};
33
34pub struct BbqQuantized(pub UnifiedQuantizedVector);
39
40impl AsRef<UnifiedQuantizedVector> for BbqQuantized {
41 #[inline]
42 fn as_ref(&self) -> &UnifiedQuantizedVector {
43 &self.0
44 }
45}
46
47pub struct BbqQuery {
51 pub centered: Vec<f32>,
53 pub signs: Vec<u8>,
55 pub query_norm: f32,
58 pub query_dot_quantized: f32,
60}
61
62#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
71pub struct BbqCodec {
72 pub dim: usize,
73 centroid: Vec<f32>,
76 pub oversample: u8,
80}
81
82impl BbqCodec {
83 pub fn calibrate(vectors: &[&[f32]], dim: usize, oversample: u8) -> Self {
91 let mut centroid = vec![0.0f32; dim];
92 if vectors.is_empty() {
93 return Self {
94 dim,
95 centroid,
96 oversample,
97 };
98 }
99 for v in vectors {
100 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
101 *c += x;
102 }
103 }
104 let n = vectors.len() as f32;
105 for c in &mut centroid {
106 *c /= n;
107 }
108 Self {
109 dim,
110 centroid,
111 oversample,
112 }
113 }
114
115 pub const ENVELOPE_MAGIC: &'static [u8; codec_envelope::MAGIC_LEN] = b"NDBBQ";
117
118 pub const ENVELOPE_VERSION: u8 = 1;
120
121 pub fn to_bytes(&self) -> Result<Vec<u8>, CodecError> {
123 codec_envelope::encode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, self)
124 }
125
126 pub fn from_bytes(buf: &[u8]) -> Result<Self, CodecError> {
128 codec_envelope::decode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, buf)
129 }
130
131 fn center(&self, v: &[f32], out: &mut Vec<f32>) {
135 out.clear();
136 out.extend(v.iter().zip(self.centroid.iter()).map(|(&x, &c)| x - c));
137 }
138
139 fn pack_signs(centered: &[f32]) -> Vec<u8> {
142 let nbytes = centered.len().div_ceil(8);
143 let mut bits = vec![0u8; nbytes];
144 for (i, &x) in centered.iter().enumerate() {
145 if x >= 0.0 {
146 bits[i / 8] |= 1 << (7 - (i % 8));
147 }
148 }
149 bits
150 }
151
152 fn norm(v: &[f32]) -> f32 {
154 v.iter().map(|&x| x * x).sum::<f32>().sqrt()
155 }
156
157 fn dot(a: &[f32], b: &[f32]) -> f32 {
159 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
160 }
161
162 fn dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
169 let scale = if dim > 0 {
170 residual_norm / (dim as f32).sqrt()
171 } else {
172 0.0
173 };
174 (0..dim)
175 .map(|i| {
176 let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
177 if bit != 0 { scale } else { -scale }
178 })
179 .collect()
180 }
181}
182
183impl VectorCodec for BbqCodec {
184 type Quantized = BbqQuantized;
185 type Query = BbqQuery;
186
187 fn encode(&self, v: &[f32]) -> BbqQuantized {
188 let mut centered = Vec::with_capacity(self.dim);
190 self.center(v, &mut centered);
191
192 let packed = Self::pack_signs(¢ered);
194
195 let residual_norm = Self::norm(¢ered);
200
201 let sign_fp: Vec<f32> = centered
205 .iter()
206 .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
207 .collect();
208 let dot_vs = Self::dot(¢ered, &sign_fp);
209 let dot_quantized = if residual_norm > 0.0 {
210 dot_vs / residual_norm
211 } else {
212 0.0
213 };
214
215 let centroid_norm = Self::norm(&self.centroid);
219 let dot_vc = Self::dot(¢ered, &self.centroid);
220 let query_alignment = if centroid_norm > 0.0 {
221 dot_vc / centroid_norm
222 } else {
223 0.0
224 };
225
226 let reserved = [0u8; 8];
228 let header = QuantHeader {
231 quant_mode: QuantMode::Bbq as u16,
232 dim: self.dim as u16,
233 global_scale: query_alignment,
234 residual_norm,
235 dot_quantized,
236 outlier_bitmask: 0,
237 reserved,
238 };
239
240 let uqv = UnifiedQuantizedVector::new(header, &packed, &[]).expect(
241 "BBQ encode: UnifiedQuantizedVector construction must succeed with no outliers",
242 );
243 BbqQuantized(uqv)
244 }
245
246 fn prepare_query(&self, q: &[f32]) -> BbqQuery {
247 let mut centered = Vec::with_capacity(self.dim);
248 self.center(q, &mut centered);
249 let signs = Self::pack_signs(¢ered);
250 let query_norm = Self::norm(¢ered);
251 let sign_fp: Vec<f32> = centered
252 .iter()
253 .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
254 .collect();
255 let dot_vs = Self::dot(¢ered, &sign_fp);
256 let query_dot_quantized = if query_norm > 0.0 {
257 dot_vs / query_norm
258 } else {
259 0.0
260 };
261 BbqQuery {
262 centered,
263 signs,
264 query_norm,
265 query_dot_quantized,
266 }
267 }
268
269 fn fast_symmetric_distance(&self, q: &BbqQuantized, v: &BbqQuantized) -> f32 {
276 let q_bits = q.0.packed_bits();
277 let v_bits = v.0.packed_bits();
278 let ham = hamming_distance(q_bits, v_bits);
279 let dim = self.dim as f32;
280 let dot_estimate = 1.0 - 2.0 * ham as f32 / dim;
281 let q_n = q.0.header().residual_norm;
282 let v_n = v.0.header().residual_norm;
283 (q_n * q_n + v_n * v_n - 2.0 * q_n * v_n * dot_estimate).max(0.0)
284 }
285
286 fn exact_asymmetric_distance(&self, q: &BbqQuery, v: &BbqQuantized) -> f32 {
294 let header = v.0.header();
295 let recon = Self::dequantize(v.0.packed_bits(), header.residual_norm, self.dim);
296 q.centered
298 .iter()
299 .zip(recon.iter())
300 .map(|(&a, &b)| (a - b) * (a - b))
301 .sum::<f32>()
302 .sqrt()
303 }
304}
305
306#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
313 let mut x = seed
315 .wrapping_mul(6364136223846793005)
316 .wrapping_add(1442695040888963407);
317 (0..dim)
318 .map(|_| {
319 x = x
320 .wrapping_mul(6364136223846793005)
321 .wrapping_add(1442695040888963407);
322 ((x >> 33) as f32) / (u32::MAX as f32) * 4.0 - 2.0
324 })
325 .collect()
326 }
327
328 #[test]
329 fn to_bytes_from_bytes_roundtrip() {
330 let dim = 32;
331 let vecs: Vec<Vec<f32>> = (0..4).map(|i| rand_vec(i, dim)).collect();
332 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
333 let codec = BbqCodec::calibrate(&refs, dim, 5);
334 let bytes = codec.to_bytes().expect("to_bytes should succeed");
335 let restored = BbqCodec::from_bytes(&bytes).expect("from_bytes should succeed");
336 assert_eq!(restored.dim, codec.dim);
337 assert_eq!(restored.oversample, codec.oversample);
338 assert_eq!(restored.centroid.len(), codec.centroid.len());
339 for (a, b) in restored.centroid.iter().zip(codec.centroid.iter()) {
340 assert!((a - b).abs() < 1e-6, "centroid mismatch: {a} vs {b}");
341 }
342 }
343
344 #[test]
345 fn from_bytes_rejects_bad_magic() {
346 let mut bytes = b"WRONG".to_vec();
347 bytes.push(1);
348 bytes.extend_from_slice(&[0u8; 4]);
349 assert!(BbqCodec::from_bytes(&bytes).is_err());
350 }
351
352 #[test]
353 fn from_bytes_rejects_bad_version() {
354 let codec = BbqCodec::calibrate(&[], 4, 3);
355 let mut bytes = codec.to_bytes().unwrap();
356 bytes[5] = 42;
357 assert!(BbqCodec::from_bytes(&bytes).is_err());
358 }
359
360 #[test]
361 fn calibrate_centroid_mean() {
362 let dim = 8;
363 let a = vec![1.0f32; dim];
365 let b = vec![3.0f32; dim];
366 let c = vec![2.0f32; dim];
367 let refs: Vec<&[f32]> = vec![&a, &b, &c];
368 let codec = BbqCodec::calibrate(&refs, dim, 3);
369 for &x in &codec.centroid {
370 assert!((x - 2.0).abs() < 1e-5, "expected centroid 2.0, got {x}");
371 }
372 }
373
374 #[test]
375 fn calibrate_empty_gives_zero_centroid() {
376 let codec = BbqCodec::calibrate(&[], 4, 3);
377 assert!(codec.centroid.iter().all(|&x| x == 0.0));
378 }
379
380 #[test]
381 fn encode_packed_bits_length() {
382 let dim = 128;
383 let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
384 let refs: Vec<&[f32]> = vec![v.as_slice()];
385 let codec = BbqCodec::calibrate(&refs, dim, 3);
386 let q = codec.encode(&v);
387 let expected_bytes = dim.div_ceil(8);
388 assert_eq!(
389 q.0.packed_bits().len(),
390 expected_bytes,
391 "packed bits length should be dim.div_ceil(8)"
392 );
393 }
394
395 #[test]
396 fn encode_odd_dim_packed_bits_length() {
397 let dim = 17;
399 let v: Vec<f32> = (0..dim).map(|i| i as f32 - 8.0).collect();
400 let refs: Vec<&[f32]> = vec![v.as_slice()];
401 let codec = BbqCodec::calibrate(&refs, dim, 3);
402 let q = codec.encode(&v);
403 assert_eq!(q.0.packed_bits().len(), 3);
404 }
405
406 #[test]
407 fn hamming_scalar_vs_self_zero() {
408 let bits = vec![0b10101010u8, 0b11001100, 0b11110000];
409 assert_eq!(hamming_distance(&bits, &bits), 0);
410 }
411
412 #[test]
413 fn hamming_scalar_known_distance() {
414 let a = vec![0xFFu8];
416 let b = vec![0x00u8];
417 assert_eq!(hamming_distance(&a, &b), 8);
418 }
419
420 #[test]
421 fn hamming_multi_byte_agreement() {
422 let dim = 64;
423 let a: Vec<u8> = (0..dim as u8).collect();
424 let b: Vec<u8> = a.iter().map(|&x| !x).collect();
425 assert_eq!(hamming_distance(&a, &b), 512);
427 }
428
429 #[test]
430 fn distance_non_negative_finite() {
431 let dim = 32;
432 let vecs: Vec<Vec<f32>> = (0..8).map(|i| rand_vec(i, dim)).collect();
433 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
434 let codec = BbqCodec::calibrate(&refs, dim, 3);
435
436 for i in 0..vecs.len() {
437 for j in 0..vecs.len() {
438 let qi = codec.encode(&vecs[i]);
439 let qj = codec.encode(&vecs[j]);
440 let sym = codec.fast_symmetric_distance(&qi, &qj);
441 assert!(
442 sym.is_finite() && sym >= 0.0,
443 "fast_symmetric_distance({i},{j}) = {sym}"
444 );
445
446 let query = codec.prepare_query(&vecs[i]);
447 let asym = codec.exact_asymmetric_distance(&query, &qj);
448 assert!(
449 asym.is_finite() && asym >= 0.0,
450 "exact_asymmetric_distance({i},{j}) = {asym}"
451 );
452 }
453 }
454 }
455
456 #[test]
457 fn oversample_default_is_three() {
458 let codec = BbqCodec::calibrate(&[], 4, 3);
459 assert_eq!(codec.oversample, 3);
460 }
461
462 #[test]
463 fn encode_quant_mode_is_bbq() {
464 let dim = 16;
465 let v: Vec<f32> = vec![1.0; dim];
466 let refs: Vec<&[f32]> = vec![v.as_slice()];
467 let codec = BbqCodec::calibrate(&refs, dim, 3);
468 let q = codec.encode(&v);
469 assert_eq!(q.0.header().quant_mode, QuantMode::Bbq as u16);
470 }
471}