Skip to main content

rvf_quant/
rabitq.rs

1//! RaBitQ-style binary quantization (Gao & Long, SIGMOD 2024).
2//!
3//! Improves on naive sign-bit binary quantization with:
4//!
5//! 1. **Centroid centering** — vectors are encoded relative to a single
6//!    global centroid, so codes capture the residual geometry instead of
7//!    the (often uninformative) absolute signs.
8//! 2. **Deterministic pseudo-random rotation** — a seeded randomized
9//!    Hadamard transform (sign flips + fast Walsh–Hadamard, repeated for
10//!    [`DEFAULT_ROUNDS`] rounds, dims padded to a power of two) spreads
11//!    energy uniformly across dimensions so 1 bit per dimension carries
12//!    near-optimal information. The rotation is orthonormal and fully
13//!    reproducible from the stored seed.
14//! 3. **Per-vector correction scalars** — the residual norm and the
15//!    dot-correction `<o_unit, s_unit>` from the RaBitQ estimator, which
16//!    turn Hamming-style codes into an unbiased inner-product estimator.
17//! 4. **Asymmetric distance estimation** — full-precision queries are
18//!    compared against binary codes via the correction scalars, giving a
19//!    far better candidate ranking than symmetric Hamming distance.
20//!
21//! Intended use is a two-stage search: scan codes with the estimator to
22//! collect `oversample * k` candidates, then rescore those candidates with
23//! exact f32 distances.
24
25use alloc::vec;
26use alloc::vec::Vec;
27
28use crate::tier::TemperatureTier;
29use crate::traits::Quantizer;
30
31/// Default number of randomized-Hadamard rounds. Two rounds already mix
32/// well; three gives near-Gaussian rotations at negligible cost.
33pub const DEFAULT_ROUNDS: u8 = 3;
34
35/// Number of bytes of per-vector correction scalars (norm + dot_corr).
36pub const CORRECTION_BYTES: usize = 8;
37
38/// RaBitQ quantizer parameters (shared across all encoded vectors).
39#[derive(Clone, Debug)]
40pub struct RabitqQuantizer {
41    /// Original vector dimensionality.
42    pub dim: usize,
43    /// Dimensionality after padding to the next power of two.
44    pub padded_dim: usize,
45    /// Seed for the deterministic pseudo-random rotation.
46    pub seed: u64,
47    /// Number of randomized-Hadamard rounds.
48    pub rounds: u8,
49    /// Global centroid (length `dim`) subtracted before rotation.
50    pub centroid: Vec<f32>,
51}
52
53/// A single encoded vector: 1-bit sign code plus correction scalars.
54#[derive(Clone, Debug, PartialEq)]
55pub struct RabitqCode {
56    /// Sign bits of the rotated centered vector (`padded_dim` bits,
57    /// dimension `d` maps to bit `d % 8` of byte `d / 8`).
58    pub bits: Vec<u8>,
59    /// Residual norm `||v - centroid||`.
60    pub norm: f32,
61    /// Dot correction `<o_unit, s_unit>` where `o_unit` is the unit
62    /// rotated residual and `s_unit = signs / sqrt(padded_dim)`.
63    pub dot_corr: f32,
64}
65
66impl RabitqCode {
67    /// Total stored bytes for this code (bits + correction scalars).
68    #[inline]
69    pub fn stored_bytes(&self) -> usize {
70        self.bits.len() + CORRECTION_BYTES
71    }
72}
73
74/// A query prepared for asymmetric distance estimation (computed once
75/// per query, reused across all codes).
76#[derive(Clone, Debug)]
77pub struct RabitqQuery {
78    /// Rotated centered query, length `padded_dim`.
79    pub rotated: Vec<f32>,
80    /// Squared residual norm `||q - centroid||^2`.
81    pub norm_sq: f32,
82}
83
84/// SplitMix64 mixer (same constants as the runtime's deterministic
85/// leveling) — used to derive reproducible rotation sign flips.
86#[inline]
87fn splitmix64(x: u64) -> u64 {
88    let mut z = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
89    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
90    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
91    z ^ (z >> 31)
92}
93
94/// Smallest power of two `>= n` (and `>= 1`).
95#[inline]
96fn next_pow2(n: usize) -> usize {
97    n.max(1).next_power_of_two()
98}
99
100/// In-place unnormalized fast Walsh–Hadamard transform.
101/// `v.len()` must be a power of two.
102fn fwht(v: &mut [f32]) {
103    let n = v.len();
104    let mut h = 1;
105    while h < n {
106        let mut i = 0;
107        while i < n {
108            for j in i..i + h {
109                let x = v[j];
110                let y = v[j + h];
111                v[j] = x + y;
112                v[j + h] = x - y;
113            }
114            i += h * 2;
115        }
116        h *= 2;
117    }
118}
119
120impl RabitqQuantizer {
121    /// Train a RaBitQ quantizer over the given vectors: the centroid is
122    /// the per-dimension mean. The rotation is derived from `seed`.
123    ///
124    /// # Panics
125    ///
126    /// Panics if `vectors` is empty or dimensionality is inconsistent.
127    pub fn train(vectors: &[&[f32]], seed: u64) -> Self {
128        assert!(!vectors.is_empty(), "need at least one training vector");
129        let dim = vectors[0].len();
130        assert!(dim > 0, "vector dimensionality must be > 0");
131
132        let mut centroid = vec![0.0f64; dim];
133        for v in vectors {
134            assert_eq!(v.len(), dim, "dimension mismatch in training data");
135            for (acc, &x) in centroid.iter_mut().zip(v.iter()) {
136                *acc += x as f64;
137            }
138        }
139        let inv_n = 1.0 / vectors.len() as f64;
140        let centroid: Vec<f32> = centroid.iter().map(|&s| (s * inv_n) as f32).collect();
141
142        Self::with_centroid(dim, centroid, seed, DEFAULT_ROUNDS)
143    }
144
145    /// Construct from explicit parameters (used by the QUANT_SEG codec).
146    pub fn with_centroid(dim: usize, centroid: Vec<f32>, seed: u64, rounds: u8) -> Self {
147        assert_eq!(centroid.len(), dim, "centroid length must equal dim");
148        Self {
149            dim,
150            padded_dim: next_pow2(dim),
151            seed,
152            rounds: rounds.max(1),
153            centroid,
154        }
155    }
156
157    /// Deterministic sign flip for dimension `i` of rotation round `round`:
158    /// returns `true` for negate.
159    #[inline]
160    fn sign_flip(&self, round: u8, i: usize) -> bool {
161        // One SplitMix64 word covers 64 dimensions; counter-based so any
162        // (round, word) is independently addressable and reproducible.
163        let word = splitmix64(
164            self.seed
165                ^ (round as u64).wrapping_mul(0xA076_1D64_78BD_642F)
166                ^ ((i as u64) / 64).wrapping_mul(0xE703_7ED1_A0B4_28DB),
167        );
168        (word >> (i % 64)) & 1 == 1
169    }
170
171    /// Apply the seeded orthonormal rotation: pad to `padded_dim`, then
172    /// `rounds` of (sign flips, normalized Walsh–Hadamard).
173    pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
174        debug_assert!(v.len() <= self.padded_dim);
175        let mut buf = vec![0.0f32; self.padded_dim];
176        buf[..v.len()].copy_from_slice(v);
177        let scale = 1.0 / (self.padded_dim as f32).sqrt();
178        for round in 0..self.rounds {
179            for (i, x) in buf.iter_mut().enumerate() {
180                if self.sign_flip(round, i) {
181                    *x = -*x;
182                }
183            }
184            fwht(&mut buf);
185            for x in buf.iter_mut() {
186                *x *= scale;
187            }
188        }
189        buf
190    }
191
192    /// Inverse of [`Self::rotate`] (rounds in reverse: Hadamard, then
193    /// sign flips — both are their own inverses).
194    pub fn rotate_inverse(&self, v: &[f32]) -> Vec<f32> {
195        debug_assert_eq!(v.len(), self.padded_dim);
196        let mut buf = v.to_vec();
197        let scale = 1.0 / (self.padded_dim as f32).sqrt();
198        for round in (0..self.rounds).rev() {
199            fwht(&mut buf);
200            for x in buf.iter_mut() {
201                *x *= scale;
202            }
203            for (i, x) in buf.iter_mut().enumerate() {
204                if self.sign_flip(round, i) {
205                    *x = -*x;
206                }
207            }
208        }
209        buf
210    }
211
212    /// Encode a vector: center, rotate, take sign bits, and compute the
213    /// RaBitQ correction scalars.
214    pub fn encode_code(&self, vector: &[f32]) -> RabitqCode {
215        assert_eq!(vector.len(), self.dim, "vector dimension mismatch");
216        let centered: Vec<f32> = vector
217            .iter()
218            .zip(self.centroid.iter())
219            .map(|(&x, &c)| x - c)
220            .collect();
221        let rotated = self.rotate(&centered);
222
223        let mut norm_sq = 0.0f32;
224        let mut abs_sum = 0.0f32;
225        let mut bits = vec![0u8; self.padded_dim.div_ceil(8)];
226        for (d, &x) in rotated.iter().enumerate() {
227            norm_sq += x * x;
228            abs_sum += x.abs();
229            if x >= 0.0 {
230                bits[d / 8] |= 1 << (d % 8);
231            }
232        }
233        let norm = norm_sq.sqrt();
234        // <o_unit, s_unit> = sum |r_i| / (||r|| * sqrt(D)). For a zero
235        // residual (vector == centroid) the estimator multiplies by
236        // norm = 0 anyway, so any positive placeholder is fine.
237        let dot_corr = if norm > f32::EPSILON {
238            (abs_sum / (norm * (self.padded_dim as f32).sqrt())).max(f32::EPSILON)
239        } else {
240            1.0
241        };
242        RabitqCode {
243            bits,
244            norm,
245            dot_corr,
246        }
247    }
248
249    /// Prepare a query for repeated asymmetric distance estimation.
250    pub fn prepare_query(&self, query: &[f32]) -> RabitqQuery {
251        assert_eq!(query.len(), self.dim, "query dimension mismatch");
252        let centered: Vec<f32> = query
253            .iter()
254            .zip(self.centroid.iter())
255            .map(|(&x, &c)| x - c)
256            .collect();
257        let rotated = self.rotate(&centered);
258        let norm_sq = rotated.iter().map(|&x| x * x).sum();
259        RabitqQuery { rotated, norm_sq }
260    }
261
262    /// Estimate the squared L2 distance between the (full-precision)
263    /// prepared query and an encoded vector.
264    ///
265    /// Uses the RaBitQ estimator: `<o_unit, x> ~ <s_unit, x> / <s_unit,
266    /// o_unit>`, so `<v-c, q-c> ~ norm * (<s, rq> / sqrt(D)) / dot_corr`,
267    /// and `||v-q||^2 = ||v-c||^2 + ||q-c||^2 - 2<v-c, q-c>` (rotation
268    /// preserves norms and inner products).
269    pub fn estimate_l2_sq(&self, query: &RabitqQuery, code: &RabitqCode) -> f32 {
270        let mut signed_sum = 0.0f32;
271        for (d, &x) in query.rotated.iter().enumerate() {
272            if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
273                signed_sum += x;
274            } else {
275                signed_sum -= x;
276            }
277        }
278        let est_ip = code.norm * (signed_sum / (self.padded_dim as f32).sqrt()) / code.dot_corr;
279        code.norm * code.norm + query.norm_sq - 2.0 * est_ip
280    }
281
282    /// Bytes stored per encoded vector (sign bits + correction scalars).
283    #[inline]
284    pub fn stored_bytes_per_vector(&self) -> usize {
285        self.padded_dim.div_ceil(8) + CORRECTION_BYTES
286    }
287
288    /// Compression ratio versus raw f32 storage of the original vector.
289    #[inline]
290    pub fn compression_ratio(&self) -> f32 {
291        (self.dim * 4) as f32 / self.stored_bytes_per_vector() as f32
292    }
293
294    /// Serialize a code to bytes: `[bits][norm: f32 LE][dot_corr: f32 LE]`.
295    pub fn code_to_bytes(&self, code: &RabitqCode) -> Vec<u8> {
296        let mut out = Vec::with_capacity(code.stored_bytes());
297        out.extend_from_slice(&code.bits);
298        out.extend_from_slice(&code.norm.to_le_bytes());
299        out.extend_from_slice(&code.dot_corr.to_le_bytes());
300        out
301    }
302
303    /// Deserialize a code produced by [`Self::code_to_bytes`].
304    /// Returns `None` if `data` is too short (panic-free on bad input).
305    pub fn code_from_bytes(&self, data: &[u8]) -> Option<RabitqCode> {
306        let nbits = self.padded_dim.div_ceil(8);
307        if data.len() < nbits + CORRECTION_BYTES {
308            return None;
309        }
310        let bits = data[..nbits].to_vec();
311        let norm = f32::from_le_bytes(data[nbits..nbits + 4].try_into().ok()?);
312        let dot_corr = f32::from_le_bytes(data[nbits + 4..nbits + 8].try_into().ok()?);
313        Some(RabitqCode {
314            bits,
315            norm,
316            dot_corr,
317        })
318    }
319}
320
321impl Quantizer for RabitqQuantizer {
322    fn encode(&self, vector: &[f32]) -> Vec<u8> {
323        self.code_to_bytes(&self.encode_code(vector))
324    }
325
326    fn decode(&self, codes: &[u8]) -> Vec<f32> {
327        let code = match self.code_from_bytes(codes) {
328            Some(c) => c,
329            None => return vec![0.0; self.dim],
330        };
331        // Best rank-1 reconstruction: project onto the code direction,
332        // r_hat = norm * dot_corr * s / sqrt(D), then invert the rotation
333        // and re-add the centroid.
334        let scale = code.norm * code.dot_corr / (self.padded_dim as f32).sqrt();
335        let mut rotated = Vec::with_capacity(self.padded_dim);
336        for d in 0..self.padded_dim {
337            let sign = if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
338                1.0
339            } else {
340                -1.0
341            };
342            rotated.push(sign * scale);
343        }
344        let residual = self.rotate_inverse(&rotated);
345        residual
346            .iter()
347            .take(self.dim)
348            .zip(self.centroid.iter())
349            .map(|(&r, &c)| r + c)
350            .collect()
351    }
352
353    fn tier(&self) -> TemperatureTier {
354        TemperatureTier::Cold
355    }
356
357    fn dim(&self) -> usize {
358        self.dim
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    fn lcg_vector(dim: usize, seed: u64) -> Vec<f32> {
367        let mut x = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
368        (0..dim)
369            .map(|_| {
370                x = x
371                    .wrapping_mul(6364136223846793005)
372                    .wrapping_add(1442695040888963407);
373                ((x >> 33) as f32) / (u32::MAX as f32) - 0.5
374            })
375            .collect()
376    }
377
378    fn make_quantizer(dim: usize, n: usize) -> (RabitqQuantizer, Vec<Vec<f32>>) {
379        let data: Vec<Vec<f32>> = (0..n).map(|i| lcg_vector(dim, i as u64)).collect();
380        let refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
381        (RabitqQuantizer::train(&refs, 0xDEAD_BEEF), data)
382    }
383
384    fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
385        a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
386    }
387
388    #[test]
389    fn rotation_is_orthonormal_and_deterministic() {
390        let (rq, data) = make_quantizer(100, 8); // non-pow2 dim -> padded 128
391        assert_eq!(rq.padded_dim, 128);
392        for v in &data {
393            let r1 = rq.rotate(v);
394            let r2 = rq.rotate(v);
395            assert_eq!(r1, r2, "rotation must be deterministic");
396
397            let norm_in: f32 = v.iter().map(|x| x * x).sum();
398            let norm_out: f32 = r1.iter().map(|x| x * x).sum();
399            assert!(
400                (norm_in - norm_out).abs() < 1e-3 * norm_in.max(1.0),
401                "rotation must preserve norms: {norm_in} vs {norm_out}"
402            );
403
404            // Inverse round-trips back to the padded input.
405            let back = rq.rotate_inverse(&r1);
406            for (d, (&orig, &rec)) in v.iter().zip(back.iter()).enumerate() {
407                assert!(
408                    (orig - rec).abs() < 1e-4,
409                    "dim {d}: {orig} != {rec} after inverse rotation"
410                );
411            }
412            for &pad in &back[v.len()..] {
413                assert!(pad.abs() < 1e-4, "padding must invert to ~0");
414            }
415        }
416    }
417
418    #[test]
419    fn rotation_preserves_inner_products() {
420        let (rq, data) = make_quantizer(64, 4);
421        let a = &data[0];
422        let b = &data[1];
423        let ip: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
424        let ra = rq.rotate(a);
425        let rb = rq.rotate(b);
426        let rip: f32 = ra.iter().zip(rb.iter()).map(|(x, y)| x * y).sum();
427        assert!((ip - rip).abs() < 1e-3, "ip {ip} vs rotated ip {rip}");
428    }
429
430    #[test]
431    fn different_seeds_give_different_rotations() {
432        let v = lcg_vector(32, 7);
433        let a = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 1, DEFAULT_ROUNDS);
434        let b = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 2, DEFAULT_ROUNDS);
435        assert_ne!(a.rotate(&v), b.rotate(&v));
436    }
437
438    #[test]
439    fn code_round_trip_bytes() {
440        let (rq, data) = make_quantizer(48, 16);
441        for v in &data {
442            let code = rq.encode_code(v);
443            let bytes = rq.code_to_bytes(&code);
444            assert_eq!(bytes.len(), rq.stored_bytes_per_vector());
445            let back = rq.code_from_bytes(&bytes).expect("decode");
446            assert_eq!(back, code);
447        }
448        // Truncated input must be rejected, not panic.
449        let code = rq.encode_code(&data[0]);
450        let bytes = rq.code_to_bytes(&code);
451        assert!(rq.code_from_bytes(&bytes[..bytes.len() - 1]).is_none());
452        assert!(rq.code_from_bytes(&[]).is_none());
453    }
454
455    #[test]
456    fn decode_reconstruction_beats_naive_sign_bits() {
457        // The corrected reconstruction must be closer to the original than
458        // the naive +-1 sign decode is (sanity that corrections help).
459        let (rq, data) = make_quantizer(128, 64);
460        let mut rabitq_err = 0.0f64;
461        let mut naive_err = 0.0f64;
462        for v in &data {
463            let rec = rq.decode(&rq.encode(v));
464            rabitq_err += l2_sq(v, &rec) as f64;
465
466            let bits = crate::binary::encode_binary(v);
467            let nrec = crate::binary::decode_binary(&bits, v.len());
468            naive_err += l2_sq(v, &nrec) as f64;
469        }
470        assert!(
471            rabitq_err < naive_err,
472            "RaBitQ reconstruction error {rabitq_err} must beat naive {naive_err}"
473        );
474    }
475
476    #[test]
477    fn estimator_correlates_with_true_distances() {
478        // Pearson correlation between estimated and true squared L2 over
479        // many (query, vector) pairs must be strong.
480        let dim = 128;
481        let (rq, data) = make_quantizer(dim, 200);
482        let codes: Vec<RabitqCode> = data.iter().map(|v| rq.encode_code(v)).collect();
483
484        let mut est = Vec::new();
485        let mut truth = Vec::new();
486        for qi in 0..20u64 {
487            let q = lcg_vector(dim, 5_000 + qi);
488            let prepared = rq.prepare_query(&q);
489            for (v, code) in data.iter().zip(codes.iter()) {
490                est.push(rq.estimate_l2_sq(&prepared, code) as f64);
491                truth.push(l2_sq(&q, v) as f64);
492            }
493        }
494
495        let n = est.len() as f64;
496        let me = est.iter().sum::<f64>() / n;
497        let mt = truth.iter().sum::<f64>() / n;
498        let mut cov = 0.0;
499        let mut ve = 0.0;
500        let mut vt = 0.0;
501        for (&e, &t) in est.iter().zip(truth.iter()) {
502            cov += (e - me) * (t - mt);
503            ve += (e - me) * (e - me);
504            vt += (t - mt) * (t - mt);
505        }
506        let corr = cov / (ve.sqrt() * vt.sqrt());
507        #[cfg(feature = "std")]
508        std::eprintln!("estimator/true distance correlation (128d): {corr:.4}");
509        assert!(
510            corr > 0.8,
511            "estimator correlation {corr:.3} too weak (expected > 0.8)"
512        );
513
514        // The estimator must also be roughly unbiased: mean relative error
515        // of estimated vs true distance stays small.
516        let mean_rel: f64 = est
517            .iter()
518            .zip(truth.iter())
519            .map(|(&e, &t)| ((e - t) / t.max(1e-9)).abs())
520            .sum::<f64>()
521            / n;
522        #[cfg(feature = "std")]
523        std::eprintln!("estimator mean relative distance error (128d): {mean_rel:.4}");
524        assert!(
525            mean_rel < 0.25,
526            "mean relative error {mean_rel:.3} too large"
527        );
528    }
529
530    #[test]
531    fn compression_ratio_targets() {
532        // Code-only payload is exactly 32x; with the 8 correction bytes the
533        // total is ~21x at 128 dims and approaches 32x as dims grow.
534        let rq128 = RabitqQuantizer::with_centroid(128, vec![0.0; 128], 1, DEFAULT_ROUNDS);
535        assert_eq!(rq128.padded_dim, 128);
536        assert_eq!((rq128.dim * 4) / (rq128.padded_dim / 8), 32);
537        assert!(rq128.compression_ratio() >= 20.0);
538
539        let rq1024 = RabitqQuantizer::with_centroid(1024, vec![0.0; 1024], 1, DEFAULT_ROUNDS);
540        assert!(rq1024.compression_ratio() >= 30.0);
541    }
542
543    #[test]
544    fn zero_residual_vector_is_safe() {
545        let (rq, _) = make_quantizer(16, 4);
546        let code = rq.encode_code(&rq.centroid.clone());
547        assert!(code.norm <= 1e-6);
548        let q = rq.prepare_query(&lcg_vector(16, 99));
549        let est = rq.estimate_l2_sq(&q, &code);
550        assert!(est.is_finite());
551    }
552}