Skip to main content

nodedb_vector/rerank/codecs/
rabitq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `RerankCodec` wrapper for RaBitQ 1-bit quantization.
4//!
5//! RaBitQ is a training-based codec: `train()` runs centroid calibration and
6//! stores a randomised WHT rotation. Until training is complete, `encode` and
7//! `prepare_query` return `RerankError::NotTrained`.
8//!
9//! Distance is computed by inlining the asymmetric Hamming-based formula from
10//! `RaBitQCodec::exact_asymmetric_distance`:
11//!
12//!   approx_l2 = q_norm² + v_norm² − 2 · q_norm · v_norm · (1 − 2·hamming/dim)
13//!
14//! The prepared form is `PreparedQuery::Bytes` with the layout:
15//!   [0..4]   query_norm as f32 little-endian
16//!   [4..]    rotated_signs bytes (length = dim.div_ceil(8))
17
18use nodedb_codec::vector_quant::codec::VectorCodec as _;
19use nodedb_codec::vector_quant::hamming::hamming_distance;
20use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
21use nodedb_codec::vector_quant::rabitq::{RaBitQCodec, RaBitQQuery};
22
23use crate::{
24    rerank::codec::{CodecName, PreparedQuery, RerankCodec},
25    rerank::types::RerankError,
26};
27
28// ── Payload helpers ───────────────────────────────────────────────────────────
29
30fn encode_payload(query: &RaBitQQuery) -> Vec<u8> {
31    // Layout: 4 bytes query_norm (f32 LE) || rotated_signs bytes
32    let mut buf = Vec::with_capacity(4 + query.rotated_signs.len());
33    buf.extend_from_slice(&query.query_norm.to_le_bytes());
34    buf.extend_from_slice(&query.rotated_signs);
35    buf
36}
37
38fn decode_payload(payload: &[u8], dim: usize) -> Result<(f32, Vec<u8>), RerankError> {
39    let sign_len = dim.div_ceil(8);
40    let expected = 4 + sign_len;
41    if payload.len() != expected {
42        return Err(RerankError::BadInput(format!(
43            "rabitq distance: payload len {} != expected {} for dim {}",
44            payload.len(),
45            expected,
46            dim
47        )));
48    }
49    let query_norm = f32::from_le_bytes(
50        payload[..4]
51            .try_into()
52            .expect("slice of 4 bytes always converts to [u8;4]"),
53    );
54    Ok((query_norm, payload[4..].to_vec()))
55}
56
57// ── RaBitQRerank ──────────────────────────────────────────────────────────────
58
59/// Default rotation seed used when the caller does not specify one.
60pub const DEFAULT_ROTATION_SEED: u64 = 0x00C0_FFEE_00C0_FFEE;
61
62/// Object-safe `RerankCodec` wrapper around `RaBitQCodec`.
63///
64/// The codec starts untrained. `encode` and `prepare_query` return
65/// `RerankError::NotTrained` until `train()` has been called with a
66/// representative sample of vectors.
67///
68/// `from_codec` accepts a pre-calibrated `RaBitQCodec` (used when restoring
69/// from a snapshot).
70pub struct RaBitQRerank {
71    codec: Option<RaBitQCodec>,
72    dim: usize,
73    rotation_seed: u64,
74}
75
76impl RaBitQRerank {
77    /// Construct an untrained wrapper.
78    ///
79    /// `encode` / `distance_prepared` return `RerankError::NotTrained` until
80    /// `train()` is called.
81    pub fn new(dim: usize, rotation_seed: u64) -> Self {
82        Self {
83            codec: None,
84            dim,
85            rotation_seed,
86        }
87    }
88
89    /// Construct from a pre-calibrated codec (used when restoring from snapshot).
90    pub fn from_codec(codec: RaBitQCodec) -> Self {
91        let dim = codec.dim;
92        Self {
93            codec: Some(codec),
94            dim,
95            rotation_seed: DEFAULT_ROTATION_SEED,
96        }
97    }
98}
99
100impl RerankCodec for RaBitQRerank {
101    /// Encode a full-precision vector to RaBitQ 1-bit bytes.
102    ///
103    /// The serialised form is the raw `UnifiedQuantizedVector` buffer
104    /// (`as_bytes()`): 32-byte `QuantHeader` followed by `dim.div_ceil(8)`
105    /// sign-packed bits.
106    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
107        if v.len() != self.dim {
108            return Err(RerankError::BadInput(format!(
109                "rabitq encode: vector len {} != codec dim {}",
110                v.len(),
111                self.dim
112            )));
113        }
114        let codec = self.codec.as_ref().ok_or_else(|| {
115            RerankError::NotTrained(
116                "rabitq: codec must be trained before encoding (call train() with a sample of vectors)"
117                    .to_string(),
118            )
119        })?;
120        let quantized = codec.encode(v);
121        Ok(quantized.as_ref().as_bytes().to_vec())
122    }
123
124    /// Prepare the query by computing its centroid-subtracted, rotated sign pack
125    /// and the exact query norm.
126    ///
127    /// The prepared form is `PreparedQuery::Bytes` with the layout:
128    ///   4 bytes query_norm (f32 LE) || sign bytes (dim.div_ceil(8)).
129    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
130        if q.len() != self.dim {
131            return Err(RerankError::BadInput(format!(
132                "rabitq prepare_query: query len {} != codec dim {}",
133                q.len(),
134                self.dim
135            )));
136        }
137        let codec = self.codec.as_ref().ok_or_else(|| {
138            RerankError::NotTrained(
139                "rabitq: codec must be trained before prepare_query (call train() with a sample of vectors)"
140                    .to_string(),
141            )
142        })?;
143        let query = codec.prepare_query(q);
144        Ok(PreparedQuery::Bytes(encode_payload(&query)))
145    }
146
147    /// Compute asymmetric Hamming-based L2 distance from a prepared query to a
148    /// RaBitQ-encoded candidate.
149    ///
150    /// Inlines `RaBitQCodec::exact_asymmetric_distance` (without bias_correct)
151    /// using `UnifiedQuantizedVectorRef` to avoid a redundant allocation:
152    ///
153    ///   approx = q_norm² + v_norm² − 2·q_norm·v_norm·(1 − 2·hamming/dim)
154    ///
155    /// Expects `PreparedQuery::Bytes` produced by `prepare_query`.
156    fn distance_prepared(
157        &self,
158        prepared: &PreparedQuery,
159        encoded: &[u8],
160    ) -> Result<f32, RerankError> {
161        let payload = match prepared {
162            PreparedQuery::Bytes(b) => b.as_slice(),
163            _ => {
164                return Err(RerankError::BadInput(
165                    "rabitq distance: prepared query is not Bytes".to_string(),
166                ));
167            }
168        };
169
170        let (query_norm, rotated_signs) = decode_payload(payload, self.dim)?;
171
172        let packed_len = self.dim.div_ceil(8);
173        let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
174            RerankError::BadInput(format!(
175                "rabitq distance: failed to parse encoded bytes: {e}"
176            ))
177        })?;
178
179        let vh = uqv_ref.header();
180        let vb = uqv_ref.packed_bits();
181        let h = hamming_distance(&rotated_signs, vb);
182        let dim = self.dim as f32;
183        let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
184        let approx = query_norm * query_norm + vh.residual_norm * vh.residual_norm
185            - 2.0 * query_norm * vh.residual_norm * dot_estimate;
186        Ok(approx.max(0.0))
187    }
188
189    fn name(&self) -> CodecName {
190        CodecName::RaBitQ
191    }
192
193    fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
194        let codec = self.codec.as_ref().ok_or_else(|| {
195            RerankError::NotTrained("rabitq sidecar serialize: codec not trained".to_string())
196        })?;
197        codec
198            .to_bytes()
199            .map_err(|e| RerankError::BadInput(format!("rabitq to_bytes: {e}")))
200    }
201
202    /// Calibrate from a sample of vectors.
203    ///
204    /// Validates that:
205    /// - `samples` is non-empty.
206    /// - Every sample has length `self.dim`.
207    ///
208    /// On success, stores the calibrated codec; subsequent `encode` /
209    /// `distance_prepared` calls will succeed.
210    fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
211        if samples.is_empty() {
212            return Err(RerankError::BadInput(
213                "rabitq train: empty sample set".to_string(),
214            ));
215        }
216        for s in samples {
217            if s.len() != self.dim {
218                return Err(RerankError::BadInput(format!(
219                    "rabitq train: sample has len {} but codec dim is {}",
220                    s.len(),
221                    self.dim
222                )));
223            }
224        }
225        let codec = RaBitQCodec::calibrate(samples, self.dim, self.rotation_seed);
226        self.codec = Some(codec);
227        Ok(())
228    }
229}
230
231// ── Tests ─────────────────────────────────────────────────────────────────────
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    const DIM: usize = 16;
238    const N: usize = 64;
239
240    fn det_vec(i: usize, dim: usize) -> Vec<f32> {
241        (0..dim)
242            .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
243            .collect()
244    }
245
246    fn trained() -> RaBitQRerank {
247        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
248        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
249        let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
250        codec.train(&refs).expect("train must succeed");
251        codec
252    }
253
254    #[test]
255    fn train_then_encode_roundtrip() {
256        let codec = trained();
257        let v = det_vec(0, DIM);
258        let enc = codec.encode(&v).expect("encode");
259        let prep = codec.prepare_query(&v).expect("prepare_query");
260        let dist = codec.distance_prepared(&prep, &enc).expect("distance");
261        assert!(dist.is_finite(), "distance must be finite, got {dist}");
262        assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
263    }
264
265    #[test]
266    fn encode_before_train_returns_not_trained() {
267        let codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
268        let v = det_vec(0, DIM);
269        let err = codec.encode(&v).unwrap_err();
270        let msg = format!("{err}");
271        assert!(
272            msg.contains("not trained") || msg.contains("trained"),
273            "expected 'trained' in error, got: {msg}"
274        );
275    }
276
277    #[test]
278    fn train_with_empty_samples_fails() {
279        let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
280        let err = codec.train(&[]).unwrap_err();
281        let msg = format!("{err}");
282        assert!(
283            msg.contains("bad input") || msg.contains("empty"),
284            "expected bad input error, got: {msg}"
285        );
286    }
287
288    #[test]
289    fn train_with_dim_mismatch_fails() {
290        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
291        let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
292        let bad = det_vec(0, DIM + 4);
293        refs.push(bad.as_slice());
294        let mut codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
295        let err = codec.train(&refs).unwrap_err();
296        let msg = format!("{err}");
297        assert!(
298            msg.contains("bad input") || msg.contains("dim"),
299            "expected bad input error, got: {msg}"
300        );
301    }
302
303    #[test]
304    fn prepare_query_wrong_dim_fails() {
305        let codec = trained();
306        let bad = det_vec(0, DIM + 2);
307        match codec.prepare_query(&bad) {
308            Err(e) => {
309                let msg = format!("{e}");
310                assert!(
311                    msg.contains("bad input") || msg.contains("dim"),
312                    "expected bad input error, got: {msg}"
313                );
314            }
315            Ok(_) => panic!("expected an error for wrong dim"),
316        }
317    }
318
319    #[test]
320    fn distance_prepared_wrong_variant_fails() {
321        let codec = trained();
322        let v = det_vec(0, DIM);
323        let enc = codec.encode(&v).expect("encode");
324        let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
325        let err = codec.distance_prepared(&bad_prepared, &enc).unwrap_err();
326        let msg = format!("{err}");
327        assert!(
328            msg.contains("Bytes") || msg.contains("not Bytes"),
329            "error message should mention Bytes variant, got: {msg}"
330        );
331    }
332
333    #[test]
334    fn name_is_expected() {
335        let codec = RaBitQRerank::new(DIM, DEFAULT_ROTATION_SEED);
336        assert_eq!(codec.name(), CodecName::RaBitQ);
337    }
338}