1use rand::Rng;
30use serde::{Deserialize, Serialize};
31
32#[cfg(target_arch = "aarch64")]
34#[allow(unused_imports)]
35use std::arch::aarch64::*;
36
37#[cfg(target_arch = "x86_64")]
38#[allow(unused_imports)]
39use std::arch::x86_64::*;
40
41#[inline]
46fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
47 #[cfg(target_arch = "aarch64")]
49 {
50 if let Some(result) = lut_dot_product_neon(bits, luts) {
51 return result;
52 }
53 }
54
55 #[cfg(target_arch = "x86_64")]
56 {
57 if is_x86_feature_detected!("ssse3") {
58 unsafe {
60 if let Some(result) = lut_dot_product_ssse3(bits, luts) {
61 return result;
62 }
63 }
64 }
65 }
66
67 lut_dot_product_scalar(bits, luts)
69}
70
71#[inline]
73fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
74 let mut dot_sum = 0u32;
75
76 for (lut_idx, lut) in luts.iter().enumerate() {
77 let base_bit = lut_idx * 4;
79 let byte_idx = base_bit / 8;
80 let bit_offset = base_bit % 8;
81
82 let byte = bits.get(byte_idx).copied().unwrap_or(0);
84 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
85
86 let pattern = if bit_offset <= 4 {
88 (byte >> bit_offset) & 0x0F
89 } else {
90 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
91 };
92
93 dot_sum += lut[pattern as usize] as u32;
94 }
95
96 dot_sum
97}
98
99#[cfg(target_arch = "aarch64")]
103#[inline]
104fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
105 if luts.len() < 8 {
106 return None; }
108
109 let mut total = 0u32;
110 let num_luts = luts.len();
111 let mut lut_idx = 0;
112
113 while lut_idx + 2 <= num_luts {
116 let base_bit0 = lut_idx * 4;
118 let base_bit1 = (lut_idx + 1) * 4;
119
120 let byte_idx0 = base_bit0 / 8;
121 let bit_offset0 = base_bit0 % 8;
122 let byte_idx1 = base_bit1 / 8;
123 let bit_offset1 = base_bit1 % 8;
124
125 let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
126 let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
127 let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
128 let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
129
130 let pattern0 = if bit_offset0 <= 4 {
131 (byte0 >> bit_offset0) & 0x0F
132 } else {
133 ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
134 };
135
136 let pattern1 = if bit_offset1 <= 4 {
137 (byte1 >> bit_offset1) & 0x0F
138 } else {
139 ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
140 };
141
142 total += luts[lut_idx][pattern0 as usize] as u32;
143 total += luts[lut_idx + 1][pattern1 as usize] as u32;
144
145 lut_idx += 2;
146 }
147
148 while lut_idx < num_luts {
150 let base_bit = lut_idx * 4;
151 let byte_idx = base_bit / 8;
152 let bit_offset = base_bit % 8;
153
154 let byte = bits.get(byte_idx).copied().unwrap_or(0);
155 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
156
157 let pattern = if bit_offset <= 4 {
158 (byte >> bit_offset) & 0x0F
159 } else {
160 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
161 };
162
163 total += luts[lut_idx][pattern as usize] as u32;
164 lut_idx += 1;
165 }
166
167 Some(total)
168}
169
170#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "ssse3")]
175#[inline]
176unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
177 if luts.len() < 8 {
178 return None; }
180
181 Some(lut_dot_product_scalar(bits, luts))
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RaBitQConfig {
190 pub dim: usize,
192 pub query_bits: u8,
194 pub seed: u64,
196}
197
198impl RaBitQConfig {
199 pub fn new(dim: usize) -> Self {
200 Self {
201 dim,
202 query_bits: 4,
203 seed: 42,
204 }
205 }
206
207 pub fn with_seed(mut self, seed: u64) -> Self {
208 self.seed = seed;
209 self
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct QuantizedVector {
216 pub bits: Vec<u8>,
218 pub dist_to_centroid: f32,
220 pub self_dot: f32,
222 pub popcount: u32,
224}
225
226impl QuantizedVector {
227 pub fn size_bytes(&self) -> usize {
229 self.bits.len() + 4 + 4 + 4 }
231}
232
233#[derive(Debug, Clone)]
235pub struct QuantizedQuery {
236 pub quantized: Vec<u8>,
238 pub dist_to_centroid: f32,
240 pub lower: f32,
242 pub width: f32,
244 pub sum: u32,
246 pub luts: Vec<[u16; 16]>,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct RaBitQIndex {
253 pub config: RaBitQConfig,
255 pub centroid: Vec<f32>,
257 pub random_signs: Vec<i8>,
260 pub random_perm: Vec<u32>,
261 pub vectors: Vec<QuantizedVector>,
263 pub raw_vectors: Option<Vec<Vec<f32>>>,
265}
266
267impl RaBitQIndex {
268 pub fn new(config: RaBitQConfig) -> Self {
270 let dim = config.dim;
271 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
272
273 let random_signs: Vec<i8> = (0..dim)
275 .map(|_| if rng.random::<bool>() { 1 } else { -1 })
276 .collect();
277
278 let mut random_perm: Vec<u32> = (0..dim as u32).collect();
280 for i in (1..dim).rev() {
281 let j = rng.random_range(0..=i);
282 random_perm.swap(i, j);
283 }
284
285 Self {
286 config,
287 centroid: vec![0.0; dim],
288 random_signs,
289 random_perm,
290 vectors: Vec::new(),
291 raw_vectors: None,
292 }
293 }
294
295 pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
297 let n = vectors.len();
298 let dim = config.dim;
299
300 assert!(n > 0, "Cannot build index from empty vector set");
301 assert!(vectors[0].len() == dim, "Vector dimension mismatch");
302
303 let mut index = Self::new(config);
304
305 index.centroid = vec![0.0; dim];
307 for v in vectors {
308 for (i, &val) in v.iter().enumerate() {
309 index.centroid[i] += val;
310 }
311 }
312 for c in &mut index.centroid {
313 *c /= n as f32;
314 }
315
316 index.vectors = vectors.iter().map(|v| index.quantize_vector(v)).collect();
318
319 if store_raw {
321 index.raw_vectors = Some(vectors.to_vec());
322 }
323
324 index
325 }
326
327 fn quantize_vector(&self, raw: &[f32]) -> QuantizedVector {
329 let dim = self.config.dim;
330
331 let mut centered: Vec<f32> = raw
333 .iter()
334 .zip(&self.centroid)
335 .map(|(&v, &c)| v - c)
336 .collect();
337
338 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
339 let dist_to_centroid = norm;
340
341 if norm > 1e-10 {
343 for x in &mut centered {
344 *x /= norm;
345 }
346 }
347
348 let transformed: Vec<f32> = (0..dim)
350 .map(|i| {
351 let src_idx = self.random_perm[i] as usize;
352 centered[src_idx] * self.random_signs[src_idx] as f32
353 })
354 .collect();
355
356 let num_bytes = dim.div_ceil(8);
358 let mut bits = vec![0u8; num_bytes];
359 let mut popcount = 0u32;
360
361 for i in 0..dim {
362 if transformed[i] >= 0.0 {
363 bits[i / 8] |= 1 << (i % 8);
364 popcount += 1;
365 }
366 }
367
368 let scale = 1.0 / (dim as f32).sqrt();
371 let mut self_dot = 0.0f32;
372 for i in 0..dim {
373 let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
374 scale
375 } else {
376 -scale
377 };
378 self_dot += transformed[i] * o_bar_i;
379 }
380
381 QuantizedVector {
382 bits,
383 dist_to_centroid,
384 self_dot,
385 popcount,
386 }
387 }
388
389 pub fn prepare_query(&self, raw_query: &[f32]) -> QuantizedQuery {
391 let dim = self.config.dim;
392
393 let mut centered: Vec<f32> = raw_query
395 .iter()
396 .zip(&self.centroid)
397 .map(|(&v, &c)| v - c)
398 .collect();
399
400 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
401 let dist_to_centroid = norm;
402
403 if norm > 1e-10 {
405 for x in &mut centered {
406 *x /= norm;
407 }
408 }
409
410 let transformed: Vec<f32> = (0..dim)
412 .map(|i| {
413 let src_idx = self.random_perm[i] as usize;
414 centered[src_idx] * self.random_signs[src_idx] as f32
415 })
416 .collect();
417
418 let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
420 let max_val = transformed
421 .iter()
422 .cloned()
423 .fold(f32::NEG_INFINITY, f32::max);
424 let lower = min_val;
425 let width = if max_val > min_val {
426 max_val - min_val
427 } else {
428 1.0
429 };
430
431 let quantized_vals: Vec<u8> = transformed
433 .iter()
434 .map(|&x| {
435 let normalized = (x - lower) / width;
436 (normalized * 15.0).round().clamp(0.0, 15.0) as u8
437 })
438 .collect();
439
440 let num_bytes = dim.div_ceil(2);
442 let mut quantized = vec![0u8; num_bytes];
443 for i in 0..dim {
444 if i % 2 == 0 {
445 quantized[i / 2] |= quantized_vals[i];
446 } else {
447 quantized[i / 2] |= quantized_vals[i] << 4;
448 }
449 }
450
451 let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
453
454 let num_luts = dim.div_ceil(4);
457 let mut luts = vec![[0u16; 16]; num_luts];
458
459 for (lut_idx, lut) in luts.iter_mut().enumerate() {
460 let base_dim = lut_idx * 4;
461 for pattern in 0u8..16 {
462 let mut dot = 0u16;
463 for bit in 0..4 {
464 let dim_idx = base_dim + bit;
465 if dim_idx < dim && (pattern >> bit) & 1 == 1 {
466 dot += quantized_vals[dim_idx] as u16;
467 }
468 }
469 lut[pattern as usize] = dot;
470 }
471 }
472
473 QuantizedQuery {
474 quantized,
475 dist_to_centroid,
476 lower,
477 width,
478 sum,
479 luts,
480 }
481 }
482
483 pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
490 let qv = &self.vectors[vec_idx];
491 let dim = self.config.dim;
492
493 let dot_sum = lut_dot_product_simd(&qv.bits, &query.luts);
495
496 let scale = 1.0 / (dim as f32).sqrt();
507
508 let sum_positive = qv.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
514
515 let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
517
518 let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
520
521 let q_o_estimate = if qv.self_dot.abs() > 1e-6 {
524 q_obar_dot / qv.self_dot
525 } else {
526 q_obar_dot };
528
529 let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
531
532 let dist_sq = qv.dist_to_centroid * qv.dist_to_centroid
535 + query.dist_to_centroid * query.dist_to_centroid
536 - 2.0 * qv.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
537
538 dist_sq.max(0.0) }
540
541 pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(usize, f32)> {
543 let prepared = self.prepare_query(query);
544
545 let mut candidates: Vec<(usize, f32)> = self
547 .vectors
548 .iter()
549 .enumerate()
550 .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
551 .collect();
552
553 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
555
556 let rerank_count = (k * rerank_factor).min(candidates.len());
558
559 if let Some(ref raw_vectors) = self.raw_vectors {
560 let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
561 .iter()
562 .map(|&(idx, _)| {
563 let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
564 (idx, exact_dist)
565 })
566 .collect();
567
568 reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
569 reranked.truncate(k);
570 reranked
571 } else {
572 candidates.truncate(k);
574 candidates
575 }
576 }
577
578 pub fn len(&self) -> usize {
580 self.vectors.len()
581 }
582
583 pub fn is_empty(&self) -> bool {
585 self.vectors.is_empty()
586 }
587
588 pub fn memory_usage(&self) -> usize {
590 let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
591 let centroid_size = self.centroid.len() * 4;
592 let transform_size = self.random_signs.len() + self.random_perm.len() * 4;
593 let raw_size = self
594 .raw_vectors
595 .as_ref()
596 .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
597 .unwrap_or(0);
598
599 vectors_size + centroid_size + transform_size + raw_size
600 }
601
602 pub fn compression_ratio(&self) -> f32 {
604 if self.vectors.is_empty() {
605 return 1.0;
606 }
607
608 let raw_size = self.vectors.len() * self.config.dim * 4; let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
610
611 raw_size as f32 / compressed_size as f32
612 }
613}
614
615#[inline]
617fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
618 a.iter()
619 .zip(b.iter())
620 .map(|(&x, &y)| {
621 let d = x - y;
622 d * d
623 })
624 .sum()
625}
626
627use rand::SeedableRng;
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_rabitq_basic() {
636 let dim = 128;
637 let n = 100;
638
639 let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
641 let vectors: Vec<Vec<f32>> = (0..n)
642 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
643 .collect();
644
645 let config = RaBitQConfig::new(dim);
647 let index = RaBitQIndex::build(config, &vectors, true);
648
649 assert_eq!(index.len(), n);
650 println!("Compression ratio: {:.1}x", index.compression_ratio());
651 }
652
653 #[test]
654 fn test_rabitq_search() {
655 let dim = 64;
656 let n = 1000;
657 let k = 10;
658
659 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
661 let vectors: Vec<Vec<f32>> = (0..n)
662 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
663 .collect();
664
665 let config = RaBitQConfig::new(dim);
667 let index = RaBitQIndex::build(config, &vectors, true);
668
669 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
671 let results = index.search(&query, k, 10);
672
673 assert_eq!(results.len(), k);
674
675 for i in 1..results.len() {
677 assert!(results[i].1 >= results[i - 1].1);
678 }
679
680 let mut ground_truth: Vec<(usize, f32)> = vectors
682 .iter()
683 .enumerate()
684 .map(|(i, v)| (i, euclidean_distance_squared(&query, v)))
685 .collect();
686 ground_truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
687
688 let gt_set: std::collections::HashSet<usize> =
690 ground_truth[..k].iter().map(|x| x.0).collect();
691 let result_set: std::collections::HashSet<usize> = results.iter().map(|x| x.0).collect();
692 let recall = gt_set.intersection(&result_set).count() as f32 / k as f32;
693
694 println!("Recall@{}: {:.2}", k, recall);
695 assert!(recall >= 0.8, "Recall too low: {}", recall);
696 }
697
698 #[test]
699 fn test_quantized_vector_size() {
700 let dim = 768;
701 let config = RaBitQConfig::new(dim);
702 let index = RaBitQIndex::new(config);
703
704 let raw: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
705 let qv = index.quantize_vector(&raw);
706
707 let expected_bits = dim.div_ceil(8);
709 assert_eq!(qv.bits.len(), expected_bits);
710
711 let total = qv.size_bytes();
713 let raw_size = dim * 4;
714
715 println!(
716 "Raw size: {} bytes, Quantized size: {} bytes",
717 raw_size, total
718 );
719 println!("Compression: {:.1}x", raw_size as f32 / total as f32);
720
721 assert!(raw_size as f32 / total as f32 > 20.0);
723 }
724
725 #[test]
726 fn test_distance_estimation_accuracy() {
727 let dim = 128;
728 let n = 100;
729
730 let mut rng = rand::rngs::StdRng::seed_from_u64(999);
731 let vectors: Vec<Vec<f32>> = (0..n)
732 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
733 .collect();
734
735 let config = RaBitQConfig::new(dim);
736 let index = RaBitQIndex::build(config, &vectors, false);
737
738 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
739 let prepared = index.prepare_query(&query);
740
741 let mut errors = Vec::new();
743 for (i, v) in vectors.iter().enumerate() {
744 let estimated = index.estimate_distance(&prepared, i);
745 let exact = euclidean_distance_squared(&query, v);
746 let error = (estimated - exact).abs() / exact.max(1e-6);
747 errors.push(error);
748 }
749
750 let mean_error: f32 = errors.iter().sum::<f32>() / errors.len() as f32;
751 let max_error = errors.iter().cloned().fold(0.0f32, f32::max);
752
753 println!("Mean relative error: {:.2}%", mean_error * 100.0);
754 println!("Max relative error: {:.2}%", max_error * 100.0);
755
756 assert!(mean_error < 0.5, "Mean error too high: {}", mean_error);
758 }
759}