hermes_core/structures/vector/index/
rabitq.rs1use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12 QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21 pub codebook: RaBitQCodebook,
23 pub centroid: Vec<f32>,
25 pub doc_ids: Vec<u32>,
27 pub ordinals: Vec<u16>,
29 pub vectors: Vec<QuantizedVector>,
31}
32
33impl RaBitQIndex {
34 pub fn new(config: RaBitQConfig) -> Self {
36 let dim = config.dim;
37 let codebook = RaBitQCodebook::new(config);
38
39 Self {
40 codebook,
41 centroid: vec![0.0; dim],
42 doc_ids: Vec::new(),
43 ordinals: Vec::new(),
44 vectors: Vec::new(),
45 }
46 }
47
48 pub fn build_with_ids(
50 config: RaBitQConfig,
51 vectors: &[(u32, u16, Vec<f32>)], ) -> Self {
53 let n = vectors.len();
54 let dim = config.dim;
55
56 assert!(n > 0, "Cannot build index from empty vector set");
57 assert!(vectors[0].2.len() == dim, "Vector dimension mismatch");
58
59 let mut index = Self::new(config);
60
61 index.centroid = vec![0.0; dim];
63 for (_, _, v) in vectors {
64 for (i, &val) in v.iter().enumerate() {
65 index.centroid[i] += val;
66 }
67 }
68 for c in &mut index.centroid {
69 *c /= n as f32;
70 }
71
72 index.doc_ids = vectors.iter().map(|(doc_id, _, _)| *doc_id).collect();
74 index.ordinals = vectors.iter().map(|(_, ordinal, _)| *ordinal).collect();
75 index.vectors = vectors
76 .iter()
77 .map(|(_, _, v)| index.codebook.encode(v, Some(&index.centroid)))
78 .collect();
79
80 index
81 }
82
83 pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>]) -> Self {
85 let with_ids: Vec<(u32, u16, Vec<f32>)> = vectors
86 .iter()
87 .enumerate()
88 .map(|(i, v)| (i as u32, 0u16, v.clone()))
89 .collect();
90 Self::build_with_ids(config, &with_ids)
91 }
92
93 pub fn add_vector(&mut self, doc_id: u32, ordinal: u16, vector: &[f32]) {
95 self.doc_ids.push(doc_id);
96 self.ordinals.push(ordinal);
97 self.vectors
98 .push(self.codebook.encode(vector, Some(&self.centroid)));
99 }
100
101 pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
103 self.codebook.prepare_query(query, Some(&self.centroid))
104 }
105
106 pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
108 self.codebook
109 .estimate_distance(query, &self.vectors[vec_idx])
110 }
111
112 pub fn search(&self, query: &[f32], k: usize, _rerank_factor: usize) -> Vec<(u32, u16, f32)> {
114 let prepared = self.prepare_query(query);
115
116 let mut candidates: Vec<(usize, f32)> = self
118 .vectors
119 .iter()
120 .enumerate()
121 .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
122 .collect();
123
124 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
126
127 candidates.truncate(k);
128
129 candidates
131 .into_iter()
132 .map(|(idx, dist)| (self.doc_ids[idx], self.ordinals[idx], dist))
133 .collect()
134 }
135
136 pub fn len(&self) -> usize {
138 self.vectors.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
142 self.vectors.is_empty()
143 }
144
145 pub fn size_bytes(&self) -> usize {
147 use std::mem::size_of;
148
149 let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
150 let centroid_size = self.centroid.len() * size_of::<f32>();
151 let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
152 let ordinals_size = self.ordinals.len() * size_of::<u16>();
153 let codebook_size = self.codebook.size_bytes();
154 vectors_size + centroid_size + doc_ids_size + ordinals_size + codebook_size
155 }
156
157 pub fn estimated_memory_bytes(&self) -> usize {
159 self.size_bytes()
160 }
161
162 pub fn compression_ratio(&self) -> f32 {
164 if self.vectors.is_empty() {
165 return 1.0;
166 }
167
168 let dim = self.codebook.config.dim;
169 let raw_size = self.vectors.len() * dim * 4;
170 let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
171
172 raw_size as f32 / compressed_size as f32
173 }
174
175 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
177 serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
178 }
179
180 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
182 serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use rand::prelude::*;
190
191 #[test]
192 fn test_rabitq_basic() {
193 let dim = 128;
194 let n = 100;
195
196 let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
197 let vectors: Vec<Vec<f32>> = (0..n)
198 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
199 .collect();
200
201 let config = RaBitQConfig::new(dim);
202 let index = RaBitQIndex::build(config, &vectors);
203
204 assert_eq!(index.len(), n);
205 println!("Compression ratio: {:.1}x", index.compression_ratio());
206 }
207
208 #[test]
209 fn test_rabitq_search() {
210 let dim = 64;
211 let n = 1000;
212 let k = 10;
213
214 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
215 let vectors: Vec<Vec<f32>> = (0..n)
216 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
217 .collect();
218
219 let config = RaBitQConfig::new(dim);
220 let index = RaBitQIndex::build(config, &vectors);
221
222 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
223 let results = index.search(&query, k, 10);
224
225 assert_eq!(results.len(), k);
226
227 for i in 1..results.len() {
229 assert!(results[i].1 >= results[i - 1].1);
230 }
231 }
232}