hermes_core/structures/vector/index/
rabitq.rs1use std::io;
7
8use serde::{Deserialize, Serialize};
9
10use crate::structures::vector::ivf::QuantizedCode;
11use crate::structures::vector::quantization::{
12 QuantizedQuery, QuantizedVector, RaBitQCodebook, RaBitQConfig,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RaBitQIndex {
21 pub codebook: RaBitQCodebook,
23 pub centroid: Vec<f32>,
25 pub vectors: Vec<QuantizedVector>,
27 pub raw_vectors: Option<Vec<Vec<f32>>>,
29}
30
31impl RaBitQIndex {
32 pub fn new(config: RaBitQConfig) -> Self {
34 let dim = config.dim;
35 let codebook = RaBitQCodebook::new(config);
36
37 Self {
38 codebook,
39 centroid: vec![0.0; dim],
40 vectors: Vec::new(),
41 raw_vectors: None,
42 }
43 }
44
45 pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
47 let n = vectors.len();
48 let dim = config.dim;
49
50 assert!(n > 0, "Cannot build index from empty vector set");
51 assert!(vectors[0].len() == dim, "Vector dimension mismatch");
52
53 let mut index = Self::new(config);
54
55 index.centroid = vec![0.0; dim];
57 for v in vectors {
58 for (i, &val) in v.iter().enumerate() {
59 index.centroid[i] += val;
60 }
61 }
62 for c in &mut index.centroid {
63 *c /= n as f32;
64 }
65
66 index.vectors = vectors
68 .iter()
69 .map(|v| index.codebook.encode(v, Some(&index.centroid)))
70 .collect();
71
72 if store_raw {
73 index.raw_vectors = Some(vectors.to_vec());
74 }
75
76 index
77 }
78
79 pub fn prepare_query(&self, query: &[f32]) -> QuantizedQuery {
81 self.codebook.prepare_query(query, Some(&self.centroid))
82 }
83
84 pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
86 self.codebook
87 .estimate_distance(query, &self.vectors[vec_idx])
88 }
89
90 pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(usize, f32)> {
92 let prepared = self.prepare_query(query);
93
94 let mut candidates: Vec<(usize, f32)> = self
96 .vectors
97 .iter()
98 .enumerate()
99 .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
100 .collect();
101
102 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
104
105 let rerank_count = (k * rerank_factor).min(candidates.len());
107
108 if let Some(ref raw_vectors) = self.raw_vectors {
109 let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
110 .iter()
111 .map(|&(idx, _)| {
112 let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
113 (idx, exact_dist)
114 })
115 .collect();
116
117 reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118 reranked.truncate(k);
119 reranked
120 } else {
121 candidates.truncate(k);
122 candidates
123 }
124 }
125
126 pub fn len(&self) -> usize {
128 self.vectors.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
132 self.vectors.is_empty()
133 }
134
135 pub fn size_bytes(&self) -> usize {
137 let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
138 let centroid_size = self.centroid.len() * 4;
139 let codebook_size = self.codebook.size_bytes();
140 let raw_size = self
141 .raw_vectors
142 .as_ref()
143 .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
144 .unwrap_or(0);
145
146 vectors_size + centroid_size + codebook_size + raw_size
147 }
148
149 pub fn compression_ratio(&self) -> f32 {
151 if self.vectors.is_empty() {
152 return 1.0;
153 }
154
155 let dim = self.codebook.config.dim;
156 let raw_size = self.vectors.len() * dim * 4;
157 let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
158
159 raw_size as f32 / compressed_size as f32
160 }
161
162 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
164 serde_json::to_vec(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
165 }
166
167 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
169 serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
170 }
171}
172
173#[inline]
175fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
176 a.iter()
177 .zip(b.iter())
178 .map(|(&x, &y)| {
179 let d = x - y;
180 d * d
181 })
182 .sum()
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use rand::prelude::*;
189
190 #[test]
191 fn test_rabitq_basic() {
192 let dim = 128;
193 let n = 100;
194
195 let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
196 let vectors: Vec<Vec<f32>> = (0..n)
197 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
198 .collect();
199
200 let config = RaBitQConfig::new(dim);
201 let index = RaBitQIndex::build(config, &vectors, true);
202
203 assert_eq!(index.len(), n);
204 println!("Compression ratio: {:.1}x", index.compression_ratio());
205 }
206
207 #[test]
208 fn test_rabitq_search() {
209 let dim = 64;
210 let n = 1000;
211 let k = 10;
212
213 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
214 let vectors: Vec<Vec<f32>> = (0..n)
215 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
216 .collect();
217
218 let config = RaBitQConfig::new(dim);
219 let index = RaBitQIndex::build(config, &vectors, true);
220
221 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
222 let results = index.search(&query, k, 10);
223
224 assert_eq!(results.len(), k);
225
226 for i in 1..results.len() {
228 assert!(results[i].1 >= results[i - 1].1);
229 }
230 }
231}