Skip to main content

nodedb_codec/vector_quant/
rabitq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! RaBitQ — 1-bit quantization with O(1/√D) MSE error bound (SIGMOD 2024).
4//!
5//! Algorithm outline
6//! -----------------
7//! 1. **Calibration**: compute centroid `c` over training vectors.
8//! 2. **Rotation**: apply a randomised Walsh-Hadamard transform (WHT) with a
9//!    seed-derived signed-diagonal matrix `D` to the residual `v - c`.
10//!    - WHT requires a power-of-2 length; dimensions that are not pow2 are
11//!      zero-padded before the transform then truncated after.
12//!    - `D` is a vector of `±1` scalars derived deterministically from
13//!      `rotation_seed` using an xorshift64 generator.
14//! 3. **Encoding**: `code = sign(R·(v-c))`, packed 1-bit-per-dimension.
15//! 4. **Distance estimation** (L2):
16//!    `‖q-v‖² ≈ ‖v-c‖² + ‖q-c‖² − 2‖v-c‖‖q-c‖·(1 − 2·hamming/D)`
17//!    with `O(1/√D)` MSE — see SIGMOD 2024 Theorem 4.
18//!
19//! IP / raw inner-product bias
20//! ---------------------------
21//! The angular estimator above is MSE-optimal for L2 and cosine. For raw IP
22//! it carries a systematic bias. When `bias_correct = true` the codec
23//! subtracts `dot_quantized` (stored in the [`QuantHeader`]) from the
24//! asymmetric distance to partially compensate, following the TurboQuant-style
25//! QJL residual correction. This is off by default; consumers that use raw IP
26//! (RAG, attention scores) should enable it.
27//!
28//! [`rand_xoshiro`] is **not** a dependency of `nodedb-codec`. The signed-
29//! diagonal flip vector is derived from `rotation_seed` via an inline
30//! xorshift64 generator — no external crate required.
31
32use crate::error::CodecError;
33use crate::vector_quant::codec::VectorCodec;
34use crate::vector_quant::codec_envelope;
35use crate::vector_quant::hamming::hamming_distance;
36use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
37use serde::{Deserialize, Serialize};
38
39// ── Xorshift64 (inline PRNG) ────────────────────────────────────────────────
40
41/// Minimal xorshift64 PRNG for deterministic signed-diagonal generation.
42#[inline]
43fn xorshift64(state: &mut u64) -> u64 {
44    let mut x = *state;
45    x ^= x << 13;
46    x ^= x >> 7;
47    x ^= x << 17;
48    *state = x;
49    x
50}
51
52// ── WHT helpers ──────────────────────────────────────────────────────────────
53
54/// Next power-of-two ≥ `n`, returning `n` itself if already pow2.
55#[inline]
56fn next_pow2(n: usize) -> usize {
57    if n.is_power_of_two() {
58        n
59    } else {
60        n.next_power_of_two()
61    }
62}
63
64/// In-place Walsh-Hadamard Transform of a power-of-2 length slice.
65/// O(N log N) butterfly. Does not normalise by 1/√N (sign-only code
66/// does not require normalisation).
67fn wht_inplace(buf: &mut [f32]) {
68    let n = buf.len();
69    debug_assert!(n.is_power_of_two());
70    let mut step = 1usize;
71    while step < n {
72        let mut i = 0usize;
73        while i < n {
74            for j in i..i + step {
75                let a = buf[j];
76                let b = buf[j + step];
77                buf[j] = a + b;
78                buf[j + step] = a - b;
79            }
80            i += step * 2;
81        }
82        step *= 2;
83    }
84}
85
86// ── Sign-pack / unpack helpers ───────────────────────────────────────────────
87
88/// Pack `dim` signs from `rotated` (negative = 1-bit, non-negative = 0-bit)
89/// into ceil(dim/8) bytes, LSB-first.
90fn sign_pack(rotated: &[f32], dim: usize) -> Vec<u8> {
91    let nbytes = dim.div_ceil(8);
92    let mut out = vec![0u8; nbytes];
93    for (i, &v) in rotated.iter().take(dim).enumerate() {
94        if v < 0.0 {
95            out[i / 8] |= 1 << (i % 8);
96        }
97    }
98    out
99}
100
101/// Dequantize sign-packed bits back to ±1 values (for dot_quantized calc).
102fn sign_unpack(packed: &[u8], dim: usize) -> Vec<f32> {
103    (0..dim)
104        .map(|i| {
105            if packed[i / 8] & (1 << (i % 8)) != 0 {
106                -1.0f32
107            } else {
108                1.0f32
109            }
110        })
111        .collect()
112}
113
114// ── RaBitQCodec ──────────────────────────────────────────────────────────────
115
116/// RaBitQ codec: 1-bit quantization with O(1/√D) MSE error bound.
117///
118/// See module-level documentation for algorithm details.
119#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
120pub struct RaBitQCodec {
121    pub dim: usize,
122    /// Mean of training vectors; subtracted from each vector before rotation.
123    centroid: Vec<f32>,
124    /// Seed for the signed-diagonal flip vector used by the WHT rotation.
125    rotation_seed: u64,
126    /// If `true`, subtract `dot_quantized` from the asymmetric distance
127    /// estimate as a TurboQuant-style IP-bias correction term.
128    pub bias_correct: bool,
129}
130
131impl RaBitQCodec {
132    /// Calibrate a new codec from a set of training vectors.
133    ///
134    /// - Computes the centroid as the coordinate-wise mean of `vectors`.
135    /// - Stores `rotation_seed` for reproducible signed-diagonal generation.
136    ///
137    /// # Errors (none — returns Self directly)
138    ///
139    /// Returns a zero-centroid codec if `vectors` is empty.
140    pub fn calibrate(vectors: &[&[f32]], dim: usize, rotation_seed: u64) -> Self {
141        let centroid = if vectors.is_empty() {
142            vec![0.0f32; dim]
143        } else {
144            let n = vectors.len() as f32;
145            let mut c = vec![0.0f32; dim];
146            for v in vectors {
147                for (ci, &vi) in c.iter_mut().zip(v.iter()) {
148                    *ci += vi;
149                }
150            }
151            c.iter_mut().for_each(|x| *x /= n);
152            c
153        };
154        Self {
155            dim,
156            centroid,
157            rotation_seed,
158            bias_correct: false,
159        }
160    }
161
162    /// On-disk magic for RaBitQ codec envelopes.
163    pub const ENVELOPE_MAGIC: &'static [u8; codec_envelope::MAGIC_LEN] = b"NDRBQ";
164
165    /// Current on-disk envelope version for RaBitQ codecs.
166    pub const ENVELOPE_VERSION: u8 = 1;
167
168    /// Serialize this codec to a self-describing byte buffer.
169    pub fn to_bytes(&self) -> Result<Vec<u8>, CodecError> {
170        codec_envelope::encode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, self)
171    }
172
173    /// Deserialize a codec from a byte buffer produced by [`Self::to_bytes`].
174    pub fn from_bytes(buf: &[u8]) -> Result<Self, CodecError> {
175        codec_envelope::decode(Self::ENVELOPE_MAGIC, Self::ENVELOPE_VERSION, buf)
176    }
177
178    /// Apply the randomised WHT rotation to a residual vector.
179    ///
180    /// Steps:
181    /// 1. Apply signed-diagonal `D` (deterministic from `rotation_seed`).
182    /// 2. Zero-pad to the next power of two if `dim` is not pow2.
183    /// 3. WHT in-place.
184    /// 4. Truncate back to `dim`.
185    pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
186        let dim = self.dim;
187        let pow2 = next_pow2(dim);
188
189        // Signed-diagonal multiply: generate ±1 per dim from seed.
190        let mut seed = self.rotation_seed;
191        let mut buf = vec![0.0f32; pow2];
192        for (i, &vi) in v.iter().take(dim).enumerate() {
193            let flip = if xorshift64(&mut seed) & 1 == 0 {
194                1.0f32
195            } else {
196                -1.0f32
197            };
198            buf[i] = vi * flip;
199        }
200        // Trailing elements remain zero (pad).
201
202        wht_inplace(&mut buf);
203        buf.truncate(dim);
204        buf
205    }
206
207    /// Encode a single vector into a [`UnifiedQuantizedVector`] with
208    /// `QuantMode::RaBitQ`.
209    ///
210    /// The header fields populated are:
211    /// - `global_scale` = `residual_norm` (‖v−c‖); both store the same value
212    ///   so that consumers that use either field without context still have the
213    ///   magnitude available.
214    /// - `residual_norm` = ‖v−c‖.
215    /// - `dot_quantized` = ⟨residual, dequantised_sign_vector⟩ / ‖v−c‖;
216    ///   used for IP-bias correction when `bias_correct = true`.
217    fn encode_inner(&self, v: &[f32]) -> UnifiedQuantizedVector {
218        let dim = self.dim;
219
220        // Step 1: residual = v - centroid
221        let residual: Vec<f32> = v
222            .iter()
223            .zip(self.centroid.iter())
224            .map(|(&vi, &ci)| vi - ci)
225            .collect();
226
227        // Step 2: ‖residual‖
228        let residual_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
229
230        // Step 3: rotate
231        let rotated = self.apply_rotation(&residual);
232
233        // Step 4: sign-pack → 1-bit code
234        let packed = sign_pack(&rotated, dim);
235
236        // Step 5: compute dot_quantized = ⟨residual, R⁻¹·sign(rotated)⟩ / ‖residual‖
237        // Inverse rotation of the sign vector, then dot with original residual.
238        let signs_fp = sign_unpack(&packed, dim);
239        // Inverse WHT rotation: apply WHT again then re-apply D⁻¹ = D (since D² = I).
240        let pow2 = next_pow2(dim);
241        let mut sign_buf = vec![0.0f32; pow2];
242        for (i, &s) in signs_fp.iter().enumerate() {
243            sign_buf[i] = s;
244        }
245        wht_inplace(&mut sign_buf);
246        // Re-apply signed diagonal (D is its own inverse since flips are ±1)
247        let mut seed = self.rotation_seed;
248        #[allow(clippy::needless_range_loop)]
249        for i in 0..dim {
250            let flip = if xorshift64(&mut seed) & 1 == 0 {
251                1.0f32
252            } else {
253                -1.0f32
254            };
255            sign_buf[i] *= flip;
256        }
257        let dot_raw: f32 = residual
258            .iter()
259            .zip(sign_buf.iter().take(dim))
260            .map(|(&r, &s)| r * s)
261            .sum();
262        let dot_quantized = if residual_norm > 0.0 {
263            dot_raw / residual_norm
264        } else {
265            0.0
266        };
267
268        let header = QuantHeader {
269            quant_mode: QuantMode::RaBitQ as u16,
270            dim: dim as u16,
271            global_scale: residual_norm,
272            residual_norm,
273            dot_quantized,
274            outlier_bitmask: 0,
275            reserved: [0u8; 8],
276        };
277
278        UnifiedQuantizedVector::new(header, &packed, &[])
279            .expect("RaBitQ encode: layout construction must succeed")
280    }
281}
282
283// ── Quantized / Query newtypes ───────────────────────────────────────────────
284
285/// Packed 1-bit quantized vector produced by [`RaBitQCodec::encode`].
286pub struct RaBitQQuantized(UnifiedQuantizedVector);
287
288impl AsRef<UnifiedQuantizedVector> for RaBitQQuantized {
289    #[inline]
290    fn as_ref(&self) -> &UnifiedQuantizedVector {
291        &self.0
292    }
293}
294
295/// Prepared query for [`RaBitQCodec`] distance computations.
296pub struct RaBitQQuery {
297    /// Sign-packed rotated query (same bit layout as the stored codes).
298    pub rotated_signs: Vec<u8>,
299    /// ‖q − centroid‖.
300    pub query_norm: f32,
301}
302
303// ── VectorCodec impl ─────────────────────────────────────────────────────────
304
305impl VectorCodec for RaBitQCodec {
306    type Quantized = RaBitQQuantized;
307    type Query = RaBitQQuery;
308
309    fn encode(&self, v: &[f32]) -> Self::Quantized {
310        RaBitQQuantized(self.encode_inner(v))
311    }
312
313    fn prepare_query(&self, q: &[f32]) -> Self::Query {
314        let dim = self.dim;
315        let residual: Vec<f32> = q
316            .iter()
317            .zip(self.centroid.iter())
318            .map(|(&qi, &ci)| qi - ci)
319            .collect();
320        let query_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
321        let rotated = self.apply_rotation(&residual);
322        let rotated_signs = sign_pack(&rotated, dim);
323        RaBitQQuery {
324            rotated_signs,
325            query_norm,
326        }
327    }
328
329    /// Symmetric distance estimate: both `q` and `v` are quantized.
330    ///
331    /// `approx_l2 = ‖v-c‖² + ‖q-c‖² − 2·‖v-c‖·‖q-c‖·(1 − 2·hamming/D)`
332    ///
333    /// The angular factor `1 − 2·hamming/D` approximates `cos(θ)` between
334    /// the two sign-vectors. Error bound: O(1/√D) MSE.
335    fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
336        let qh = q.0.header();
337        let vh = v.0.header();
338        let qb = q.0.packed_bits();
339        let vb = v.0.packed_bits();
340        let h = hamming_distance(qb, vb);
341        let dim = self.dim as f32;
342        let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
343        let approx = qh.residual_norm * qh.residual_norm + vh.residual_norm * vh.residual_norm
344            - 2.0 * qh.residual_norm * vh.residual_norm * dot_estimate;
345        approx.max(0.0)
346    }
347
348    /// Asymmetric distance estimate: `q` is a prepared [`RaBitQQuery`], `v`
349    /// is a stored quantized vector.
350    ///
351    /// Uses `query_norm` (exact ‖q−c‖) against `v.residual_norm` (exact ‖v−c‖)
352    /// for higher fidelity than the symmetric variant.
353    ///
354    /// If `self.bias_correct = true`, subtract `v.dot_quantized` as a first-
355    /// order IP-bias correction term (TurboQuant-style).
356    fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
357        let vh = v.0.header();
358        let vb = v.0.packed_bits();
359        let h = hamming_distance(&q.rotated_signs, vb);
360        let dim = self.dim as f32;
361        let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
362        let mut approx = q.query_norm * q.query_norm + vh.residual_norm * vh.residual_norm
363            - 2.0 * q.query_norm * vh.residual_norm * dot_estimate;
364        if self.bias_correct {
365            approx -= vh.dot_quantized;
366        }
367        approx.max(0.0)
368    }
369}
370
371// ── Tests ─────────────────────────────────────────────────────────────────────
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
378        let mut s = seed | 1;
379        (0..dim)
380            .map(|_| {
381                let v = xorshift64(&mut s);
382                // Map to [-1, 1]
383                (v as f32 / u64::MAX as f32) * 2.0 - 1.0
384            })
385            .collect()
386    }
387
388    #[test]
389    fn to_bytes_from_bytes_roundtrip() {
390        let dim = 64;
391        let vecs: Vec<Vec<f32>> = (0..4).map(|i| random_vec(i as u64, dim)).collect();
392        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
393        let codec = RaBitQCodec::calibrate(&refs, dim, 0xABCD_1234_5678_EF01);
394        let bytes = codec.to_bytes().expect("to_bytes should succeed");
395        let restored = RaBitQCodec::from_bytes(&bytes).expect("from_bytes should succeed");
396        assert_eq!(restored.dim, codec.dim);
397        assert_eq!(restored.rotation_seed, codec.rotation_seed);
398        assert_eq!(restored.bias_correct, codec.bias_correct);
399        assert_eq!(restored.centroid.len(), codec.centroid.len());
400        for (a, b) in restored.centroid.iter().zip(codec.centroid.iter()) {
401            assert!((a - b).abs() < 1e-6, "centroid mismatch: {a} vs {b}");
402        }
403    }
404
405    #[test]
406    fn from_bytes_rejects_bad_magic() {
407        let mut bytes = b"WRONG".to_vec();
408        bytes.push(1);
409        bytes.extend_from_slice(&[0u8; 4]);
410        assert!(RaBitQCodec::from_bytes(&bytes).is_err());
411    }
412
413    #[test]
414    fn from_bytes_rejects_bad_version() {
415        let codec = RaBitQCodec::calibrate(&[], 4, 1);
416        let mut bytes = codec.to_bytes().unwrap();
417        bytes[5] = 99;
418        assert!(RaBitQCodec::from_bytes(&bytes).is_err());
419    }
420
421    #[test]
422    fn apply_rotation_different_seeds_differ() {
423        let dim = 64;
424        let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
425        let codec_a = RaBitQCodec::calibrate(&[], dim, 0xDEAD_BEEF_1234_5678);
426        let codec_b = RaBitQCodec::calibrate(&[], dim, 0xCAFE_BABE_0000_0001);
427        let rot_a = codec_a.apply_rotation(&v);
428        let rot_b = codec_b.apply_rotation(&v);
429        // Different seeds → different rotation outputs.
430        let differ = rot_a
431            .iter()
432            .zip(rot_b.iter())
433            .any(|(a, b)| (a - b).abs() > 1e-6);
434        assert!(differ, "different seeds must produce different rotations");
435    }
436
437    #[test]
438    fn encode_roundtrip_preserves_residual_norm() {
439        let dim = 128;
440        let vecs: Vec<Vec<f32>> = (0..16).map(|i| random_vec(i as u64, dim)).collect();
441        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
442        let codec = RaBitQCodec::calibrate(&refs, dim, 42);
443        let v = random_vec(99, dim);
444        let q = codec.encode(&v);
445        let h = q.0.header();
446        // residual_norm should be finite and equal to global_scale.
447        assert!(h.residual_norm.is_finite() && h.residual_norm >= 0.0);
448        assert!((h.global_scale - h.residual_norm).abs() < 1e-6);
449    }
450
451    #[test]
452    fn distance_non_negative_finite() {
453        let dim = 64;
454        let vecs: Vec<Vec<f32>> = (0..8).map(|i| random_vec(i as u64, dim)).collect();
455        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
456        let codec = RaBitQCodec::calibrate(&refs, dim, 7);
457        let v1 = codec.encode(&random_vec(100, dim));
458        let v2 = codec.encode(&random_vec(200, dim));
459        let sym = codec.fast_symmetric_distance(&v1, &v2);
460        assert!(sym.is_finite() && sym >= 0.0, "sym distance: {sym}");
461        let q = codec.prepare_query(&random_vec(300, dim));
462        let asym = codec.exact_asymmetric_distance(&q, &v2);
463        assert!(asym.is_finite() && asym >= 0.0, "asym distance: {asym}");
464    }
465
466    #[test]
467    fn calibrate_identical_vectors_zero_residual() {
468        let dim = 32;
469        let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
470        let refs = vec![v.as_slice(); 16];
471        let codec = RaBitQCodec::calibrate(&refs, dim, 1);
472        // Centroid == v, so residual = 0 → residual_norm = 0.
473        let q = codec.encode(&v);
474        assert!(
475            q.0.header().residual_norm < 1e-5,
476            "residual_norm should be ~0 for vector equal to centroid"
477        );
478    }
479}