Skip to main content

nodedb_codec/vector_quant/
bbq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! BBQ — Better Binary Quantization (Elasticsearch 9.1, mid-2025).
4//!
5//! Centroid-centered asymmetric 1-bit quantizer with 14-byte corrective
6//! factors per vector.  Empirically beats raw FP32 NDCG on 9/10 BEIR datasets
7//! via oversampling rerank.
8//!
9//! ## Corrective factor layout (14 bytes → unified header fields)
10//!
11//! | Bytes | Header field      | Meaning                                          |
12//! |-------|-------------------|--------------------------------------------------|
13//! | 4     | `residual_norm`   | ‖v − c‖ (centroid distance)                     |
14//! | 4     | `dot_quantized`   | ⟨v′, sign(v′)⟩ / ‖v′‖  (quantization quality)  |
15//! | 4     | `global_scale`    | ⟨v′, c⟩ / ‖c‖            (query alignment)      |
16//! | 2     | `reserved[0..2]`  | reserved for future correctives                  |
17//!
18//! where v′ = v − c.  Together these 14 bytes enable the asymmetric corrective
19//! distance used at rerank without storing full FP32 per vector.
20//!
21//! ## Oversampling
22//!
23//! The codec does not execute rerank itself; it exposes `oversample` so the
24//! caller fetches `oversample × top_k` coarse candidates and reruns
25//! `exact_asymmetric_distance` on them.  Default is 3×.
26
27use crate::vector_quant::codec::VectorCodec;
28use crate::vector_quant::hamming::hamming_distance;
29use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
30
31// ── BbqQuantized ────────────────────────────────────────────────────────────
32
33/// Owned quantized BBQ vector.  Wraps a [`UnifiedQuantizedVector`] with
34/// `QuantMode::Bbq`.
35pub struct BbqQuantized(pub UnifiedQuantizedVector);
36
37impl AsRef<UnifiedQuantizedVector> for BbqQuantized {
38    #[inline]
39    fn as_ref(&self) -> &UnifiedQuantizedVector {
40        &self.0
41    }
42}
43
44// ── BbqQuery ────────────────────────────────────────────────────────────────
45
46/// Prepared query for BBQ distance computation.
47pub struct BbqQuery {
48    /// Query vector after centroid subtraction (FP32, length = dim).
49    pub centered: Vec<f32>,
50    /// Sign-packed bits of `centered` (length = dim.div_ceil(8)).
51    pub signs: Vec<u8>,
52    /// ‖centered‖₂ — stored in header `residual_norm` of a synthetic query
53    /// entry for reuse across candidates.
54    pub query_norm: f32,
55    /// ⟨centered, sign(centered)⟩ / query_norm — quantization quality factor.
56    pub query_dot_quantized: f32,
57}
58
59// ── BbqCodec ────────────────────────────────────────────────────────────────
60
61/// BBQ centroid-centered asymmetric 1-bit quantization codec.
62///
63/// Calibration computes the dataset centroid which is subtracted from each
64/// vector before sign quantization.  The resulting 1-bit codes are Hamming-
65/// coarse-comparable; exact rerank uses the 14-byte corrective factors stored
66/// in the unified header.
67pub struct BbqCodec {
68    pub dim: usize,
69    /// Centroid of training data.  BBQ is centroid-asymmetric — all encode
70    /// calls subtract this before quantizing.
71    centroid: Vec<f32>,
72    /// Oversample multiplier for the caller's rerank pass.  Caller fetches
73    /// `oversample × top_k` coarse candidates from the Hamming coarse pass,
74    /// then calls `exact_asymmetric_distance` on each.
75    pub oversample: u8,
76}
77
78impl BbqCodec {
79    /// Calibrate a new [`BbqCodec`] from a set of training vectors.
80    ///
81    /// Computes the centroid as the mean of all training vectors.
82    ///
83    /// # Panics
84    ///
85    /// Does not panic; returns a zero-centroid codec if `vectors` is empty.
86    pub fn calibrate(vectors: &[&[f32]], dim: usize, oversample: u8) -> Self {
87        let mut centroid = vec![0.0f32; dim];
88        if vectors.is_empty() {
89            return Self {
90                dim,
91                centroid,
92                oversample,
93            };
94        }
95        for v in vectors {
96            for (c, &x) in centroid.iter_mut().zip(v.iter()) {
97                *c += x;
98            }
99        }
100        let n = vectors.len() as f32;
101        for c in &mut centroid {
102            *c /= n;
103        }
104        Self {
105            dim,
106            centroid,
107            oversample,
108        }
109    }
110
111    // ── Internal helpers ─────────────────────────────────────────────────────
112
113    /// Subtract centroid from `v`, storing result in `out`.
114    fn center(&self, v: &[f32], out: &mut Vec<f32>) {
115        out.clear();
116        out.extend(v.iter().zip(self.centroid.iter()).map(|(&x, &c)| x - c));
117    }
118
119    /// Pack signs of `centered` into bytes (1 bit per dim, MSB-first within
120    /// each byte).  Returns byte vector of length `dim.div_ceil(8)`.
121    fn pack_signs(centered: &[f32]) -> Vec<u8> {
122        let nbytes = centered.len().div_ceil(8);
123        let mut bits = vec![0u8; nbytes];
124        for (i, &x) in centered.iter().enumerate() {
125            if x >= 0.0 {
126                bits[i / 8] |= 1 << (7 - (i % 8));
127            }
128        }
129        bits
130    }
131
132    /// Compute ‖v‖₂.
133    fn norm(v: &[f32]) -> f32 {
134        v.iter().map(|&x| x * x).sum::<f32>().sqrt()
135    }
136
137    /// Compute ⟨a, b⟩.
138    fn dot(a: &[f32], b: &[f32]) -> f32 {
139        a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
140    }
141
142    /// Reconstruct approximate FP32 vector from sign bits and `residual_norm`.
143    ///
144    /// Each dimension is approximated as ±residual_norm / √dim, with the sign
145    /// taken from the packed bit.  This is the "high-fidelity asymmetric" path
146    /// used at rerank: the query is exact FP32; the stored vector is
147    /// reconstructed from its 1-bit code and the scalar corrective.
148    fn dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
149        let scale = if dim > 0 {
150            residual_norm / (dim as f32).sqrt()
151        } else {
152            0.0
153        };
154        (0..dim)
155            .map(|i| {
156                let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
157                if bit != 0 { scale } else { -scale }
158            })
159            .collect()
160    }
161}
162
163impl VectorCodec for BbqCodec {
164    type Quantized = BbqQuantized;
165    type Query = BbqQuery;
166
167    fn encode(&self, v: &[f32]) -> BbqQuantized {
168        // Step 1: center.
169        let mut centered = Vec::with_capacity(self.dim);
170        self.center(v, &mut centered);
171
172        // Step 2: pack sign bits.
173        let packed = Self::pack_signs(&centered);
174
175        // Step 3: corrective factors.
176        //
177        // residual_norm (4 B → header.residual_norm):
178        //   ‖v′‖  where v′ = v − c.  Used in the symmetric distance estimate.
179        let residual_norm = Self::norm(&centered);
180
181        // dot_quantized (4 B → header.dot_quantized):
182        //   ⟨v′, sign(v′)⟩ / ‖v′‖.  Measures how well the sign quantization
183        //   captures the direction of v′.
184        let sign_fp: Vec<f32> = centered
185            .iter()
186            .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
187            .collect();
188        let dot_vs = Self::dot(&centered, &sign_fp);
189        let dot_quantized = if residual_norm > 0.0 {
190            dot_vs / residual_norm
191        } else {
192            0.0
193        };
194
195        // global_scale (4 B → header.global_scale):
196        //   ⟨v′, c⟩ / ‖c‖.  Captures how aligned the centered vector is with
197        //   the centroid direction — used as a query-alignment corrective.
198        let centroid_norm = Self::norm(&self.centroid);
199        let dot_vc = Self::dot(&centered, &self.centroid);
200        let query_alignment = if centroid_norm > 0.0 {
201            dot_vc / centroid_norm
202        } else {
203            0.0
204        };
205
206        // reserved[0..2]: 2 bytes reserved for future correctives (zero-filled).
207        let reserved = [0u8; 8];
208        // reserved[0..2] are the 2 reserved corrective bytes; remainder is zero.
209
210        let header = QuantHeader {
211            quant_mode: QuantMode::Bbq as u16,
212            dim: self.dim as u16,
213            global_scale: query_alignment,
214            residual_norm,
215            dot_quantized,
216            outlier_bitmask: 0,
217            reserved,
218        };
219
220        let uqv = UnifiedQuantizedVector::new(header, &packed, &[]).expect(
221            "BBQ encode: UnifiedQuantizedVector construction must succeed with no outliers",
222        );
223        BbqQuantized(uqv)
224    }
225
226    fn prepare_query(&self, q: &[f32]) -> BbqQuery {
227        let mut centered = Vec::with_capacity(self.dim);
228        self.center(q, &mut centered);
229        let signs = Self::pack_signs(&centered);
230        let query_norm = Self::norm(&centered);
231        let sign_fp: Vec<f32> = centered
232            .iter()
233            .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
234            .collect();
235        let dot_vs = Self::dot(&centered, &sign_fp);
236        let query_dot_quantized = if query_norm > 0.0 {
237            dot_vs / query_norm
238        } else {
239            0.0
240        };
241        BbqQuery {
242            centered,
243            signs,
244            query_norm,
245            query_dot_quantized,
246        }
247    }
248
249    /// Fast Hamming-based symmetric distance estimate.
250    ///
251    /// Uses the asymmetric corrective distance formula:
252    ///   approx = q_n² + v_n² − 2 · q_n · v_n · dot_estimate
253    /// where `dot_estimate = 1 − 2·hamming/dim` maps the Hamming count to
254    /// a normalised cosine-like similarity on {−1,+1} codes.
255    fn fast_symmetric_distance(&self, q: &BbqQuantized, v: &BbqQuantized) -> f32 {
256        let q_bits = q.0.packed_bits();
257        let v_bits = v.0.packed_bits();
258        let ham = hamming_distance(q_bits, v_bits);
259        let dim = self.dim as f32;
260        let dot_estimate = 1.0 - 2.0 * ham as f32 / dim;
261        let q_n = q.0.header().residual_norm;
262        let v_n = v.0.header().residual_norm;
263        (q_n * q_n + v_n * v_n - 2.0 * q_n * v_n * dot_estimate).max(0.0)
264    }
265
266    /// Exact asymmetric L2 distance using the dequantized stored vector.
267    ///
268    /// The query is exact centered FP32 (`q.centered`).  The stored vector is
269    /// reconstructed from its sign bits and `residual_norm` via
270    /// [`BbqCodec::dequantize`]: each dimension ≈ ±residual_norm / √dim.
271    /// This is the high-fidelity asymmetric path invoked during rerank on the
272    /// `oversample × top_k` candidates returned by the coarse Hamming pass.
273    fn exact_asymmetric_distance(&self, q: &BbqQuery, v: &BbqQuantized) -> f32 {
274        let header = v.0.header();
275        let recon = Self::dequantize(v.0.packed_bits(), header.residual_norm, self.dim);
276        // L2(q.centered, recon)
277        q.centered
278            .iter()
279            .zip(recon.iter())
280            .map(|(&a, &b)| (a - b) * (a - b))
281            .sum::<f32>()
282            .sqrt()
283    }
284}
285
286// ── Tests ────────────────────────────────────────────────────────────────────
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
293        // Simple deterministic LCG.
294        let mut x = seed
295            .wrapping_mul(6364136223846793005)
296            .wrapping_add(1442695040888963407);
297        (0..dim)
298            .map(|_| {
299                x = x
300                    .wrapping_mul(6364136223846793005)
301                    .wrapping_add(1442695040888963407);
302                // Map to [-2, 2].
303                ((x >> 33) as f32) / (u32::MAX as f32) * 4.0 - 2.0
304            })
305            .collect()
306    }
307
308    #[test]
309    fn calibrate_centroid_mean() {
310        let dim = 8;
311        // Three simple vectors: centroid should be element-wise mean.
312        let a = vec![1.0f32; dim];
313        let b = vec![3.0f32; dim];
314        let c = vec![2.0f32; dim];
315        let refs: Vec<&[f32]> = vec![&a, &b, &c];
316        let codec = BbqCodec::calibrate(&refs, dim, 3);
317        for &x in &codec.centroid {
318            assert!((x - 2.0).abs() < 1e-5, "expected centroid 2.0, got {x}");
319        }
320    }
321
322    #[test]
323    fn calibrate_empty_gives_zero_centroid() {
324        let codec = BbqCodec::calibrate(&[], 4, 3);
325        assert!(codec.centroid.iter().all(|&x| x == 0.0));
326    }
327
328    #[test]
329    fn encode_packed_bits_length() {
330        let dim = 128;
331        let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
332        let refs: Vec<&[f32]> = vec![v.as_slice()];
333        let codec = BbqCodec::calibrate(&refs, dim, 3);
334        let q = codec.encode(&v);
335        let expected_bytes = dim.div_ceil(8);
336        assert_eq!(
337            q.0.packed_bits().len(),
338            expected_bytes,
339            "packed bits length should be dim.div_ceil(8)"
340        );
341    }
342
343    #[test]
344    fn encode_odd_dim_packed_bits_length() {
345        // dim = 17 → ceil(17/8) = 3 bytes.
346        let dim = 17;
347        let v: Vec<f32> = (0..dim).map(|i| i as f32 - 8.0).collect();
348        let refs: Vec<&[f32]> = vec![v.as_slice()];
349        let codec = BbqCodec::calibrate(&refs, dim, 3);
350        let q = codec.encode(&v);
351        assert_eq!(q.0.packed_bits().len(), 3);
352    }
353
354    #[test]
355    fn hamming_scalar_vs_self_zero() {
356        let bits = vec![0b10101010u8, 0b11001100, 0b11110000];
357        assert_eq!(hamming_distance(&bits, &bits), 0);
358    }
359
360    #[test]
361    fn hamming_scalar_known_distance() {
362        // 0xFF ^ 0x00 = 8 bits set.
363        let a = vec![0xFFu8];
364        let b = vec![0x00u8];
365        assert_eq!(hamming_distance(&a, &b), 8);
366    }
367
368    #[test]
369    fn hamming_multi_byte_agreement() {
370        let dim = 64;
371        let a: Vec<u8> = (0..dim as u8).collect();
372        let b: Vec<u8> = a.iter().map(|&x| !x).collect();
373        // Every byte is fully flipped → all 64 × 8 = 512 bits differ.
374        assert_eq!(hamming_distance(&a, &b), 512);
375    }
376
377    #[test]
378    fn distance_non_negative_finite() {
379        let dim = 32;
380        let vecs: Vec<Vec<f32>> = (0..8).map(|i| rand_vec(i, dim)).collect();
381        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
382        let codec = BbqCodec::calibrate(&refs, dim, 3);
383
384        for i in 0..vecs.len() {
385            for j in 0..vecs.len() {
386                let qi = codec.encode(&vecs[i]);
387                let qj = codec.encode(&vecs[j]);
388                let sym = codec.fast_symmetric_distance(&qi, &qj);
389                assert!(
390                    sym.is_finite() && sym >= 0.0,
391                    "fast_symmetric_distance({i},{j}) = {sym}"
392                );
393
394                let query = codec.prepare_query(&vecs[i]);
395                let asym = codec.exact_asymmetric_distance(&query, &qj);
396                assert!(
397                    asym.is_finite() && asym >= 0.0,
398                    "exact_asymmetric_distance({i},{j}) = {asym}"
399                );
400            }
401        }
402    }
403
404    #[test]
405    fn oversample_default_is_three() {
406        let codec = BbqCodec::calibrate(&[], 4, 3);
407        assert_eq!(codec.oversample, 3);
408    }
409
410    #[test]
411    fn encode_quant_mode_is_bbq() {
412        let dim = 16;
413        let v: Vec<f32> = vec![1.0; dim];
414        let refs: Vec<&[f32]> = vec![v.as_slice()];
415        let codec = BbqCodec::calibrate(&refs, dim, 3);
416        let q = codec.encode(&v);
417        assert_eq!(q.0.header().quant_mode, QuantMode::Bbq as u16);
418    }
419}