nodedb_codec/vector_quant/
rabitq.rs1use crate::vector_quant::codec::VectorCodec;
33use crate::vector_quant::hamming::hamming_distance;
34use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
35
36#[inline]
40fn xorshift64(state: &mut u64) -> u64 {
41 let mut x = *state;
42 x ^= x << 13;
43 x ^= x >> 7;
44 x ^= x << 17;
45 *state = x;
46 x
47}
48
49#[inline]
53fn next_pow2(n: usize) -> usize {
54 if n.is_power_of_two() {
55 n
56 } else {
57 n.next_power_of_two()
58 }
59}
60
61fn wht_inplace(buf: &mut [f32]) {
65 let n = buf.len();
66 debug_assert!(n.is_power_of_two());
67 let mut step = 1usize;
68 while step < n {
69 let mut i = 0usize;
70 while i < n {
71 for j in i..i + step {
72 let a = buf[j];
73 let b = buf[j + step];
74 buf[j] = a + b;
75 buf[j + step] = a - b;
76 }
77 i += step * 2;
78 }
79 step *= 2;
80 }
81}
82
83fn sign_pack(rotated: &[f32], dim: usize) -> Vec<u8> {
88 let nbytes = dim.div_ceil(8);
89 let mut out = vec![0u8; nbytes];
90 for (i, &v) in rotated.iter().take(dim).enumerate() {
91 if v < 0.0 {
92 out[i / 8] |= 1 << (i % 8);
93 }
94 }
95 out
96}
97
98fn sign_unpack(packed: &[u8], dim: usize) -> Vec<f32> {
100 (0..dim)
101 .map(|i| {
102 if packed[i / 8] & (1 << (i % 8)) != 0 {
103 -1.0f32
104 } else {
105 1.0f32
106 }
107 })
108 .collect()
109}
110
111pub struct RaBitQCodec {
117 pub dim: usize,
118 centroid: Vec<f32>,
120 rotation_seed: u64,
122 pub bias_correct: bool,
125}
126
127impl RaBitQCodec {
128 pub fn calibrate(vectors: &[&[f32]], dim: usize, rotation_seed: u64) -> Self {
137 let centroid = if vectors.is_empty() {
138 vec![0.0f32; dim]
139 } else {
140 let n = vectors.len() as f32;
141 let mut c = vec![0.0f32; dim];
142 for v in vectors {
143 for (ci, &vi) in c.iter_mut().zip(v.iter()) {
144 *ci += vi;
145 }
146 }
147 c.iter_mut().for_each(|x| *x /= n);
148 c
149 };
150 Self {
151 dim,
152 centroid,
153 rotation_seed,
154 bias_correct: false,
155 }
156 }
157
158 pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
166 let dim = self.dim;
167 let pow2 = next_pow2(dim);
168
169 let mut seed = self.rotation_seed;
171 let mut buf = vec![0.0f32; pow2];
172 for (i, &vi) in v.iter().take(dim).enumerate() {
173 let flip = if xorshift64(&mut seed) & 1 == 0 {
174 1.0f32
175 } else {
176 -1.0f32
177 };
178 buf[i] = vi * flip;
179 }
180 wht_inplace(&mut buf);
183 buf.truncate(dim);
184 buf
185 }
186
187 fn encode_inner(&self, v: &[f32]) -> UnifiedQuantizedVector {
198 let dim = self.dim;
199
200 let residual: Vec<f32> = v
202 .iter()
203 .zip(self.centroid.iter())
204 .map(|(&vi, &ci)| vi - ci)
205 .collect();
206
207 let residual_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
209
210 let rotated = self.apply_rotation(&residual);
212
213 let packed = sign_pack(&rotated, dim);
215
216 let signs_fp = sign_unpack(&packed, dim);
219 let pow2 = next_pow2(dim);
221 let mut sign_buf = vec![0.0f32; pow2];
222 for (i, &s) in signs_fp.iter().enumerate() {
223 sign_buf[i] = s;
224 }
225 wht_inplace(&mut sign_buf);
226 let mut seed = self.rotation_seed;
228 #[allow(clippy::needless_range_loop)]
229 for i in 0..dim {
230 let flip = if xorshift64(&mut seed) & 1 == 0 {
231 1.0f32
232 } else {
233 -1.0f32
234 };
235 sign_buf[i] *= flip;
236 }
237 let dot_raw: f32 = residual
238 .iter()
239 .zip(sign_buf.iter().take(dim))
240 .map(|(&r, &s)| r * s)
241 .sum();
242 let dot_quantized = if residual_norm > 0.0 {
243 dot_raw / residual_norm
244 } else {
245 0.0
246 };
247
248 let header = QuantHeader {
249 quant_mode: QuantMode::RaBitQ as u16,
250 dim: dim as u16,
251 global_scale: residual_norm,
252 residual_norm,
253 dot_quantized,
254 outlier_bitmask: 0,
255 reserved: [0u8; 8],
256 };
257
258 UnifiedQuantizedVector::new(header, &packed, &[])
259 .expect("RaBitQ encode: layout construction must succeed")
260 }
261}
262
263pub struct RaBitQQuantized(UnifiedQuantizedVector);
267
268impl AsRef<UnifiedQuantizedVector> for RaBitQQuantized {
269 #[inline]
270 fn as_ref(&self) -> &UnifiedQuantizedVector {
271 &self.0
272 }
273}
274
275pub struct RaBitQQuery {
277 pub rotated_signs: Vec<u8>,
279 pub query_norm: f32,
281}
282
283impl VectorCodec for RaBitQCodec {
286 type Quantized = RaBitQQuantized;
287 type Query = RaBitQQuery;
288
289 fn encode(&self, v: &[f32]) -> Self::Quantized {
290 RaBitQQuantized(self.encode_inner(v))
291 }
292
293 fn prepare_query(&self, q: &[f32]) -> Self::Query {
294 let dim = self.dim;
295 let residual: Vec<f32> = q
296 .iter()
297 .zip(self.centroid.iter())
298 .map(|(&qi, &ci)| qi - ci)
299 .collect();
300 let query_norm = residual.iter().map(|x| x * x).sum::<f32>().sqrt();
301 let rotated = self.apply_rotation(&residual);
302 let rotated_signs = sign_pack(&rotated, dim);
303 RaBitQQuery {
304 rotated_signs,
305 query_norm,
306 }
307 }
308
309 fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
316 let qh = q.0.header();
317 let vh = v.0.header();
318 let qb = q.0.packed_bits();
319 let vb = v.0.packed_bits();
320 let h = hamming_distance(qb, vb);
321 let dim = self.dim as f32;
322 let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
323 let approx = qh.residual_norm * qh.residual_norm + vh.residual_norm * vh.residual_norm
324 - 2.0 * qh.residual_norm * vh.residual_norm * dot_estimate;
325 approx.max(0.0)
326 }
327
328 fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
337 let vh = v.0.header();
338 let vb = v.0.packed_bits();
339 let h = hamming_distance(&q.rotated_signs, vb);
340 let dim = self.dim as f32;
341 let dot_estimate = 1.0 - 2.0 * h as f32 / dim;
342 let mut approx = q.query_norm * q.query_norm + vh.residual_norm * vh.residual_norm
343 - 2.0 * q.query_norm * vh.residual_norm * dot_estimate;
344 if self.bias_correct {
345 approx -= vh.dot_quantized;
346 }
347 approx.max(0.0)
348 }
349}
350
351#[cfg(test)]
354mod tests {
355 use super::*;
356
357 fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
358 let mut s = seed | 1;
359 (0..dim)
360 .map(|_| {
361 let v = xorshift64(&mut s);
362 (v as f32 / u64::MAX as f32) * 2.0 - 1.0
364 })
365 .collect()
366 }
367
368 #[test]
369 fn apply_rotation_different_seeds_differ() {
370 let dim = 64;
371 let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
372 let codec_a = RaBitQCodec::calibrate(&[], dim, 0xDEAD_BEEF_1234_5678);
373 let codec_b = RaBitQCodec::calibrate(&[], dim, 0xCAFE_BABE_0000_0001);
374 let rot_a = codec_a.apply_rotation(&v);
375 let rot_b = codec_b.apply_rotation(&v);
376 let differ = rot_a
378 .iter()
379 .zip(rot_b.iter())
380 .any(|(a, b)| (a - b).abs() > 1e-6);
381 assert!(differ, "different seeds must produce different rotations");
382 }
383
384 #[test]
385 fn encode_roundtrip_preserves_residual_norm() {
386 let dim = 128;
387 let vecs: Vec<Vec<f32>> = (0..16).map(|i| random_vec(i as u64, dim)).collect();
388 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
389 let codec = RaBitQCodec::calibrate(&refs, dim, 42);
390 let v = random_vec(99, dim);
391 let q = codec.encode(&v);
392 let h = q.0.header();
393 assert!(h.residual_norm.is_finite() && h.residual_norm >= 0.0);
395 assert!((h.global_scale - h.residual_norm).abs() < 1e-6);
396 }
397
398 #[test]
399 fn distance_non_negative_finite() {
400 let dim = 64;
401 let vecs: Vec<Vec<f32>> = (0..8).map(|i| random_vec(i as u64, dim)).collect();
402 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
403 let codec = RaBitQCodec::calibrate(&refs, dim, 7);
404 let v1 = codec.encode(&random_vec(100, dim));
405 let v2 = codec.encode(&random_vec(200, dim));
406 let sym = codec.fast_symmetric_distance(&v1, &v2);
407 assert!(sym.is_finite() && sym >= 0.0, "sym distance: {sym}");
408 let q = codec.prepare_query(&random_vec(300, dim));
409 let asym = codec.exact_asymmetric_distance(&q, &v2);
410 assert!(asym.is_finite() && asym >= 0.0, "asym distance: {asym}");
411 }
412
413 #[test]
414 fn calibrate_identical_vectors_zero_residual() {
415 let dim = 32;
416 let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
417 let refs = vec![v.as_slice(); 16];
418 let codec = RaBitQCodec::calibrate(&refs, dim, 1);
419 let q = codec.encode(&v);
421 assert!(
422 q.0.header().residual_norm < 1e-5,
423 "residual_norm should be ~0 for vector equal to centroid"
424 );
425 }
426}