hermes_core/structures/vector/index/
rabitq.rs1use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12 QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21 pub codebook: RaBitQCodebook,
23 pub centroid: Vec<f32>,
25 pub doc_ids: Vec<u32>,
27 pub ordinals: Vec<u16>,
29 pub vectors: Vec<QuantizedVector>,
31 pub raw_vectors: Option<Vec<Vec<f32>>>,
33}
34
35impl RaBitQIndex {
36 pub fn new(config: RaBitQConfig) -> Self {
38 let dim = config.dim;
39 let codebook = RaBitQCodebook::new(config);
40
41 Self {
42 codebook,
43 centroid: vec![0.0; dim],
44 doc_ids: Vec::new(),
45 ordinals: Vec::new(),
46 vectors: Vec::new(),
47 raw_vectors: None,
48 }
49 }
50
51 pub fn build_with_ids(
53 config: RaBitQConfig,
54 vectors: &[(u32, u16, Vec<f32>)], store_raw: bool,
56 ) -> Self {
57 let n = vectors.len();
58 let dim = config.dim;
59
60 assert!(n > 0, "Cannot build index from empty vector set");
61 assert!(vectors[0].2.len() == dim, "Vector dimension mismatch");
62
63 let mut index = Self::new(config);
64
65 index.centroid = vec![0.0; dim];
67 for (_, _, v) in vectors {
68 for (i, &val) in v.iter().enumerate() {
69 index.centroid[i] += val;
70 }
71 }
72 for c in &mut index.centroid {
73 *c /= n as f32;
74 }
75
76 index.doc_ids = vectors.iter().map(|(doc_id, _, _)| *doc_id).collect();
78 index.ordinals = vectors.iter().map(|(_, ordinal, _)| *ordinal).collect();
79 index.vectors = vectors
80 .iter()
81 .map(|(_, _, v)| index.codebook.encode(v, Some(&index.centroid)))
82 .collect();
83
84 if store_raw {
85 index.raw_vectors = Some(vectors.iter().map(|(_, _, v)| v.clone()).collect());
86 }
87
88 index
89 }
90
91 pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
93 let with_ids: Vec<(u32, u16, Vec<f32>)> = vectors
94 .iter()
95 .enumerate()
96 .map(|(i, v)| (i as u32, 0u16, v.clone()))
97 .collect();
98 Self::build_with_ids(config, &with_ids, store_raw)
99 }
100
101 pub fn add_vector(&mut self, doc_id: u32, ordinal: u16, vector: &[f32], raw: Option<Vec<f32>>) {
103 self.doc_ids.push(doc_id);
104 self.ordinals.push(ordinal);
105 self.vectors
106 .push(self.codebook.encode(vector, Some(&self.centroid)));
107 if let Some(ref mut raw_vectors) = self.raw_vectors
108 && let Some(r) = raw
109 {
110 raw_vectors.push(r);
111 }
112 }
113
114 pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
116 self.codebook.prepare_query(query, Some(&self.centroid))
117 }
118
119 pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
121 self.codebook
122 .estimate_distance(query, &self.vectors[vec_idx])
123 }
124
125 pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(u32, u16, f32)> {
127 let prepared = self.prepare_query(query);
128
129 let mut candidates: Vec<(usize, f32)> = self
131 .vectors
132 .iter()
133 .enumerate()
134 .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
135 .collect();
136
137 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
139
140 let rerank_count = (k * rerank_factor).min(candidates.len());
142
143 let results = if let Some(ref raw_vectors) = self.raw_vectors {
144 let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
145 .iter()
146 .map(|&(idx, _)| {
147 let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
148 (idx, exact_dist)
149 })
150 .collect();
151
152 reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
153 reranked.truncate(k);
154 reranked
155 } else {
156 candidates.truncate(k);
157 candidates
158 };
159
160 results
162 .into_iter()
163 .map(|(idx, dist)| (self.doc_ids[idx], self.ordinals[idx], dist))
164 .collect()
165 }
166
167 pub fn len(&self) -> usize {
169 self.vectors.len()
170 }
171
172 pub fn is_empty(&self) -> bool {
173 self.vectors.is_empty()
174 }
175
176 pub fn size_bytes(&self) -> usize {
178 use std::mem::size_of;
179
180 let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
181 let centroid_size = self.centroid.len() * size_of::<f32>();
182 let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
183 let ordinals_size = self.ordinals.len() * size_of::<u16>();
184 let codebook_size = self.codebook.size_bytes();
185 let raw_size = self
186 .raw_vectors
187 .as_ref()
188 .map(|vecs| vecs.iter().map(|v| v.len() * size_of::<f32>()).sum())
189 .unwrap_or(0);
190
191 vectors_size + centroid_size + doc_ids_size + ordinals_size + codebook_size + raw_size
192 }
193
194 pub fn compression_ratio(&self) -> f32 {
196 if self.vectors.is_empty() {
197 return 1.0;
198 }
199
200 let dim = self.codebook.config.dim;
201 let raw_size = self.vectors.len() * dim * 4;
202 let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
203
204 raw_size as f32 / compressed_size as f32
205 }
206
207 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
209 serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
210 }
211
212 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
214 serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
215 }
216}
217
218#[inline]
220fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
221 a.iter()
222 .zip(b.iter())
223 .map(|(&x, &y)| {
224 let d = x - y;
225 d * d
226 })
227 .sum()
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use rand::prelude::*;
234
235 #[test]
236 fn test_rabitq_basic() {
237 let dim = 128;
238 let n = 100;
239
240 let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
241 let vectors: Vec<Vec<f32>> = (0..n)
242 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
243 .collect();
244
245 let config = RaBitQConfig::new(dim);
246 let index = RaBitQIndex::build(config, &vectors, true);
247
248 assert_eq!(index.len(), n);
249 println!("Compression ratio: {:.1}x", index.compression_ratio());
250 }
251
252 #[test]
253 fn test_rabitq_search() {
254 let dim = 64;
255 let n = 1000;
256 let k = 10;
257
258 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
259 let vectors: Vec<Vec<f32>> = (0..n)
260 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
261 .collect();
262
263 let config = RaBitQConfig::new(dim);
264 let index = RaBitQIndex::build(config, &vectors, true);
265
266 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
267 let results = index.search(&query, k, 10);
268
269 assert_eq!(results.len(), k);
270
271 for i in 1..results.len() {
273 assert!(results[i].1 >= results[i - 1].1);
274 }
275 }
276}