Skip to main content

nodedb_codec/vector_quant/
bbq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! BBQ — Better Binary Quantization (Elasticsearch 9.1, mid-2025).
4//!
5//! Centroid-centered asymmetric 1-bit quantizer with 14-byte corrective
6//! factors per vector.  Empirically beats raw FP32 NDCG on 9/10 BEIR datasets
7//! via oversampling rerank.
8//!
9//! ## Corrective factor layout (14 bytes → unified header fields)
10//!
11//! | Bytes | Header field      | Meaning                                          |
12//! |-------|-------------------|--------------------------------------------------|
13//! | 4     | `residual_norm`   | ‖v − c‖ (centroid distance)                     |
14//! | 4     | `dot_quantized`   | ⟨v′, sign(v′)⟩ / ‖v′‖  (quantization quality)  |
15//! | 4     | `global_scale`    | ⟨v′, c⟩ / ‖c‖            (query alignment)      |
16//! | 2     | `reserved[0..2]`  | reserved for future correctives                  |
17//!
18//! where v′ = v − c.  Together these 14 bytes enable the asymmetric corrective
19//! distance used at rerank without storing full FP32 per vector.
20//!
21//! ## Oversampling
22//!
23//! The codec does not execute rerank itself; it exposes `oversample` so the
24//! caller fetches `oversample × top_k` coarse candidates and reruns
25//! `exact_asymmetric_distance` on them.  Default is 3×.
26
27use crate::error::CodecError;
28use crate::vector_quant::codec::VectorCodec;
29use crate::vector_quant::codec_envelope;
30use crate::vector_quant::hamming::hamming_distance;
31use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
32use serde::{Deserialize, Serialize};
33
34// ── BbqQuantized ────────────────────────────────────────────────────────────
35
36/// Owned quantized BBQ vector.  Wraps a [`UnifiedQuantizedVector`] with
37/// `QuantMode::Bbq`.
38pub struct BbqQuantized(pub UnifiedQuantizedVector);
39
40impl AsRef<UnifiedQuantizedVector> for BbqQuantized {
41    #[inline]
42    fn as_ref(&self) -> &UnifiedQuantizedVector {
43        &self.0
44    }
45}
46
47// ── BbqQuery ────────────────────────────────────────────────────────────────
48
49/// Prepared query for BBQ distance computation.
50pub struct BbqQuery {
51    /// Query vector after centroid subtraction (FP32, length = dim).
52    pub centered: Vec<f32>,
53    /// Sign-packed bits of `centered` (length = dim.div_ceil(8)).
54    pub signs: Vec<u8>,
55    /// ‖centered‖₂ — stored in header `residual_norm` of a synthetic query
56    /// entry for reuse across candidates.
57    pub query_norm: f32,
58    /// ⟨centered, sign(centered)⟩ / query_norm — quantization quality factor.
59    pub query_dot_quantized: f32,
60}
61
62// ── BbqCodec ────────────────────────────────────────────────────────────────
63
64/// BBQ centroid-centered asymmetric 1-bit quantization codec.
65///
66/// Calibration computes the dataset centroid which is subtracted from each
67/// vector before sign quantization.  The resulting 1-bit codes are Hamming-
68/// coarse-comparable; exact rerank uses the 14-byte corrective factors stored
69/// in the unified header.
70#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
71pub struct BbqCodec {
72    pub dim: usize,
73    /// Centroid of training data.  BBQ is centroid-asymmetric — all encode
74    /// calls subtract this before quantizing.
75    centroid: Vec<f32>,
76    /// Oversample multiplier for the caller's rerank pass.  Caller fetches
77    /// `oversample × top_k` coarse candidates from the Hamming coarse pass,
78    /// then calls `exact_asymmetric_distance` on each.
79    pub oversample: u8,
80}
81
82impl BbqCodec {
83    /// Calibrate a new [`BbqCodec`] from a set of training vectors.
84    ///
85    /// Computes the centroid as the mean of all training vectors.
86    ///
87    /// # Panics
88    ///
89    /// Does not panic; returns a zero-centroid codec if `vectors` is empty.
90    pub fn calibrate(vectors: &[&[f32]], dim: usize, oversample: u8) -> Self {
91        let mut centroid = vec![0.0f32; dim];
92        if vectors.is_empty() {
93            return Self {
94                dim,
95                centroid,
96                oversample,
97            };
98        }
99        for v in vectors {
100            for (c, &x) in centroid.iter_mut().zip(v.iter()) {
101                *c += x;
102            }
103        }
104        let n = vectors.len() as f32;
105        for c in &mut centroid {
106            *c /= n;
107        }
108        Self {
109            dim,
110            centroid,
111            oversample,
112        }
113    }
114
115    /// On-disk magic for BBQ codec envelopes.
116    pub const ENVELOPE_MAGIC: &'static [u8; codec_envelope::MAGIC_LEN] = b"NDBBQ";
117
118    /// Current on-disk envelope version for BBQ codecs.
119    pub const ENVELOPE_VERSION: u8 = 1;
120
121    /// Serialize this codec to a self-describing byte buffer.
122    pub fn to_bytes(&self) -> Result<Vec<u8>, CodecError> {
123        codec_envelope::encode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, self)
124    }
125
126    /// Deserialize a codec from a byte buffer produced by [`Self::to_bytes`].
127    pub fn from_bytes(buf: &[u8]) -> Result<Self, CodecError> {
128        codec_envelope::decode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, buf)
129    }
130
131    // ── Internal helpers ─────────────────────────────────────────────────────
132
133    /// Subtract centroid from `v`, storing result in `out`.
134    fn center(&self, v: &[f32], out: &mut Vec<f32>) {
135        out.clear();
136        out.extend(v.iter().zip(self.centroid.iter()).map(|(&x, &c)| x - c));
137    }
138
139    /// Pack signs of `centered` into bytes (1 bit per dim, MSB-first within
140    /// each byte).  Returns byte vector of length `dim.div_ceil(8)`.
141    fn pack_signs(centered: &[f32]) -> Vec<u8> {
142        let nbytes = centered.len().div_ceil(8);
143        let mut bits = vec![0u8; nbytes];
144        for (i, &x) in centered.iter().enumerate() {
145            if x >= 0.0 {
146                bits[i / 8] |= 1 << (7 - (i % 8));
147            }
148        }
149        bits
150    }
151
152    /// Compute ‖v‖₂.
153    fn norm(v: &[f32]) -> f32 {
154        v.iter().map(|&x| x * x).sum::<f32>().sqrt()
155    }
156
157    /// Compute ⟨a, b⟩.
158    fn dot(a: &[f32], b: &[f32]) -> f32 {
159        a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
160    }
161
162    /// Reconstruct approximate FP32 vector from sign bits and `residual_norm`.
163    ///
164    /// Each dimension is approximated as ±residual_norm / √dim, with the sign
165    /// taken from the packed bit.  This is the "high-fidelity asymmetric" path
166    /// used at rerank: the query is exact FP32; the stored vector is
167    /// reconstructed from its 1-bit code and the scalar corrective.
168    fn dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
169        let scale = if dim > 0 {
170            residual_norm / (dim as f32).sqrt()
171        } else {
172            0.0
173        };
174        (0..dim)
175            .map(|i| {
176                let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
177                if bit != 0 { scale } else { -scale }
178            })
179            .collect()
180    }
181}
182
183impl VectorCodec for BbqCodec {
184    type Quantized = BbqQuantized;
185    type Query = BbqQuery;
186
187    fn encode(&self, v: &[f32]) -> BbqQuantized {
188        // Step 1: center.
189        let mut centered = Vec::with_capacity(self.dim);
190        self.center(v, &mut centered);
191
192        // Step 2: pack sign bits.
193        let packed = Self::pack_signs(&centered);
194
195        // Step 3: corrective factors.
196        //
197        // residual_norm (4 B → header.residual_norm):
198        //   ‖v′‖  where v′ = v − c.  Used in the symmetric distance estimate.
199        let residual_norm = Self::norm(&centered);
200
201        // dot_quantized (4 B → header.dot_quantized):
202        //   ⟨v′, sign(v′)⟩ / ‖v′‖.  Measures how well the sign quantization
203        //   captures the direction of v′.
204        let sign_fp: Vec<f32> = centered
205            .iter()
206            .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
207            .collect();
208        let dot_vs = Self::dot(&centered, &sign_fp);
209        let dot_quantized = if residual_norm > 0.0 {
210            dot_vs / residual_norm
211        } else {
212            0.0
213        };
214
215        // global_scale (4 B → header.global_scale):
216        //   ⟨v′, c⟩ / ‖c‖.  Captures how aligned the centered vector is with
217        //   the centroid direction — used as a query-alignment corrective.
218        let centroid_norm = Self::norm(&self.centroid);
219        let dot_vc = Self::dot(&centered, &self.centroid);
220        let query_alignment = if centroid_norm > 0.0 {
221            dot_vc / centroid_norm
222        } else {
223            0.0
224        };
225
226        // reserved[0..2]: 2 bytes reserved for future correctives (zero-filled).
227        let reserved = [0u8; 8];
228        // reserved[0..2] are the 2 reserved corrective bytes; remainder is zero.
229
230        let header = QuantHeader {
231            quant_mode: QuantMode::Bbq as u16,
232            dim: self.dim as u16,
233            global_scale: query_alignment,
234            residual_norm,
235            dot_quantized,
236            outlier_bitmask: 0,
237            reserved,
238        };
239
240        let uqv = UnifiedQuantizedVector::new(header, &packed, &[]).expect(
241            "BBQ encode: UnifiedQuantizedVector construction must succeed with no outliers",
242        );
243        BbqQuantized(uqv)
244    }
245
246    fn prepare_query(&self, q: &[f32]) -> BbqQuery {
247        let mut centered = Vec::with_capacity(self.dim);
248        self.center(q, &mut centered);
249        let signs = Self::pack_signs(&centered);
250        let query_norm = Self::norm(&centered);
251        let sign_fp: Vec<f32> = centered
252            .iter()
253            .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
254            .collect();
255        let dot_vs = Self::dot(&centered, &sign_fp);
256        let query_dot_quantized = if query_norm > 0.0 {
257            dot_vs / query_norm
258        } else {
259            0.0
260        };
261        BbqQuery {
262            centered,
263            signs,
264            query_norm,
265            query_dot_quantized,
266        }
267    }
268
269    /// Fast Hamming-based symmetric distance estimate.
270    ///
271    /// Uses the asymmetric corrective distance formula:
272    ///   approx = q_n² + v_n² − 2 · q_n · v_n · dot_estimate
273    /// where `dot_estimate = 1 − 2·hamming/dim` maps the Hamming count to
274    /// a normalised cosine-like similarity on {−1,+1} codes.
275    fn fast_symmetric_distance(&self, q: &BbqQuantized, v: &BbqQuantized) -> f32 {
276        let q_bits = q.0.packed_bits();
277        let v_bits = v.0.packed_bits();
278        let ham = hamming_distance(q_bits, v_bits);
279        let dim = self.dim as f32;
280        let dot_estimate = 1.0 - 2.0 * ham as f32 / dim;
281        let q_n = q.0.header().residual_norm;
282        let v_n = v.0.header().residual_norm;
283        (q_n * q_n + v_n * v_n - 2.0 * q_n * v_n * dot_estimate).max(0.0)
284    }
285
286    /// Exact asymmetric L2 distance using the dequantized stored vector.
287    ///
288    /// The query is exact centered FP32 (`q.centered`).  The stored vector is
289    /// reconstructed from its sign bits and `residual_norm` via
290    /// [`BbqCodec::dequantize`]: each dimension ≈ ±residual_norm / √dim.
291    /// This is the high-fidelity asymmetric path invoked during rerank on the
292    /// `oversample × top_k` candidates returned by the coarse Hamming pass.
293    fn exact_asymmetric_distance(&self, q: &BbqQuery, v: &BbqQuantized) -> f32 {
294        let header = v.0.header();
295        let recon = Self::dequantize(v.0.packed_bits(), header.residual_norm, self.dim);
296        // L2(q.centered, recon)
297        q.centered
298            .iter()
299            .zip(recon.iter())
300            .map(|(&a, &b)| (a - b) * (a - b))
301            .sum::<f32>()
302            .sqrt()
303    }
304}
305
306// ── Tests ────────────────────────────────────────────────────────────────────
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
313        // Simple deterministic LCG.
314        let mut x = seed
315            .wrapping_mul(6364136223846793005)
316            .wrapping_add(1442695040888963407);
317        (0..dim)
318            .map(|_| {
319                x = x
320                    .wrapping_mul(6364136223846793005)
321                    .wrapping_add(1442695040888963407);
322                // Map to [-2, 2].
323                ((x >> 33) as f32) / (u32::MAX as f32) * 4.0 - 2.0
324            })
325            .collect()
326    }
327
328    #[test]
329    fn to_bytes_from_bytes_roundtrip() {
330        let dim = 32;
331        let vecs: Vec<Vec<f32>> = (0..4).map(|i| rand_vec(i, dim)).collect();
332        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
333        let codec = BbqCodec::calibrate(&refs, dim, 5);
334        let bytes = codec.to_bytes().expect("to_bytes should succeed");
335        let restored = BbqCodec::from_bytes(&bytes).expect("from_bytes should succeed");
336        assert_eq!(restored.dim, codec.dim);
337        assert_eq!(restored.oversample, codec.oversample);
338        assert_eq!(restored.centroid.len(), codec.centroid.len());
339        for (a, b) in restored.centroid.iter().zip(codec.centroid.iter()) {
340            assert!((a - b).abs() < 1e-6, "centroid mismatch: {a} vs {b}");
341        }
342    }
343
344    #[test]
345    fn from_bytes_rejects_bad_magic() {
346        let mut bytes = b"WRONG".to_vec();
347        bytes.push(1);
348        bytes.extend_from_slice(&[0u8; 4]);
349        assert!(BbqCodec::from_bytes(&bytes).is_err());
350    }
351
352    #[test]
353    fn from_bytes_rejects_bad_version() {
354        let codec = BbqCodec::calibrate(&[], 4, 3);
355        let mut bytes = codec.to_bytes().unwrap();
356        bytes[5] = 42;
357        assert!(BbqCodec::from_bytes(&bytes).is_err());
358    }
359
360    #[test]
361    fn calibrate_centroid_mean() {
362        let dim = 8;
363        // Three simple vectors: centroid should be element-wise mean.
364        let a = vec![1.0f32; dim];
365        let b = vec![3.0f32; dim];
366        let c = vec![2.0f32; dim];
367        let refs: Vec<&[f32]> = vec![&a, &b, &c];
368        let codec = BbqCodec::calibrate(&refs, dim, 3);
369        for &x in &codec.centroid {
370            assert!((x - 2.0).abs() < 1e-5, "expected centroid 2.0, got {x}");
371        }
372    }
373
374    #[test]
375    fn calibrate_empty_gives_zero_centroid() {
376        let codec = BbqCodec::calibrate(&[], 4, 3);
377        assert!(codec.centroid.iter().all(|&x| x == 0.0));
378    }
379
380    #[test]
381    fn encode_packed_bits_length() {
382        let dim = 128;
383        let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
384        let refs: Vec<&[f32]> = vec![v.as_slice()];
385        let codec = BbqCodec::calibrate(&refs, dim, 3);
386        let q = codec.encode(&v);
387        let expected_bytes = dim.div_ceil(8);
388        assert_eq!(
389            q.0.packed_bits().len(),
390            expected_bytes,
391            "packed bits length should be dim.div_ceil(8)"
392        );
393    }
394
395    #[test]
396    fn encode_odd_dim_packed_bits_length() {
397        // dim = 17 → ceil(17/8) = 3 bytes.
398        let dim = 17;
399        let v: Vec<f32> = (0..dim).map(|i| i as f32 - 8.0).collect();
400        let refs: Vec<&[f32]> = vec![v.as_slice()];
401        let codec = BbqCodec::calibrate(&refs, dim, 3);
402        let q = codec.encode(&v);
403        assert_eq!(q.0.packed_bits().len(), 3);
404    }
405
406    #[test]
407    fn hamming_scalar_vs_self_zero() {
408        let bits = vec![0b10101010u8, 0b11001100, 0b11110000];
409        assert_eq!(hamming_distance(&bits, &bits), 0);
410    }
411
412    #[test]
413    fn hamming_scalar_known_distance() {
414        // 0xFF ^ 0x00 = 8 bits set.
415        let a = vec![0xFFu8];
416        let b = vec![0x00u8];
417        assert_eq!(hamming_distance(&a, &b), 8);
418    }
419
420    #[test]
421    fn hamming_multi_byte_agreement() {
422        let dim = 64;
423        let a: Vec<u8> = (0..dim as u8).collect();
424        let b: Vec<u8> = a.iter().map(|&x| !x).collect();
425        // Every byte is fully flipped → all 64 × 8 = 512 bits differ.
426        assert_eq!(hamming_distance(&a, &b), 512);
427    }
428
429    #[test]
430    fn distance_non_negative_finite() {
431        let dim = 32;
432        let vecs: Vec<Vec<f32>> = (0..8).map(|i| rand_vec(i, dim)).collect();
433        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
434        let codec = BbqCodec::calibrate(&refs, dim, 3);
435
436        for i in 0..vecs.len() {
437            for j in 0..vecs.len() {
438                let qi = codec.encode(&vecs[i]);
439                let qj = codec.encode(&vecs[j]);
440                let sym = codec.fast_symmetric_distance(&qi, &qj);
441                assert!(
442                    sym.is_finite() && sym >= 0.0,
443                    "fast_symmetric_distance({i},{j}) = {sym}"
444                );
445
446                let query = codec.prepare_query(&vecs[i]);
447                let asym = codec.exact_asymmetric_distance(&query, &qj);
448                assert!(
449                    asym.is_finite() && asym >= 0.0,
450                    "exact_asymmetric_distance({i},{j}) = {asym}"
451                );
452            }
453        }
454    }
455
456    #[test]
457    fn oversample_default_is_three() {
458        let codec = BbqCodec::calibrate(&[], 4, 3);
459        assert_eq!(codec.oversample, 3);
460    }
461
462    #[test]
463    fn encode_quant_mode_is_bbq() {
464        let dim = 16;
465        let v: Vec<f32> = vec![1.0; dim];
466        let refs: Vec<&[f32]> = vec![v.as_slice()];
467        let codec = BbqCodec::calibrate(&refs, dim, 3);
468        let q = codec.encode(&v);
469        assert_eq!(q.0.header().quant_mode, QuantMode::Bbq as u16);
470    }
471}