Skip to main content

hermes_core/structures/vector/index/
rabitq.rs

1//! Standalone RaBitQ index (without IVF)
2//!
3//! For small datasets where IVF overhead isn't worth it.
4//! Uses brute-force search over all quantized vectors.
5
6use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12    QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15/// Standalone RaBitQ index for small datasets
16///
17/// Uses brute-force search over all quantized vectors.
18/// For larger datasets, use `IVFRaBitQIndex` instead.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21    /// RaBitQ codebook (random transform parameters)
22    pub codebook: RaBitQCodebook,
23    /// Centroid of all indexed vectors
24    pub centroid: Vec<f32>,
25    /// Document IDs
26    pub doc_ids: Vec<u32>,
27    /// Element ordinals for multi-valued fields (0 for single-valued)
28    pub ordinals: Vec<u16>,
29    /// Quantized vectors
30    pub vectors: Vec<QuantizedVector>,
31}
32
33impl RaBitQIndex {
34    /// Create a new empty RaBitQ index
35    pub fn new(config: RaBitQConfig) -> Self {
36        let dim = config.dim;
37        let codebook = RaBitQCodebook::new(config);
38
39        Self {
40            codebook,
41            centroid: vec![0.0; dim],
42            doc_ids: Vec::new(),
43            ordinals: Vec::new(),
44            vectors: Vec::new(),
45        }
46    }
47
48    /// Build index from vectors with doc IDs and ordinals
49    pub fn build_with_ids(
50        config: RaBitQConfig,
51        vectors: &[(u32, u16, Vec<f32>)], // (doc_id, ordinal, vector)
52    ) -> Self {
53        let n = vectors.len();
54        let dim = config.dim;
55
56        assert!(n > 0, "Cannot build index from empty vector set");
57        assert!(vectors[0].2.len() == dim, "Vector dimension mismatch");
58
59        let mut index = Self::new(config);
60
61        // Compute centroid
62        index.centroid = vec![0.0; dim];
63        for (_, _, v) in vectors {
64            for (i, &val) in v.iter().enumerate() {
65                index.centroid[i] += val;
66            }
67        }
68        for c in &mut index.centroid {
69            *c /= n as f32;
70        }
71
72        // Store doc_ids, ordinals and quantize vectors
73        index.doc_ids = vectors.iter().map(|(doc_id, _, _)| *doc_id).collect();
74        index.ordinals = vectors.iter().map(|(_, ordinal, _)| *ordinal).collect();
75        index.vectors = vectors
76            .iter()
77            .map(|(_, _, v)| index.codebook.encode(v, Some(&index.centroid)))
78            .collect();
79
80        index
81    }
82
83    /// Build index from a set of vectors (legacy, uses doc_id = index, ordinal = 0)
84    pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>]) -> Self {
85        let with_ids: Vec<(u32, u16, Vec<f32>)> = vectors
86            .iter()
87            .enumerate()
88            .map(|(i, v)| (i as u32, 0u16, v.clone()))
89            .collect();
90        Self::build_with_ids(config, &with_ids)
91    }
92
93    /// Add a single vector to the index
94    pub fn add_vector(&mut self, doc_id: u32, ordinal: u16, vector: &[f32]) {
95        self.doc_ids.push(doc_id);
96        self.ordinals.push(ordinal);
97        self.vectors
98            .push(self.codebook.encode(vector, Some(&self.centroid)));
99    }
100
101    /// Prepare a query for fast distance estimation
102    pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
103        self.codebook.prepare_query(query, Some(&self.centroid))
104    }
105
106    /// Estimate squared distance between query and a quantized vector
107    pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
108        self.codebook
109            .estimate_distance(query, &self.vectors[vec_idx])
110    }
111
112    /// Search for k nearest neighbors, returns (doc_id, ordinal, distance)
113    pub fn search(&self, query: &[f32], k: usize, _rerank_factor: usize) -> Vec<(u32, u16, f32)> {
114        let prepared = self.prepare_query(query);
115
116        // Phase 1: Estimate distances for all vectors
117        let mut candidates: Vec<(usize, f32)> = self
118            .vectors
119            .iter()
120            .enumerate()
121            .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
122            .collect();
123
124        // Sort by estimated distance
125        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
126
127        candidates.truncate(k);
128
129        // Map indices to (doc_id, ordinal, dist)
130        candidates
131            .into_iter()
132            .map(|(idx, dist)| (self.doc_ids[idx], self.ordinals[idx], dist))
133            .collect()
134    }
135
136    /// Number of indexed vectors
137    pub fn len(&self) -> usize {
138        self.vectors.len()
139    }
140
141    pub fn is_empty(&self) -> bool {
142        self.vectors.is_empty()
143    }
144
145    /// Memory usage in bytes
146    pub fn size_bytes(&self) -> usize {
147        use std::mem::size_of;
148
149        let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
150        let centroid_size = self.centroid.len() * size_of::<f32>();
151        let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
152        let ordinals_size = self.ordinals.len() * size_of::<u16>();
153        let codebook_size = self.codebook.size_bytes();
154        vectors_size + centroid_size + doc_ids_size + ordinals_size + codebook_size
155    }
156
157    /// Estimated memory usage in bytes (alias for size_bytes)
158    pub fn estimated_memory_bytes(&self) -> usize {
159        self.size_bytes()
160    }
161
162    /// Compression ratio compared to raw float32 vectors
163    pub fn compression_ratio(&self) -> f32 {
164        if self.vectors.is_empty() {
165            return 1.0;
166        }
167
168        let dim = self.codebook.config.dim;
169        let raw_size = self.vectors.len() * dim * 4;
170        let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
171
172        raw_size as f32 / compressed_size as f32
173    }
174
175    /// Serialize to bytes
176    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
177        serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
178    }
179
180    /// Deserialize from bytes
181    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
182        serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use rand::prelude::*;
190
191    #[test]
192    fn test_rabitq_basic() {
193        let dim = 128;
194        let n = 100;
195
196        let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
197        let vectors: Vec<Vec<f32>> = (0..n)
198            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
199            .collect();
200
201        let config = RaBitQConfig::new(dim);
202        let index = RaBitQIndex::build(config, &vectors);
203
204        assert_eq!(index.len(), n);
205        println!("Compression ratio: {:.1}x", index.compression_ratio());
206    }
207
208    #[test]
209    fn test_rabitq_search() {
210        let dim = 64;
211        let n = 1000;
212        let k = 10;
213
214        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
215        let vectors: Vec<Vec<f32>> = (0..n)
216            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
217            .collect();
218
219        let config = RaBitQConfig::new(dim);
220        let index = RaBitQIndex::build(config, &vectors);
221
222        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
223        let results = index.search(&query, k, 10);
224
225        assert_eq!(results.len(), k);
226
227        // Verify results are sorted by distance
228        for i in 1..results.len() {
229            assert!(results[i].1 >= results[i - 1].1);
230        }
231    }
232}