Skip to main content

nodedb_vector/rerank/codecs/
bbq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `RerankCodec` wrapper for BBQ (Better Binary Quantization).
4//!
5//! BBQ is a training-based codec: `train()` calibrates a centroid from a
6//! sample of vectors. Until training is complete, `encode` and `prepare_query`
7//! return `RerankError::NotTrained`.
8//!
9//! Distance uses the asymmetric path from the BBQ paper: the query is kept in
10//! centered FP32; the stored vector is reconstructed from its 1-bit sign pack
11//! and `residual_norm` (≈ ±norm/√dim per dimension). The L2 distance between
12//! the exact centered query and the reconstructed candidate is returned.
13//!
14//! The prepared form is `PreparedQuery::Bytes` with the layout:
15//!   [0..4]         alpha (query_norm) as f32 little-endian
16//!   [4..4+dim*4]   centered f32 values, each as f32 little-endian
17
18use nodedb_codec::vector_quant::bbq::BbqCodec;
19use nodedb_codec::vector_quant::codec::VectorCodec as _;
20use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
21
22use crate::{
23    rerank::codec::{CodecName, PreparedQuery, RerankCodec},
24    rerank::types::RerankError,
25};
26
27// ── Payload helpers ───────────────────────────────────────────────────────────
28
29fn encode_payload(query_norm: f32, centered: &[f32]) -> Vec<u8> {
30    // Layout: 4 bytes alpha (query_norm f32 LE) || dim * 4 bytes centered f32 LE
31    let mut buf = Vec::with_capacity(4 + centered.len() * 4);
32    buf.extend_from_slice(&query_norm.to_le_bytes());
33    for &x in centered {
34        buf.extend_from_slice(&x.to_le_bytes());
35    }
36    buf
37}
38
39fn decode_payload(payload: &[u8], dim: usize) -> Result<(f32, Vec<f32>), RerankError> {
40    let expected = 4 + dim * 4;
41    if payload.len() != expected {
42        return Err(RerankError::BadInput(format!(
43            "bbq distance: payload len {} != expected {} for dim {}",
44            payload.len(),
45            expected,
46            dim
47        )));
48    }
49    let query_norm = f32::from_le_bytes(
50        payload[..4]
51            .try_into()
52            .expect("slice of 4 bytes always converts to [u8;4]"),
53    );
54    let centered: Vec<f32> = payload[4..]
55        .chunks_exact(4)
56        .map(|b| f32::from_le_bytes(b.try_into().expect("chunks_exact(4) always 4 bytes")))
57        .collect();
58    Ok((query_norm, centered))
59}
60
61// ── Inline dequantize (mirrors BbqCodec::dequantize, which is private) ────────
62
63/// Reconstruct an approximate FP32 vector from BBQ sign bits and residual norm.
64///
65/// Each dimension is approximated as ±residual_norm / √dim, with the sign
66/// taken from the packed bit (MSB-first within each byte, same as BBQ's
67/// `pack_signs`).
68#[inline]
69fn bbq_dequantize(packed: &[u8], residual_norm: f32, dim: usize) -> Vec<f32> {
70    let scale = if dim > 0 {
71        residual_norm / (dim as f32).sqrt()
72    } else {
73        0.0
74    };
75    (0..dim)
76        .map(|i| {
77            let bit = (packed[i / 8] >> (7 - (i % 8))) & 1;
78            if bit != 0 { scale } else { -scale }
79        })
80        .collect()
81}
82
83// ── BbqRerank ─────────────────────────────────────────────────────────────────
84
85/// Default oversample multiplier used when the caller does not specify one.
86pub const DEFAULT_OVERSAMPLE: u8 = 4;
87
88/// Object-safe `RerankCodec` wrapper around `BbqCodec`.
89///
90/// The codec starts untrained. `encode` and `prepare_query` return
91/// `RerankError::NotTrained` until `train()` has been called with a
92/// representative sample of vectors.
93///
94/// `from_codec` accepts a pre-calibrated `BbqCodec` (used when restoring
95/// from a snapshot).
96pub struct BbqRerank {
97    codec: Option<BbqCodec>,
98    dim: usize,
99    oversample: u8,
100}
101
102impl BbqRerank {
103    /// Construct an untrained wrapper.
104    ///
105    /// `encode` / `distance_prepared` return `RerankError::NotTrained` until
106    /// `train()` is called.
107    pub fn new(dim: usize, oversample: u8) -> Self {
108        Self {
109            codec: None,
110            dim,
111            oversample,
112        }
113    }
114
115    /// Construct from a pre-calibrated codec (used when restoring from snapshot).
116    pub fn from_codec(codec: BbqCodec) -> Self {
117        let dim = codec.dim;
118        Self {
119            codec: Some(codec),
120            dim,
121            oversample: DEFAULT_OVERSAMPLE,
122        }
123    }
124}
125
126impl RerankCodec for BbqRerank {
127    /// Encode a full-precision vector to BBQ 1-bit bytes.
128    ///
129    /// The serialised form is the raw `UnifiedQuantizedVector` buffer
130    /// (`as_bytes()`): 32-byte `QuantHeader` followed by `dim.div_ceil(8)`
131    /// sign-packed bits plus 14 bytes of corrective factors in the header.
132    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
133        if v.len() != self.dim {
134            return Err(RerankError::BadInput(format!(
135                "bbq encode: vector len {} != codec dim {}",
136                v.len(),
137                self.dim
138            )));
139        }
140        let codec = self.codec.as_ref().ok_or_else(|| {
141            RerankError::NotTrained(
142                "bbq: codec must be trained before encoding (call train() with a sample of vectors)"
143                    .to_string(),
144            )
145        })?;
146        let quantized = codec.encode(v);
147        Ok(quantized.as_ref().as_bytes().to_vec())
148    }
149
150    /// Prepare the query by centering it and serialising the exact FP32 centered
151    /// vector alongside the query norm.
152    ///
153    /// The prepared form is `PreparedQuery::Bytes` with the layout:
154    ///   4 bytes query_norm (f32 LE) || dim × 4 bytes centered f32 LE.
155    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
156        if q.len() != self.dim {
157            return Err(RerankError::BadInput(format!(
158                "bbq prepare_query: query len {} != codec dim {}",
159                q.len(),
160                self.dim
161            )));
162        }
163        let codec = self.codec.as_ref().ok_or_else(|| {
164            RerankError::NotTrained(
165                "bbq: codec must be trained before prepare_query (call train() with a sample of vectors)"
166                    .to_string(),
167            )
168        })?;
169        let query = codec.prepare_query(q);
170        Ok(PreparedQuery::Bytes(encode_payload(
171            query.query_norm,
172            &query.centered,
173        )))
174    }
175
176    /// Compute asymmetric L2 distance from a prepared query to a BBQ-encoded
177    /// candidate.
178    ///
179    /// The query is the exact centered FP32 vector. The stored candidate is
180    /// reconstructed from its sign bits and `residual_norm` (each dim ≈
181    /// ±norm/√dim). Returns L2 distance between them.
182    ///
183    /// Expects `PreparedQuery::Bytes` produced by `prepare_query`.
184    fn distance_prepared(
185        &self,
186        prepared: &PreparedQuery,
187        encoded: &[u8],
188    ) -> Result<f32, RerankError> {
189        let payload = match prepared {
190            PreparedQuery::Bytes(b) => b.as_slice(),
191            _ => {
192                return Err(RerankError::BadInput(
193                    "bbq distance: prepared query is not Bytes".to_string(),
194                ));
195            }
196        };
197
198        let (_query_norm, centered) = decode_payload(payload, self.dim)?;
199
200        let packed_len = self.dim.div_ceil(8);
201        let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
202            RerankError::BadInput(format!("bbq distance: failed to parse encoded bytes: {e}"))
203        })?;
204
205        let header = uqv_ref.header();
206        let recon = bbq_dequantize(uqv_ref.packed_bits(), header.residual_norm, self.dim);
207        let dist = centered
208            .iter()
209            .zip(recon.iter())
210            .map(|(&a, &b)| (a - b) * (a - b))
211            .sum::<f32>()
212            .sqrt();
213        Ok(dist)
214    }
215
216    fn name(&self) -> CodecName {
217        CodecName::Bbq
218    }
219
220    fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
221        let codec = self.codec.as_ref().ok_or_else(|| {
222            RerankError::NotTrained("bbq sidecar serialize: codec not trained".to_string())
223        })?;
224        codec
225            .to_bytes()
226            .map_err(|e| RerankError::BadInput(format!("bbq to_bytes: {e}")))
227    }
228
229    /// Calibrate from a sample of vectors.
230    ///
231    /// Validates that:
232    /// - `samples` is non-empty.
233    /// - Every sample has length `self.dim`.
234    ///
235    /// On success, stores the calibrated codec; subsequent `encode` /
236    /// `distance_prepared` calls will succeed.
237    fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
238        if samples.is_empty() {
239            return Err(RerankError::BadInput(
240                "bbq train: empty sample set".to_string(),
241            ));
242        }
243        for s in samples {
244            if s.len() != self.dim {
245                return Err(RerankError::BadInput(format!(
246                    "bbq train: sample has len {} but codec dim is {}",
247                    s.len(),
248                    self.dim
249                )));
250            }
251        }
252        let codec = BbqCodec::calibrate(samples, self.dim, self.oversample);
253        self.codec = Some(codec);
254        Ok(())
255    }
256}
257
258// ── Tests ─────────────────────────────────────────────────────────────────────
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    const DIM: usize = 16;
265    const N: usize = 64;
266
267    fn det_vec(i: usize, dim: usize) -> Vec<f32> {
268        (0..dim)
269            .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
270            .collect()
271    }
272
273    fn trained() -> BbqRerank {
274        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
275        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
276        let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
277        codec.train(&refs).expect("train must succeed");
278        codec
279    }
280
281    #[test]
282    fn train_then_encode_roundtrip() {
283        let codec = trained();
284        let v = det_vec(0, DIM);
285        let enc = codec.encode(&v).expect("encode");
286        let prep = codec.prepare_query(&v).expect("prepare_query");
287        let dist = codec.distance_prepared(&prep, &enc).expect("distance");
288        assert!(dist.is_finite(), "distance must be finite, got {dist}");
289        assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
290    }
291
292    #[test]
293    fn encode_before_train_returns_not_trained() {
294        let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
295        let v = det_vec(0, DIM);
296        let err = codec.encode(&v).unwrap_err();
297        let msg = format!("{err}");
298        assert!(
299            msg.contains("not trained") || msg.contains("trained"),
300            "expected 'trained' in error, got: {msg}"
301        );
302    }
303
304    #[test]
305    fn train_with_empty_samples_fails() {
306        let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
307        let err = codec.train(&[]).unwrap_err();
308        let msg = format!("{err}");
309        assert!(
310            msg.contains("bad input") || msg.contains("empty"),
311            "expected bad input error, got: {msg}"
312        );
313    }
314
315    #[test]
316    fn train_with_dim_mismatch_fails() {
317        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
318        let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
319        let bad = det_vec(0, DIM + 4);
320        refs.push(bad.as_slice());
321        let mut codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
322        let err = codec.train(&refs).unwrap_err();
323        let msg = format!("{err}");
324        assert!(
325            msg.contains("bad input") || msg.contains("dim"),
326            "expected bad input error, got: {msg}"
327        );
328    }
329
330    #[test]
331    fn prepare_query_wrong_dim_fails() {
332        let codec = trained();
333        let bad = det_vec(0, DIM + 2);
334        match codec.prepare_query(&bad) {
335            Err(e) => {
336                let msg = format!("{e}");
337                assert!(
338                    msg.contains("bad input") || msg.contains("dim"),
339                    "expected bad input error, got: {msg}"
340                );
341            }
342            Ok(_) => panic!("expected an error for wrong dim"),
343        }
344    }
345
346    #[test]
347    fn distance_prepared_wrong_variant_fails() {
348        let codec = trained();
349        let v = det_vec(0, DIM);
350        let enc = codec.encode(&v).expect("encode");
351        let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
352        let err = codec.distance_prepared(&bad_prepared, &enc).unwrap_err();
353        let msg = format!("{err}");
354        assert!(
355            msg.contains("Bytes") || msg.contains("not Bytes"),
356            "error message should mention Bytes variant, got: {msg}"
357        );
358    }
359
360    #[test]
361    fn name_is_expected() {
362        let codec = BbqRerank::new(DIM, DEFAULT_OVERSAMPLE);
363        assert_eq!(codec.name(), CodecName::Bbq);
364    }
365}