1use alloc::vec;
26use alloc::vec::Vec;
27
28use crate::tier::TemperatureTier;
29use crate::traits::Quantizer;
30
31pub const DEFAULT_ROUNDS: u8 = 3;
34
35pub const CORRECTION_BYTES: usize = 8;
37
38#[derive(Clone, Debug)]
40pub struct RabitqQuantizer {
41 pub dim: usize,
43 pub padded_dim: usize,
45 pub seed: u64,
47 pub rounds: u8,
49 pub centroid: Vec<f32>,
51}
52
53#[derive(Clone, Debug, PartialEq)]
55pub struct RabitqCode {
56 pub bits: Vec<u8>,
59 pub norm: f32,
61 pub dot_corr: f32,
64}
65
66impl RabitqCode {
67 #[inline]
69 pub fn stored_bytes(&self) -> usize {
70 self.bits.len() + CORRECTION_BYTES
71 }
72}
73
74#[derive(Clone, Debug)]
77pub struct RabitqQuery {
78 pub rotated: Vec<f32>,
80 pub norm_sq: f32,
82}
83
84#[inline]
87fn splitmix64(x: u64) -> u64 {
88 let mut z = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
89 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
90 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
91 z ^ (z >> 31)
92}
93
94#[inline]
96fn next_pow2(n: usize) -> usize {
97 n.max(1).next_power_of_two()
98}
99
100fn fwht(v: &mut [f32]) {
103 let n = v.len();
104 let mut h = 1;
105 while h < n {
106 let mut i = 0;
107 while i < n {
108 for j in i..i + h {
109 let x = v[j];
110 let y = v[j + h];
111 v[j] = x + y;
112 v[j + h] = x - y;
113 }
114 i += h * 2;
115 }
116 h *= 2;
117 }
118}
119
120impl RabitqQuantizer {
121 pub fn train(vectors: &[&[f32]], seed: u64) -> Self {
128 assert!(!vectors.is_empty(), "need at least one training vector");
129 let dim = vectors[0].len();
130 assert!(dim > 0, "vector dimensionality must be > 0");
131
132 let mut centroid = vec![0.0f64; dim];
133 for v in vectors {
134 assert_eq!(v.len(), dim, "dimension mismatch in training data");
135 for (acc, &x) in centroid.iter_mut().zip(v.iter()) {
136 *acc += x as f64;
137 }
138 }
139 let inv_n = 1.0 / vectors.len() as f64;
140 let centroid: Vec<f32> = centroid.iter().map(|&s| (s * inv_n) as f32).collect();
141
142 Self::with_centroid(dim, centroid, seed, DEFAULT_ROUNDS)
143 }
144
145 pub fn with_centroid(dim: usize, centroid: Vec<f32>, seed: u64, rounds: u8) -> Self {
147 assert_eq!(centroid.len(), dim, "centroid length must equal dim");
148 Self {
149 dim,
150 padded_dim: next_pow2(dim),
151 seed,
152 rounds: rounds.max(1),
153 centroid,
154 }
155 }
156
157 #[inline]
160 fn sign_flip(&self, round: u8, i: usize) -> bool {
161 let word = splitmix64(
164 self.seed
165 ^ (round as u64).wrapping_mul(0xA076_1D64_78BD_642F)
166 ^ ((i as u64) / 64).wrapping_mul(0xE703_7ED1_A0B4_28DB),
167 );
168 (word >> (i % 64)) & 1 == 1
169 }
170
171 pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
174 debug_assert!(v.len() <= self.padded_dim);
175 let mut buf = vec![0.0f32; self.padded_dim];
176 buf[..v.len()].copy_from_slice(v);
177 let scale = 1.0 / (self.padded_dim as f32).sqrt();
178 for round in 0..self.rounds {
179 for (i, x) in buf.iter_mut().enumerate() {
180 if self.sign_flip(round, i) {
181 *x = -*x;
182 }
183 }
184 fwht(&mut buf);
185 for x in buf.iter_mut() {
186 *x *= scale;
187 }
188 }
189 buf
190 }
191
192 pub fn rotate_inverse(&self, v: &[f32]) -> Vec<f32> {
195 debug_assert_eq!(v.len(), self.padded_dim);
196 let mut buf = v.to_vec();
197 let scale = 1.0 / (self.padded_dim as f32).sqrt();
198 for round in (0..self.rounds).rev() {
199 fwht(&mut buf);
200 for x in buf.iter_mut() {
201 *x *= scale;
202 }
203 for (i, x) in buf.iter_mut().enumerate() {
204 if self.sign_flip(round, i) {
205 *x = -*x;
206 }
207 }
208 }
209 buf
210 }
211
212 pub fn encode_code(&self, vector: &[f32]) -> RabitqCode {
215 assert_eq!(vector.len(), self.dim, "vector dimension mismatch");
216 let centered: Vec<f32> = vector
217 .iter()
218 .zip(self.centroid.iter())
219 .map(|(&x, &c)| x - c)
220 .collect();
221 let rotated = self.rotate(¢ered);
222
223 let mut norm_sq = 0.0f32;
224 let mut abs_sum = 0.0f32;
225 let mut bits = vec![0u8; self.padded_dim.div_ceil(8)];
226 for (d, &x) in rotated.iter().enumerate() {
227 norm_sq += x * x;
228 abs_sum += x.abs();
229 if x >= 0.0 {
230 bits[d / 8] |= 1 << (d % 8);
231 }
232 }
233 let norm = norm_sq.sqrt();
234 let dot_corr = if norm > f32::EPSILON {
238 (abs_sum / (norm * (self.padded_dim as f32).sqrt())).max(f32::EPSILON)
239 } else {
240 1.0
241 };
242 RabitqCode {
243 bits,
244 norm,
245 dot_corr,
246 }
247 }
248
249 pub fn prepare_query(&self, query: &[f32]) -> RabitqQuery {
251 assert_eq!(query.len(), self.dim, "query dimension mismatch");
252 let centered: Vec<f32> = query
253 .iter()
254 .zip(self.centroid.iter())
255 .map(|(&x, &c)| x - c)
256 .collect();
257 let rotated = self.rotate(¢ered);
258 let norm_sq = rotated.iter().map(|&x| x * x).sum();
259 RabitqQuery { rotated, norm_sq }
260 }
261
262 pub fn estimate_l2_sq(&self, query: &RabitqQuery, code: &RabitqCode) -> f32 {
270 let mut signed_sum = 0.0f32;
271 for (d, &x) in query.rotated.iter().enumerate() {
272 if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
273 signed_sum += x;
274 } else {
275 signed_sum -= x;
276 }
277 }
278 let est_ip = code.norm * (signed_sum / (self.padded_dim as f32).sqrt()) / code.dot_corr;
279 code.norm * code.norm + query.norm_sq - 2.0 * est_ip
280 }
281
282 #[inline]
284 pub fn stored_bytes_per_vector(&self) -> usize {
285 self.padded_dim.div_ceil(8) + CORRECTION_BYTES
286 }
287
288 #[inline]
290 pub fn compression_ratio(&self) -> f32 {
291 (self.dim * 4) as f32 / self.stored_bytes_per_vector() as f32
292 }
293
294 pub fn code_to_bytes(&self, code: &RabitqCode) -> Vec<u8> {
296 let mut out = Vec::with_capacity(code.stored_bytes());
297 out.extend_from_slice(&code.bits);
298 out.extend_from_slice(&code.norm.to_le_bytes());
299 out.extend_from_slice(&code.dot_corr.to_le_bytes());
300 out
301 }
302
303 pub fn code_from_bytes(&self, data: &[u8]) -> Option<RabitqCode> {
306 let nbits = self.padded_dim.div_ceil(8);
307 if data.len() < nbits + CORRECTION_BYTES {
308 return None;
309 }
310 let bits = data[..nbits].to_vec();
311 let norm = f32::from_le_bytes(data[nbits..nbits + 4].try_into().ok()?);
312 let dot_corr = f32::from_le_bytes(data[nbits + 4..nbits + 8].try_into().ok()?);
313 Some(RabitqCode {
314 bits,
315 norm,
316 dot_corr,
317 })
318 }
319}
320
321impl Quantizer for RabitqQuantizer {
322 fn encode(&self, vector: &[f32]) -> Vec<u8> {
323 self.code_to_bytes(&self.encode_code(vector))
324 }
325
326 fn decode(&self, codes: &[u8]) -> Vec<f32> {
327 let code = match self.code_from_bytes(codes) {
328 Some(c) => c,
329 None => return vec![0.0; self.dim],
330 };
331 let scale = code.norm * code.dot_corr / (self.padded_dim as f32).sqrt();
335 let mut rotated = Vec::with_capacity(self.padded_dim);
336 for d in 0..self.padded_dim {
337 let sign = if (code.bits[d / 8] >> (d % 8)) & 1 == 1 {
338 1.0
339 } else {
340 -1.0
341 };
342 rotated.push(sign * scale);
343 }
344 let residual = self.rotate_inverse(&rotated);
345 residual
346 .iter()
347 .take(self.dim)
348 .zip(self.centroid.iter())
349 .map(|(&r, &c)| r + c)
350 .collect()
351 }
352
353 fn tier(&self) -> TemperatureTier {
354 TemperatureTier::Cold
355 }
356
357 fn dim(&self) -> usize {
358 self.dim
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 fn lcg_vector(dim: usize, seed: u64) -> Vec<f32> {
367 let mut x = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
368 (0..dim)
369 .map(|_| {
370 x = x
371 .wrapping_mul(6364136223846793005)
372 .wrapping_add(1442695040888963407);
373 ((x >> 33) as f32) / (u32::MAX as f32) - 0.5
374 })
375 .collect()
376 }
377
378 fn make_quantizer(dim: usize, n: usize) -> (RabitqQuantizer, Vec<Vec<f32>>) {
379 let data: Vec<Vec<f32>> = (0..n).map(|i| lcg_vector(dim, i as u64)).collect();
380 let refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
381 (RabitqQuantizer::train(&refs, 0xDEAD_BEEF), data)
382 }
383
384 fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
385 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
386 }
387
388 #[test]
389 fn rotation_is_orthonormal_and_deterministic() {
390 let (rq, data) = make_quantizer(100, 8); assert_eq!(rq.padded_dim, 128);
392 for v in &data {
393 let r1 = rq.rotate(v);
394 let r2 = rq.rotate(v);
395 assert_eq!(r1, r2, "rotation must be deterministic");
396
397 let norm_in: f32 = v.iter().map(|x| x * x).sum();
398 let norm_out: f32 = r1.iter().map(|x| x * x).sum();
399 assert!(
400 (norm_in - norm_out).abs() < 1e-3 * norm_in.max(1.0),
401 "rotation must preserve norms: {norm_in} vs {norm_out}"
402 );
403
404 let back = rq.rotate_inverse(&r1);
406 for (d, (&orig, &rec)) in v.iter().zip(back.iter()).enumerate() {
407 assert!(
408 (orig - rec).abs() < 1e-4,
409 "dim {d}: {orig} != {rec} after inverse rotation"
410 );
411 }
412 for &pad in &back[v.len()..] {
413 assert!(pad.abs() < 1e-4, "padding must invert to ~0");
414 }
415 }
416 }
417
418 #[test]
419 fn rotation_preserves_inner_products() {
420 let (rq, data) = make_quantizer(64, 4);
421 let a = &data[0];
422 let b = &data[1];
423 let ip: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
424 let ra = rq.rotate(a);
425 let rb = rq.rotate(b);
426 let rip: f32 = ra.iter().zip(rb.iter()).map(|(x, y)| x * y).sum();
427 assert!((ip - rip).abs() < 1e-3, "ip {ip} vs rotated ip {rip}");
428 }
429
430 #[test]
431 fn different_seeds_give_different_rotations() {
432 let v = lcg_vector(32, 7);
433 let a = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 1, DEFAULT_ROUNDS);
434 let b = RabitqQuantizer::with_centroid(32, vec![0.0; 32], 2, DEFAULT_ROUNDS);
435 assert_ne!(a.rotate(&v), b.rotate(&v));
436 }
437
438 #[test]
439 fn code_round_trip_bytes() {
440 let (rq, data) = make_quantizer(48, 16);
441 for v in &data {
442 let code = rq.encode_code(v);
443 let bytes = rq.code_to_bytes(&code);
444 assert_eq!(bytes.len(), rq.stored_bytes_per_vector());
445 let back = rq.code_from_bytes(&bytes).expect("decode");
446 assert_eq!(back, code);
447 }
448 let code = rq.encode_code(&data[0]);
450 let bytes = rq.code_to_bytes(&code);
451 assert!(rq.code_from_bytes(&bytes[..bytes.len() - 1]).is_none());
452 assert!(rq.code_from_bytes(&[]).is_none());
453 }
454
455 #[test]
456 fn decode_reconstruction_beats_naive_sign_bits() {
457 let (rq, data) = make_quantizer(128, 64);
460 let mut rabitq_err = 0.0f64;
461 let mut naive_err = 0.0f64;
462 for v in &data {
463 let rec = rq.decode(&rq.encode(v));
464 rabitq_err += l2_sq(v, &rec) as f64;
465
466 let bits = crate::binary::encode_binary(v);
467 let nrec = crate::binary::decode_binary(&bits, v.len());
468 naive_err += l2_sq(v, &nrec) as f64;
469 }
470 assert!(
471 rabitq_err < naive_err,
472 "RaBitQ reconstruction error {rabitq_err} must beat naive {naive_err}"
473 );
474 }
475
476 #[test]
477 fn estimator_correlates_with_true_distances() {
478 let dim = 128;
481 let (rq, data) = make_quantizer(dim, 200);
482 let codes: Vec<RabitqCode> = data.iter().map(|v| rq.encode_code(v)).collect();
483
484 let mut est = Vec::new();
485 let mut truth = Vec::new();
486 for qi in 0..20u64 {
487 let q = lcg_vector(dim, 5_000 + qi);
488 let prepared = rq.prepare_query(&q);
489 for (v, code) in data.iter().zip(codes.iter()) {
490 est.push(rq.estimate_l2_sq(&prepared, code) as f64);
491 truth.push(l2_sq(&q, v) as f64);
492 }
493 }
494
495 let n = est.len() as f64;
496 let me = est.iter().sum::<f64>() / n;
497 let mt = truth.iter().sum::<f64>() / n;
498 let mut cov = 0.0;
499 let mut ve = 0.0;
500 let mut vt = 0.0;
501 for (&e, &t) in est.iter().zip(truth.iter()) {
502 cov += (e - me) * (t - mt);
503 ve += (e - me) * (e - me);
504 vt += (t - mt) * (t - mt);
505 }
506 let corr = cov / (ve.sqrt() * vt.sqrt());
507 #[cfg(feature = "std")]
508 std::eprintln!("estimator/true distance correlation (128d): {corr:.4}");
509 assert!(
510 corr > 0.8,
511 "estimator correlation {corr:.3} too weak (expected > 0.8)"
512 );
513
514 let mean_rel: f64 = est
517 .iter()
518 .zip(truth.iter())
519 .map(|(&e, &t)| ((e - t) / t.max(1e-9)).abs())
520 .sum::<f64>()
521 / n;
522 #[cfg(feature = "std")]
523 std::eprintln!("estimator mean relative distance error (128d): {mean_rel:.4}");
524 assert!(
525 mean_rel < 0.25,
526 "mean relative error {mean_rel:.3} too large"
527 );
528 }
529
530 #[test]
531 fn compression_ratio_targets() {
532 let rq128 = RabitqQuantizer::with_centroid(128, vec![0.0; 128], 1, DEFAULT_ROUNDS);
535 assert_eq!(rq128.padded_dim, 128);
536 assert_eq!((rq128.dim * 4) / (rq128.padded_dim / 8), 32);
537 assert!(rq128.compression_ratio() >= 20.0);
538
539 let rq1024 = RabitqQuantizer::with_centroid(1024, vec![0.0; 1024], 1, DEFAULT_ROUNDS);
540 assert!(rq1024.compression_ratio() >= 30.0);
541 }
542
543 #[test]
544 fn zero_residual_vector_is_safe() {
545 let (rq, _) = make_quantizer(16, 4);
546 let code = rq.encode_code(&rq.centroid.clone());
547 assert!(code.norm <= 1e-6);
548 let q = rq.prepare_query(&lcg_vector(16, 99));
549 let est = rq.estimate_l2_sq(&q, &code);
550 assert!(est.is_finite());
551 }
552}