Skip to main content

nodedb_codec/vector_quant/
rabitq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! RaBitQ — 1-bit quantization with O(1/√D) MSE error bound (SIGMOD 2024).
4//!
5//! Algorithm outline
6//! -----------------
7//! 1. **Calibration**: compute centroid `c` over training vectors.
8//! 2. **Rotation**: apply a randomised Walsh-Hadamard transform (WHT) with a
9//!    seed-derived signed-diagonal matrix `D` to the residual `v - c`.
10//!    - WHT requires a power-of-2 length; dimensions that are not pow2 are
11//!      zero-padded before the transform then truncated after.
12//!    - `D` is a vector of `±1` scalars derived deterministically from
13//!      `rotation_seed` using an xorshift64 generator.
14//! 3. **Encoding**: `code = sign(R·(v-c))`, packed 1-bit-per-dimension.
15//! 4. **Distance estimation** (L2):
16//!    `‖q-v‖² ≈ ‖v-c‖² + ‖q-c‖² − 2‖v-c‖‖q-c‖·(1 − 2·hamming/D)`
17//!    with `O(1/√D)` MSE — see SIGMOD 2024 Theorem 4.
18//!
19//! IP / raw inner-product bias
20//! ---------------------------
21//! The angular estimator above is MSE-optimal for L2 and cosine. For raw IP
22//! it carries a systematic bias. When `bias_correct = true` the codec
23//! subtracts `dot_quantized` (stored in the [`QuantHeader`]) from the
24//! asymmetric distance to partially compensate, following the TurboQuant-style
25//! QJL residual correction. This is off by default; consumers that use raw IP
26//! (RAG, attention scores) should enable it.
27//!
28//! [`rand_xoshiro`] is **not** a dependency of `nodedb-codec`. The signed-
29//! diagonal flip vector is derived from `rotation_seed` via an inline
30//! xorshift64 generator — no external crate required.
31
32use crate::vector_quant::codec::VectorCodec;
33use crate::vector_quant::hamming::hamming_distance;
34use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
35
36// ── Xorshift64 (inline PRNG) ────────────────────────────────────────────────
37
38/// Minimal xorshift64 PRNG for deterministic signed-diagonal generation.
39#[inline]
40fn xorshift64(state: &mut u64) -> u64 {
41    let mut x = *state;
42    x ^= x << 13;
43    x ^= x >> 7;
44    x ^= x << 17;
45    *state = x;
46    x
47}
48
49// ── WHT helpers ──────────────────────────────────────────────────────────────
50
51/// Next power-of-two ≥ `n`, returning `n` itself if already pow2.
52#[inline]
53fn next_pow2(n: usize) -> usize {
54    if n.is_power_of_two() {
55        n
56    } else {
57        n.next_power_of_two()
58    }
59}
60
61/// In-place Walsh-Hadamard Transform of a power-of-2 length slice.
62/// O(N log N) butterfly. Does not normalise by 1/√N (sign-only code
63/// does not require normalisation).
64fn wht_inplace(buf: &mut [f32]) {
65    let n = buf.len();
66    debug_assert!(n.is_power_of_two());
67    let mut step = 1usize;
68    while step < n {
69        let mut i = 0usize;
70        while i < n {
71            for j in i..i + step {
72                let a = buf[j];
73                let b = buf[j + step];
74                buf[j] = a + b;
75                buf[j + step] = a - b;
76            }
77            i += step * 2;
78        }
79        step *= 2;
80    }
81}
82
83// ── Sign-pack / unpack helpers ───────────────────────────────────────────────
84
85/// Pack `dim` signs from `rotated` (negative = 1-bit, non-negative = 0-bit)
86/// into ceil(dim/8) bytes, LSB-first.
87fn sign_pack(rotated: &[f32], dim: usize) -> Vec<u8> {
88    let nbytes = dim.div_ceil(8);
89    let mut out = vec![0u8; nbytes];
90    for (i, &v) in rotated.iter().take(dim).enumerate() {
91        if v < 0.0 {
92            out[i / 8] |= 1 << (i % 8);
93        }
94    }
95    out
96}
97
98/// Dequantize sign-packed bits back to ±1 values (for dot_quantized calc).
99fn sign_unpack(packed: &[u8], dim: usize) -> Vec<f32> {
100    (0..dim)
101        .map(|i| {
102            if packed[i / 8] & (1 << (i % 8)) != 0 {
103                -1.0f32
104            } else {
105                1.0f32
106            }
107        })
108        .collect()
109}
110
111// ── RaBitQCodec ──────────────────────────────────────────────────────────────
112
113/// RaBitQ codec: 1-bit quantization with O(1/√D) MSE error bound.
114///
115/// See module-level documentation for algorithm details.
116pub struct RaBitQCodec {
117    pub dim: usize,
118    /// Mean of training vectors; subtracted from each vector before rotation.
119    centroid: Vec<f32>,
120    /// Seed for the signed-diagonal flip vector used by the WHT rotation.
121    rotation_seed: u64,
122    /// If `true`, subtract `dot_quantized` from the asymmetric distance
123    /// estimate as a TurboQuant-style IP-bias correction term.
124    pub bias_correct: bool,
125}
126
127impl RaBitQCodec {
128    /// Calibrate a new codec from a set of training vectors.
129    ///
130    /// - Computes the centroid as the coordinate-wise mean of `vectors`.
131    /// - Stores `rotation_seed` for reproducible signed-diagonal generation.
132    ///
133    /// # Errors (none — returns Self directly)
134    ///
135    /// Returns a zero-centroid codec if `vectors` is empty.
136    pub fn calibrate(vectors: &[&[f32]], dim: usize, rotation_seed: u64) -> Self {
137        let centroid = if vectors.is_empty() {
138            vec![0.0f32; dim]
139        } else {
140            let n = vectors.len() as f32;
141            let mut c = vec![0.0f32; dim];
142            for v in vectors {
143                for (ci, &vi) in c.iter_mut().zip(v.iter()) {
144                    *ci += vi;
145                }
146            }
147            c.iter_mut().for_each(|x| *x /= n);
148            c
149        };
150        Self {
151            dim,
152            centroid,
153            rotation_seed,
154            bias_correct: false,
155        }
156    }
157
158    /// Apply the randomised WHT rotation to a residual vector.
159    ///
160    /// Steps:
161    /// 1. Apply signed-diagonal `D` (deterministic from `rotation_seed`).
162    /// 2. Zero-pad to the next power of two if `dim` is not pow2.
163    /// 3. WHT in-place.
164    /// 4. Truncate back to `dim`.
165    pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
166        let dim = self.dim;
167        let pow2 = next_pow2(dim);
168
169        // Signed-diagonal multiply: generate ±1 per dim from seed.
170        let mut seed = self.rotation_seed;
171        let mut buf = vec![0.0f32; pow2];
172        for (i, &vi) in v.iter().take(dim).enumerate() {
173            let flip = if xorshift64(&mut seed) & 1 == 0 {
174                1.0f32
175            } else {
176                -1.0f32
177            };
178            buf[i] = vi * flip;
179        }
180        // Trailing elements remain zero (pad).
181
182        wht_inplace(&mut buf);
183        buf.truncate(dim);
184        buf
185    }
186
187    /// Encode a single vector into a [`UnifiedQuantizedVector`] with
188    /// `QuantMode::RaBitQ`.
189    ///
190    /// The header fields populated are:
191    /// - `global_scale` = `residual_norm` (‖v−c‖); both store the same value
192    ///   so that consumers that use either field without context still have the
193    ///   magnitude available.
194    /// - `residual_norm` = ‖v−c‖.
195    /// - `dot_quantized` = ⟨residual, dequantised_sign_vector⟩ / ‖v−c‖;
196    ///   used for IP-bias correction when `bias_correct = true`.
197    fn encode_inner(&self, v: &[f32]) -> UnifiedQuantizedVector {
198        let dim = self.dim;
199
200        // Step 1: residual = v - centroid
201        let residual: Vec<f32> = v
202            .iter()
203            .zip(self.centroid.iter())
204            .map(|(&vi, &ci)| vi - ci)
205            .collect();
206
207        // Step 2: ‖residual‖
208        let residual_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
209
210        // Step 3: rotate
211        let rotated = self.apply_rotation(&residual);
212
213        // Step 4: sign-pack → 1-bit code
214        let packed = sign_pack(&rotated, dim);
215
216        // Step 5: compute dot_quantized = ⟨residual, R⁻¹·sign(rotated)⟩ / ‖residual‖
217        // Inverse rotation of the sign vector, then dot with original residual.
218        let signs_fp = sign_unpack(&packed, dim);
219        // Inverse WHT rotation: apply WHT again then re-apply D⁻¹ = D (since D² = I).
220        let pow2 = next_pow2(dim);
221        let mut sign_buf = vec![0.0f32; pow2];
222        for (i, &s) in signs_fp.iter().enumerate() {
223            sign_buf[i] = s;
224        }
225        wht_inplace(&mut sign_buf);
226        // Re-apply signed diagonal (D is its own inverse since flips are ±1)
227        let mut seed = self.rotation_seed;
228        #[allow(clippy::needless_range_loop)]
229        for i in 0..dim {
230            let flip = if xorshift64(&mut seed) & 1 == 0 {
231                1.0f32
232            } else {
233                -1.0f32
234            };
235            sign_buf[i] *= flip;
236        }
237        let dot_raw: f32 = residual
238            .iter()
239            .zip(sign_buf.iter().take(dim))
240            .map(|(&r, &s)| r * s)
241            .sum();
242        let dot_quantized = if residual_norm > 0.0 {
243            dot_raw / residual_norm
244        } else {
245            0.0
246        };
247
248        let header = QuantHeader {
249            quant_mode: QuantMode::RaBitQ as u16,
250            dim: dim as u16,
251            global_scale: residual_norm,
252            residual_norm,
253            dot_quantized,
254            outlier_bitmask: 0,
255            reserved: [0u8; 8],
256        };
257
258        UnifiedQuantizedVector::new(header, &packed, &[])
259            .expect("RaBitQ encode: layout construction must succeed")
260    }
261}
262
263// ── Quantized / Query newtypes ───────────────────────────────────────────────
264
265/// Packed 1-bit quantized vector produced by [`RaBitQCodec::encode`].
266pub struct RaBitQQuantized(UnifiedQuantizedVector);
267
268impl AsRef<UnifiedQuantizedVector> for RaBitQQuantized {
269    #[inline]
270    fn as_ref(&self) -> &UnifiedQuantizedVector {
271        &self.0
272    }
273}
274
275/// Prepared query for [`RaBitQCodec`] distance computations.
276pub struct RaBitQQuery {
277    /// Sign-packed rotated query (same bit layout as the stored codes).
278    pub rotated_signs: Vec<u8>,
279    /// ‖q − centroid‖.
280    pub query_norm: f32,
281}
282
283// ── VectorCodec impl ─────────────────────────────────────────────────────────
284
285impl VectorCodec for RaBitQCodec {
286    type Quantized = RaBitQQuantized;
287    type Query = RaBitQQuery;
288
289    fn encode(&self, v: &[f32]) -> Self::Quantized {
290        RaBitQQuantized(self.encode_inner(v))
291    }
292
293    fn prepare_query(&self, q: &[f32]) -> Self::Query {
294        let dim = self.dim;
295        let residual: Vec<f32> = q
296            .iter()
297            .zip(self.centroid.iter())
298            .map(|(&qi, &ci)| qi - ci)
299            .collect();
300        let query_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
301        let rotated = self.apply_rotation(&residual);
302        let rotated_signs = sign_pack(&rotated, dim);
303        RaBitQQuery {
304            rotated_signs,
305            query_norm,
306        }
307    }
308
309    /// Symmetric distance estimate: both `q` and `v` are quantized.
310    ///
311    /// `approx_l2 = ‖v-c‖² + ‖q-c‖² − 2·‖v-c‖·‖q-c‖·(1 − 2·hamming/D)`
312    ///
313    /// The angular factor `1 − 2·hamming/D` approximates `cos(θ)` between
314    /// the two sign-vectors. Error bound: O(1/√D) MSE.
315    fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
316        let qh = q.0.header();
317        let vh = v.0.header();
318        let qb = q.0.packed_bits();
319        let vb = v.0.packed_bits();
320        let h = hamming_distance(qb, vb);
321        let dim = self.dim as f32;
322        let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
323        let approx = qh.residual_norm * qh.residual_norm + vh.residual_norm * vh.residual_norm
324            - 2.0 * qh.residual_norm * vh.residual_norm * dot_estimate;
325        approx.max(0.0)
326    }
327
328    /// Asymmetric distance estimate: `q` is a prepared [`RaBitQQuery`], `v`
329    /// is a stored quantized vector.
330    ///
331    /// Uses `query_norm` (exact ‖q−c‖) against `v.residual_norm` (exact ‖v−c‖)
332    /// for higher fidelity than the symmetric variant.
333    ///
334    /// If `self.bias_correct = true`, subtract `v.dot_quantized` as a first-
335    /// order IP-bias correction term (TurboQuant-style).
336    fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
337        let vh = v.0.header();
338        let vb = v.0.packed_bits();
339        let h = hamming_distance(&q.rotated_signs, vb);
340        let dim = self.dim as f32;
341        let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
342        let mut approx = q.query_norm * q.query_norm + vh.residual_norm * vh.residual_norm
343            - 2.0 * q.query_norm * vh.residual_norm * dot_estimate;
344        if self.bias_correct {
345            approx -= vh.dot_quantized;
346        }
347        approx.max(0.0)
348    }
349}
350
351// ── Tests ─────────────────────────────────────────────────────────────────────
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
358        let mut s = seed | 1;
359        (0..dim)
360            .map(|_| {
361                let v = xorshift64(&mut s);
362                // Map to [-1, 1]
363                (v as f32 / u64::MAX as f32) * 2.0 - 1.0
364            })
365            .collect()
366    }
367
368    #[test]
369    fn apply_rotation_different_seeds_differ() {
370        let dim = 64;
371        let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
372        let codec_a = RaBitQCodec::calibrate(&[], dim, 0xDEAD_BEEF_1234_5678);
373        let codec_b = RaBitQCodec::calibrate(&[], dim, 0xCAFE_BABE_0000_0001);
374        let rot_a = codec_a.apply_rotation(&v);
375        let rot_b = codec_b.apply_rotation(&v);
376        // Different seeds → different rotation outputs.
377        let differ = rot_a
378            .iter()
379            .zip(rot_b.iter())
380            .any(|(a, b)| (a - b).abs() > 1e-6);
381        assert!(differ, "different seeds must produce different rotations");
382    }
383
384    #[test]
385    fn encode_roundtrip_preserves_residual_norm() {
386        let dim = 128;
387        let vecs: Vec<Vec<f32>> = (0..16).map(|i| random_vec(i as u64, dim)).collect();
388        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
389        let codec = RaBitQCodec::calibrate(&refs, dim, 42);
390        let v = random_vec(99, dim);
391        let q = codec.encode(&v);
392        let h = q.0.header();
393        // residual_norm should be finite and equal to global_scale.
394        assert!(h.residual_norm.is_finite() && h.residual_norm >= 0.0);
395        assert!((h.global_scale - h.residual_norm).abs() < 1e-6);
396    }
397
398    #[test]
399    fn distance_non_negative_finite() {
400        let dim = 64;
401        let vecs: Vec<Vec<f32>> = (0..8).map(|i| random_vec(i as u64, dim)).collect();
402        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
403        let codec = RaBitQCodec::calibrate(&refs, dim, 7);
404        let v1 = codec.encode(&random_vec(100, dim));
405        let v2 = codec.encode(&random_vec(200, dim));
406        let sym = codec.fast_symmetric_distance(&v1, &v2);
407        assert!(sym.is_finite() && sym >= 0.0, "sym distance: {sym}");
408        let q = codec.prepare_query(&random_vec(300, dim));
409        let asym = codec.exact_asymmetric_distance(&q, &v2);
410        assert!(asym.is_finite() && asym >= 0.0, "asym distance: {asym}");
411    }
412
413    #[test]
414    fn calibrate_identical_vectors_zero_residual() {
415        let dim = 32;
416        let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
417        let refs = vec![v.as_slice(); 16];
418        let codec = RaBitQCodec::calibrate(&refs, dim, 1);
419        // Centroid == v, so residual = 0 → residual_norm = 0.
420        let q = codec.encode(&v);
421        assert!(
422            q.0.header().residual_norm < 1e-5,
423            "residual_norm should be ~0 for vector equal to centroid"
424        );
425    }
426}