1use crate::retrieval::{RerankedResult, SearchResult};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct IndexConfig {
17 pub metric: SimilarityMetric,
19 pub hierarchical: bool,
21 pub leaf_size: usize,
23}
24
25impl Default for IndexConfig {
26 fn default() -> Self {
27 Self {
28 metric: SimilarityMetric::Cosine,
29 hierarchical: false,
30 leaf_size: 1000,
31 }
32 }
33}
34
35pub trait RetrievalIndex {
37 fn add(&mut self, id: usize, vec: &SparseVec);
39
40 fn finalize(&mut self);
42
43 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult>;
45
46 fn query_top_k_reranked(
48 &self,
49 query: &SparseVec,
50 vectors: &HashMap<usize, SparseVec>,
51 candidate_k: usize,
52 k: usize,
53 ) -> Vec<RerankedResult>;
54}
55
56#[derive(Clone, Debug)]
63pub struct BruteForceIndex {
64 vectors: HashMap<usize, SparseVec>,
65 config: IndexConfig,
66}
67
68impl BruteForceIndex {
69 pub fn new(config: IndexConfig) -> Self {
70 Self {
71 vectors: HashMap::new(),
72 config,
73 }
74 }
75
76 pub fn build_from_map(vectors: HashMap<usize, SparseVec>, config: IndexConfig) -> Self {
78 Self { vectors, config }
79 }
80}
81
82impl RetrievalIndex for BruteForceIndex {
83 fn add(&mut self, id: usize, vec: &SparseVec) {
84 self.vectors.insert(id, vec.clone());
85 }
86
87 fn finalize(&mut self) {
88 }
90
91 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
92 if k == 0 || self.vectors.is_empty() {
93 return Vec::new();
94 }
95
96 let mut results: Vec<SearchResult> = self
97 .vectors
98 .iter()
99 .map(|(id, vec)| {
100 let score = (compute_similarity(query, vec, self.config.metric) * 1000.0) as i32;
101 SearchResult { id: *id, score }
102 })
103 .collect();
104
105 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
106 results.truncate(k);
107 results
108 }
109
110 fn query_top_k_reranked(
111 &self,
112 query: &SparseVec,
113 _vectors: &HashMap<usize, SparseVec>,
114 _candidate_k: usize,
115 k: usize,
116 ) -> Vec<RerankedResult> {
117 if k == 0 || self.vectors.is_empty() {
118 return Vec::new();
119 }
120
121 let mut results: Vec<RerankedResult> = self
122 .vectors
123 .iter()
124 .map(|(id, vec)| {
125 let cosine = query.cosine(vec);
126 let approx_score = (cosine * 1000.0) as i32;
127 RerankedResult {
128 id: *id,
129 approx_score,
130 cosine,
131 }
132 })
133 .collect();
134
135 results.sort_by(|a, b| {
136 b.cosine
137 .partial_cmp(&a.cosine)
138 .unwrap_or(std::cmp::Ordering::Equal)
139 .then_with(|| a.id.cmp(&b.id))
140 });
141 results.truncate(k);
142 results
143 }
144}
145
146#[derive(Clone, Debug)]
151pub struct HierarchicalIndex {
152 clusters: Vec<Vec<SparseVec>>,
154 cluster_members: Vec<Vec<Vec<usize>>>,
156 vectors: HashMap<usize, SparseVec>,
158 config: IndexConfig,
159}
160
161impl HierarchicalIndex {
162 pub fn new(config: IndexConfig) -> Self {
163 Self {
164 clusters: Vec::new(),
165 cluster_members: Vec::new(),
166 vectors: HashMap::new(),
167 config,
168 }
169 }
170
171 fn build_hierarchy(&mut self) {
173 if self.vectors.is_empty() {
174 return;
175 }
176
177 let num_clusters = (self.vectors.len() as f64).sqrt() as usize + 1;
180 let mut cluster_assignment: HashMap<usize, usize> = HashMap::new();
181
182 let cluster_centers: Vec<SparseVec> =
184 self.vectors.values().take(num_clusters).cloned().collect();
185
186 for (id, vec) in &self.vectors {
188 let mut best_cluster = 0;
189 let mut best_score = f64::NEG_INFINITY;
190
191 for (cluster_id, center) in cluster_centers.iter().enumerate() {
192 let score = vec.cosine(center);
193 if score > best_score {
194 best_score = score;
195 best_cluster = cluster_id;
196 }
197 }
198
199 cluster_assignment.insert(*id, best_cluster);
200 }
201
202 let mut members: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204 for (id, cluster_id) in cluster_assignment {
205 members[cluster_id].push(id);
206 }
207
208 self.clusters = vec![cluster_centers];
209 self.cluster_members = vec![members];
210 }
211}
212
213impl RetrievalIndex for HierarchicalIndex {
214 fn add(&mut self, id: usize, vec: &SparseVec) {
215 self.vectors.insert(id, vec.clone());
216 }
217
218 fn finalize(&mut self) {
219 if self.config.hierarchical {
220 self.build_hierarchy();
221 }
222 }
223
224 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
225 if !self.config.hierarchical || self.clusters.is_empty() {
226 let mut results: Vec<SearchResult> = self
228 .vectors
229 .iter()
230 .map(|(id, vec)| {
231 let score = (query.cosine(vec) * 1000.0) as i32;
232 SearchResult { id: *id, score }
233 })
234 .collect();
235
236 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
237 results.truncate(k);
238 return results;
239 }
240
241 let beam_width = k.max(10);
243 let mut candidate_ids: Vec<usize> = Vec::new();
244
245 if let Some(top_level_clusters) = self.clusters.first() {
246 let mut cluster_scores: Vec<(usize, f64)> = top_level_clusters
247 .iter()
248 .enumerate()
249 .map(|(idx, center)| (idx, query.cosine(center)))
250 .collect();
251
252 cluster_scores
253 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
254
255 for (cluster_id, _score) in cluster_scores.iter().take(beam_width) {
257 if let Some(level_members) = self.cluster_members.first() {
258 if let Some(members) = level_members.get(*cluster_id) {
259 candidate_ids.extend(members);
260 }
261 }
262 }
263 }
264
265 let mut results: Vec<SearchResult> = candidate_ids
267 .into_iter()
268 .filter_map(|id| {
269 self.vectors.get(&id).map(|vec| {
270 let score = (query.cosine(vec) * 1000.0) as i32;
271 SearchResult { id, score }
272 })
273 })
274 .collect();
275
276 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
277 results.truncate(k);
278 results
279 }
280
281 fn query_top_k_reranked(
282 &self,
283 query: &SparseVec,
284 _vectors: &HashMap<usize, SparseVec>,
285 candidate_k: usize,
286 k: usize,
287 ) -> Vec<RerankedResult> {
288 let candidates = self.query_top_k(query, candidate_k);
289
290 let mut results: Vec<RerankedResult> = candidates
291 .into_iter()
292 .filter_map(|cand| {
293 self.vectors.get(&cand.id).map(|vec| RerankedResult {
294 id: cand.id,
295 approx_score: cand.score,
296 cosine: query.cosine(vec),
297 })
298 })
299 .collect();
300
301 results.sort_by(|a, b| {
302 b.cosine
303 .partial_cmp(&a.cosine)
304 .unwrap_or(std::cmp::Ordering::Equal)
305 .then_with(|| a.id.cmp(&b.id))
306 });
307 results.truncate(k);
308 results
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use embeddenator_vsa::ReversibleVSAConfig;
316
317 #[test]
318 fn test_brute_force_index() {
319 let config = ReversibleVSAConfig::default();
320 let mut index = BruteForceIndex::new(IndexConfig::default());
321
322 let vec1 = SparseVec::encode_data(b"apple", &config, None);
323 let vec2 = SparseVec::encode_data(b"banana", &config, None);
324 let vec3 = SparseVec::encode_data(b"cherry", &config, None);
325
326 index.add(1, &vec1);
327 index.add(2, &vec2);
328 index.add(3, &vec3);
329 index.finalize();
330
331 let query = SparseVec::encode_data(b"apple", &config, None);
332 let results = index.query_top_k(&query, 2);
333
334 assert!(!results.is_empty());
335 assert_eq!(results[0].id, 1); }
337
338 #[test]
339 fn test_hierarchical_index() {
340 let config = ReversibleVSAConfig::default();
341 let mut index_config = IndexConfig::default();
342 index_config.hierarchical = true;
343 let mut index = HierarchicalIndex::new(index_config);
344
345 for i in 0..20 {
347 let data = format!("doc-{}", i);
348 let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
349 index.add(i, &vec);
350 }
351 index.finalize();
352
353 let query = SparseVec::encode_data(b"doc-5", &config, None);
354 let results = index.query_top_k(&query, 5);
355
356 assert!(!results.is_empty());
357 }
358}