Skip to main content

hermes_core/structures/vector/index/
rabitq.rs

1//! Standalone RaBitQ index (without IVF)
2//!
3//! For small datasets where IVF overhead isn't worth it.
4//! Uses brute-force search over all quantized vectors.
5
6use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12    QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15/// Standalone RaBitQ index for small datasets
16///
17/// Uses brute-force search over all quantized vectors.
18/// For larger datasets, use `IVFRaBitQIndex` instead.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21    /// RaBitQ codebook (random transform parameters)
22    pub codebook: RaBitQCodebook,
23    /// Centroid of all indexed vectors
24    pub centroid: Vec<f32>,
25    /// Document IDs
26    pub doc_ids: Vec<u32>,
27    /// Element ordinals for multi-valued fields (0 for single-valued)
28    pub ordinals: Vec<u16>,
29    /// Quantized vectors
30    pub vectors: Vec<QuantizedVector>,
31}
32
33impl RaBitQIndex {
34    /// Create a new empty RaBitQ index
35    pub fn new(config: RaBitQConfig) -> Self {
36        let dim = config.dim;
37        let codebook = RaBitQCodebook::new(config);
38
39        Self {
40            codebook,
41            centroid: vec![0.0; dim],
42            doc_ids: Vec::new(),
43            ordinals: Vec::new(),
44            vectors: Vec::new(),
45        }
46    }
47
48    /// Build index from vectors with doc IDs and ordinals
49    pub fn build_with_ids(
50        config: RaBitQConfig,
51        vectors: &[(u32, u16, Vec<f32>)], // (doc_id, ordinal, vector)
52    ) -> Self {
53        let n = vectors.len();
54        let dim = config.dim;
55
56        assert!(n > 0, "Cannot build index from empty vector set");
57        assert!(vectors[0].2.len() == dim, "Vector dimension mismatch");
58
59        let mut index = Self::new(config);
60
61        // Compute centroid
62        index.centroid = vec![0.0; dim];
63        for (_, _, v) in vectors {
64            for (i, &val) in v.iter().enumerate() {
65                index.centroid[i] += val;
66            }
67        }
68        for c in &mut index.centroid {
69            *c /= n as f32;
70        }
71
72        // Store doc_ids, ordinals and quantize vectors
73        index.doc_ids = vectors.iter().map(|(doc_id, _, _)| *doc_id).collect();
74        index.ordinals = vectors.iter().map(|(_, ordinal, _)| *ordinal).collect();
75        index.vectors = vectors
76            .iter()
77            .map(|(_, _, v)| index.codebook.encode(v, Some(&index.centroid)))
78            .collect();
79
80        index
81    }
82
83    /// Build index from a set of vectors (legacy, uses doc_id = index, ordinal = 0)
84    pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>]) -> Self {
85        let with_ids: Vec<(u32, u16, Vec<f32>)> = vectors
86            .iter()
87            .enumerate()
88            .map(|(i, v)| (i as u32, 0u16, v.clone()))
89            .collect();
90        Self::build_with_ids(config, &with_ids)
91    }
92
93    /// Add a single vector to the index
94    pub fn add_vector(&mut self, doc_id: u32, ordinal: u16, vector: &[f32]) {
95        self.doc_ids.push(doc_id);
96        self.ordinals.push(ordinal);
97        self.vectors
98            .push(self.codebook.encode(vector, Some(&self.centroid)));
99    }
100
101    /// Prepare a query for fast distance estimation
102    pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
103        self.codebook.prepare_query(query, Some(&self.centroid))
104    }
105
106    /// Estimate squared distance between query and a quantized vector
107    pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
108        self.codebook
109            .estimate_distance(query, &self.vectors[vec_idx])
110    }
111
112    /// Search for k nearest neighbors, returns (doc_id, ordinal, distance)
113    pub fn search(&self, query: &[f32], k: usize) -> Vec<(u32, u16, f32)> {
114        let prepared = self.prepare_query(query);
115
116        // Phase 1: Estimate distances for all vectors
117        let mut candidates: Vec<(usize, f32)> = self
118            .vectors
119            .iter()
120            .enumerate()
121            .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
122            .collect();
123
124        // Partial sort: O(n + k log k) instead of O(n log n)
125        if candidates.len() > k {
126            candidates.select_nth_unstable_by(k, |a, b| a.1.total_cmp(&b.1));
127            candidates.truncate(k);
128        }
129        candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
130
131        // Map indices to (doc_id, ordinal, dist)
132        candidates
133            .into_iter()
134            .map(|(idx, dist)| (self.doc_ids[idx], self.ordinals[idx], dist))
135            .collect()
136    }
137
138    /// Number of indexed vectors
139    pub fn len(&self) -> usize {
140        self.vectors.len()
141    }
142
143    pub fn is_empty(&self) -> bool {
144        self.vectors.is_empty()
145    }
146
147    /// Memory usage in bytes
148    pub fn size_bytes(&self) -> usize {
149        use std::mem::size_of;
150
151        let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
152        let centroid_size = self.centroid.len() * size_of::<f32>();
153        let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
154        let ordinals_size = self.ordinals.len() * size_of::<u16>();
155        let codebook_size = self.codebook.size_bytes();
156        vectors_size + centroid_size + doc_ids_size + ordinals_size + codebook_size
157    }
158
159    /// Estimated memory usage in bytes (alias for size_bytes)
160    pub fn estimated_memory_bytes(&self) -> usize {
161        self.size_bytes()
162    }
163
164    /// Compression ratio compared to raw float32 vectors
165    pub fn compression_ratio(&self) -> f32 {
166        if self.vectors.is_empty() {
167            return 1.0;
168        }
169
170        let dim = self.codebook.config.dim;
171        let raw_size = self.vectors.len() * dim * 4;
172        let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
173
174        raw_size as f32 / compressed_size as f32
175    }
176
177    /// Serialize to bytes
178    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
179        serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
180    }
181
182    /// Deserialize from bytes
183    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
184        serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use rand::prelude::*;
192
193    #[test]
194    fn test_rabitq_basic() {
195        let dim = 128;
196        let n = 100;
197
198        let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
199        let vectors: Vec<Vec<f32>> = (0..n)
200            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
201            .collect();
202
203        let config = RaBitQConfig::new(dim);
204        let index = RaBitQIndex::build(config, &vectors);
205
206        assert_eq!(index.len(), n);
207        println!("Compression ratio: {:.1}x", index.compression_ratio());
208    }
209
210    #[test]
211    fn test_rabitq_search() {
212        let dim = 64;
213        let n = 1000;
214        let k = 10;
215
216        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
217        let vectors: Vec<Vec<f32>> = (0..n)
218            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
219            .collect();
220
221        let config = RaBitQConfig::new(dim);
222        let index = RaBitQIndex::build(config, &vectors);
223
224        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
225        let results = index.search(&query, k);
226
227        assert_eq!(results.len(), k);
228
229        // Verify results are sorted by distance
230        for i in 1..results.len() {
231            assert!(results[i].2 >= results[i - 1].2);
232        }
233    }
234}