Skip to main content

hermes_core/structures/vector/index/
rabitq.rs

1//! Standalone RaBitQ index (without IVF)
2//!
3//! For small datasets where IVF overhead isn't worth it.
4//! Uses brute-force search over all quantized vectors.
5
6use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12    QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15/// Standalone RaBitQ index for small datasets
16///
17/// Uses brute-force search over all quantized vectors.
18/// For larger datasets, use `IVFRaBitQIndex` instead.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21    /// RaBitQ codebook (random transform parameters)
22    pub codebook: RaBitQCodebook,
23    /// Centroid of all indexed vectors
24    pub centroid: Vec<f32>,
25    /// Document IDs
26    pub doc_ids: Vec<u32>,
27    /// Element ordinals for multi-valued fields (0 for single-valued)
28    pub ordinals: Vec<u16>,
29    /// Quantized vectors
30    pub vectors: Vec<QuantizedVector>,
31    /// Raw vectors for re-ranking (optional)
32    pub raw_vectors: Option<Vec<Vec<f32>>>,
33}
34
35impl RaBitQIndex {
36    /// Create a new empty RaBitQ index
37    pub fn new(config: RaBitQConfig) -> Self {
38        let dim = config.dim;
39        let codebook = RaBitQCodebook::new(config);
40
41        Self {
42            codebook,
43            centroid: vec![0.0; dim],
44            doc_ids: Vec::new(),
45            ordinals: Vec::new(),
46            vectors: Vec::new(),
47            raw_vectors: None,
48        }
49    }
50
51    /// Build index from vectors with doc IDs and ordinals
52    pub fn build_with_ids(
53        config: RaBitQConfig,
54        vectors: &[(u32, u16, Vec<f32>)], // (doc_id, ordinal, vector)
55        store_raw: bool,
56    ) -> Self {
57        let n = vectors.len();
58        let dim = config.dim;
59
60        assert!(n > 0, "Cannot build index from empty vector set");
61        assert!(vectors[0].2.len() == dim, "Vector dimension mismatch");
62
63        let mut index = Self::new(config);
64
65        // Compute centroid
66        index.centroid = vec![0.0; dim];
67        for (_, _, v) in vectors {
68            for (i, &val) in v.iter().enumerate() {
69                index.centroid[i] += val;
70            }
71        }
72        for c in &mut index.centroid {
73            *c /= n as f32;
74        }
75
76        // Store doc_ids, ordinals and quantize vectors
77        index.doc_ids = vectors.iter().map(|(doc_id, _, _)| *doc_id).collect();
78        index.ordinals = vectors.iter().map(|(_, ordinal, _)| *ordinal).collect();
79        index.vectors = vectors
80            .iter()
81            .map(|(_, _, v)| index.codebook.encode(v, Some(&index.centroid)))
82            .collect();
83
84        if store_raw {
85            index.raw_vectors = Some(vectors.iter().map(|(_, _, v)| v.clone()).collect());
86        }
87
88        index
89    }
90
91    /// Build index from a set of vectors (legacy, uses doc_id = index, ordinal = 0)
92    pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
93        let with_ids: Vec<(u32, u16, Vec<f32>)> = vectors
94            .iter()
95            .enumerate()
96            .map(|(i, v)| (i as u32, 0u16, v.clone()))
97            .collect();
98        Self::build_with_ids(config, &with_ids, store_raw)
99    }
100
101    /// Add a single vector to the index
102    pub fn add_vector(&mut self, doc_id: u32, ordinal: u16, vector: &[f32], raw: Option<Vec<f32>>) {
103        self.doc_ids.push(doc_id);
104        self.ordinals.push(ordinal);
105        self.vectors
106            .push(self.codebook.encode(vector, Some(&self.centroid)));
107        if let Some(ref mut raw_vectors) = self.raw_vectors
108            && let Some(r) = raw
109        {
110            raw_vectors.push(r);
111        }
112    }
113
114    /// Prepare a query for fast distance estimation
115    pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
116        self.codebook.prepare_query(query, Some(&self.centroid))
117    }
118
119    /// Estimate squared distance between query and a quantized vector
120    pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
121        self.codebook
122            .estimate_distance(query, &self.vectors[vec_idx])
123    }
124
125    /// Search for k nearest neighbors, returns (doc_id, ordinal, distance)
126    pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(u32, u16, f32)> {
127        let prepared = self.prepare_query(query);
128
129        // Phase 1: Estimate distances for all vectors
130        let mut candidates: Vec<(usize, f32)> = self
131            .vectors
132            .iter()
133            .enumerate()
134            .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
135            .collect();
136
137        // Sort by estimated distance
138        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
139
140        // Phase 2: Re-rank top candidates with exact distances
141        let rerank_count = (k * rerank_factor).min(candidates.len());
142
143        let results = if let Some(ref raw_vectors) = self.raw_vectors {
144            let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
145                .iter()
146                .map(|&(idx, _)| {
147                    let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
148                    (idx, exact_dist)
149                })
150                .collect();
151
152            reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
153            reranked.truncate(k);
154            reranked
155        } else {
156            candidates.truncate(k);
157            candidates
158        };
159
160        // Map indices to (doc_id, ordinal, dist)
161        results
162            .into_iter()
163            .map(|(idx, dist)| (self.doc_ids[idx], self.ordinals[idx], dist))
164            .collect()
165    }
166
167    /// Number of indexed vectors
168    pub fn len(&self) -> usize {
169        self.vectors.len()
170    }
171
172    pub fn is_empty(&self) -> bool {
173        self.vectors.is_empty()
174    }
175
176    /// Memory usage in bytes
177    pub fn size_bytes(&self) -> usize {
178        use std::mem::size_of;
179
180        let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
181        let centroid_size = self.centroid.len() * size_of::<f32>();
182        let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
183        let ordinals_size = self.ordinals.len() * size_of::<u16>();
184        let codebook_size = self.codebook.size_bytes();
185        let raw_size = self
186            .raw_vectors
187            .as_ref()
188            .map(|vecs| vecs.iter().map(|v| v.len() * size_of::<f32>()).sum())
189            .unwrap_or(0);
190
191        vectors_size + centroid_size + doc_ids_size + ordinals_size + codebook_size + raw_size
192    }
193
194    /// Compression ratio compared to raw float32 vectors
195    pub fn compression_ratio(&self) -> f32 {
196        if self.vectors.is_empty() {
197            return 1.0;
198        }
199
200        let dim = self.codebook.config.dim;
201        let raw_size = self.vectors.len() * dim * 4;
202        let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
203
204        raw_size as f32 / compressed_size as f32
205    }
206
207    /// Serialize to bytes
208    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
209        serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
210    }
211
212    /// Deserialize from bytes
213    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
214        serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
215    }
216}
217
218/// Compute squared Euclidean distance
219#[inline]
220fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
221    a.iter()
222        .zip(b.iter())
223        .map(|(&x, &y)| {
224            let d = x - y;
225            d * d
226        })
227        .sum()
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use rand::prelude::*;
234
235    #[test]
236    fn test_rabitq_basic() {
237        let dim = 128;
238        let n = 100;
239
240        let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
241        let vectors: Vec<Vec<f32>> = (0..n)
242            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
243            .collect();
244
245        let config = RaBitQConfig::new(dim);
246        let index = RaBitQIndex::build(config, &vectors, true);
247
248        assert_eq!(index.len(), n);
249        println!("Compression ratio: {:.1}x", index.compression_ratio());
250    }
251
252    #[test]
253    fn test_rabitq_search() {
254        let dim = 64;
255        let n = 1000;
256        let k = 10;
257
258        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
259        let vectors: Vec<Vec<f32>> = (0..n)
260            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
261            .collect();
262
263        let config = RaBitQConfig::new(dim);
264        let index = RaBitQIndex::build(config, &vectors, true);
265
266        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
267        let results = index.search(&query, k, 10);
268
269        assert_eq!(results.len(), k);
270
271        // Verify results are sorted by distance
272        for i in 1..results.len() {
273            assert!(results[i].1 >= results[i - 1].1);
274        }
275    }
276}