Skip to main content

hermes_core/structures/vector/index/
rabitq.rs

1//! Standalone RaBitQ index (without IVF)
2//!
3//! For small datasets where IVF overhead isn't worth it.
4//! Uses brute-force search over all quantized vectors.
5
6use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12    QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15/// Standalone RaBitQ index for small datasets
16///
17/// Uses brute-force search over all quantized vectors.
18/// For larger datasets, use `IVFRaBitQIndex` instead.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21    /// RaBitQ codebook (random transform parameters)
22    pub codebook: RaBitQCodebook,
23    /// Centroid of all indexed vectors
24    pub centroid: Vec<f32>,
25    /// Quantized vectors
26    pub vectors: Vec<QuantizedVector>,
27    /// Raw vectors for re-ranking (optional)
28    pub raw_vectors: Option<Vec<Vec<f32>>>,
29}
30
31impl RaBitQIndex {
32    /// Create a new empty RaBitQ index
33    pub fn new(config: RaBitQConfig) -> Self {
34        let dim = config.dim;
35        let codebook = RaBitQCodebook::new(config);
36
37        Self {
38            codebook,
39            centroid: vec![0.0; dim],
40            vectors: Vec::new(),
41            raw_vectors: None,
42        }
43    }
44
45    /// Build index from a set of vectors
46    pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
47        let n = vectors.len();
48        let dim = config.dim;
49
50        assert!(n > 0, "Cannot build index from empty vector set");
51        assert!(vectors[0].len() == dim, "Vector dimension mismatch");
52
53        let mut index = Self::new(config);
54
55        // Compute centroid
56        index.centroid = vec![0.0; dim];
57        for v in vectors {
58            for (i, &val) in v.iter().enumerate() {
59                index.centroid[i] += val;
60            }
61        }
62        for c in &mut index.centroid {
63            *c /= n as f32;
64        }
65
66        // Quantize each vector relative to centroid
67        index.vectors = vectors
68            .iter()
69            .map(|v| index.codebook.encode(v, Some(&index.centroid)))
70            .collect();
71
72        if store_raw {
73            index.raw_vectors = Some(vectors.to_vec());
74        }
75
76        index
77    }
78
79    /// Prepare a query for fast distance estimation
80    pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
81        self.codebook.prepare_query(query, Some(&self.centroid))
82    }
83
84    /// Estimate squared distance between query and a quantized vector
85    pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
86        self.codebook
87            .estimate_distance(query, &self.vectors[vec_idx])
88    }
89
90    /// Search for k nearest neighbors
91    pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(usize, f32)> {
92        let prepared = self.prepare_query(query);
93
94        // Phase 1: Estimate distances for all vectors
95        let mut candidates: Vec<(usize, f32)> = self
96            .vectors
97            .iter()
98            .enumerate()
99            .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
100            .collect();
101
102        // Sort by estimated distance
103        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
104
105        // Phase 2: Re-rank top candidates with exact distances
106        let rerank_count = (k * rerank_factor).min(candidates.len());
107
108        if let Some(ref raw_vectors) = self.raw_vectors {
109            let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
110                .iter()
111                .map(|&(idx, _)| {
112                    let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
113                    (idx, exact_dist)
114                })
115                .collect();
116
117            reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118            reranked.truncate(k);
119            reranked
120        } else {
121            candidates.truncate(k);
122            candidates
123        }
124    }
125
126    /// Number of indexed vectors
127    pub fn len(&self) -> usize {
128        self.vectors.len()
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.vectors.is_empty()
133    }
134
135    /// Memory usage in bytes
136    pub fn size_bytes(&self) -> usize {
137        let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
138        let centroid_size = self.centroid.len() * 4;
139        let codebook_size = self.codebook.size_bytes();
140        let raw_size = self
141            .raw_vectors
142            .as_ref()
143            .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
144            .unwrap_or(0);
145
146        vectors_size + centroid_size + codebook_size + raw_size
147    }
148
149    /// Compression ratio compared to raw float32 vectors
150    pub fn compression_ratio(&self) -> f32 {
151        if self.vectors.is_empty() {
152            return 1.0;
153        }
154
155        let dim = self.codebook.config.dim;
156        let raw_size = self.vectors.len() * dim * 4;
157        let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
158
159        raw_size as f32 / compressed_size as f32
160    }
161
162    /// Serialize to bytes
163    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
164        serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
165    }
166
167    /// Deserialize from bytes
168    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
169        serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
170    }
171}
172
173/// Compute squared Euclidean distance
174#[inline]
175fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
176    a.iter()
177        .zip(b.iter())
178        .map(|(&x, &y)| {
179            let d = x - y;
180            d * d
181        })
182        .sum()
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use rand::prelude::*;
189
190    #[test]
191    fn test_rabitq_basic() {
192        let dim = 128;
193        let n = 100;
194
195        let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
196        let vectors: Vec<Vec<f32>> = (0..n)
197            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
198            .collect();
199
200        let config = RaBitQConfig::new(dim);
201        let index = RaBitQIndex::build(config, &vectors, true);
202
203        assert_eq!(index.len(), n);
204        println!("Compression ratio: {:.1}x", index.compression_ratio());
205    }
206
207    #[test]
208    fn test_rabitq_search() {
209        let dim = 64;
210        let n = 1000;
211        let k = 10;
212
213        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
214        let vectors: Vec<Vec<f32>> = (0..n)
215            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
216            .collect();
217
218        let config = RaBitQConfig::new(dim);
219        let index = RaBitQIndex::build(config, &vectors, true);
220
221        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
222        let results = index.search(&query, k, 10);
223
224        assert_eq!(results.len(), k);
225
226        // Verify results are sorted by distance
227        for i in 1..results.len() {
228            assert!(results[i].1 >= results[i - 1].1);
229        }
230    }
231}