1use crate::retrieval::{RerankedResult, SearchResult};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct IndexConfig {
17 pub metric: SimilarityMetric,
19 pub hierarchical: bool,
21 pub leaf_size: usize,
23}
24
25impl Default for IndexConfig {
26 fn default() -> Self {
27 Self {
28 metric: SimilarityMetric::Cosine,
29 hierarchical: false,
30 leaf_size: 1000,
31 }
32 }
33}
34
35pub trait RetrievalIndex {
37 fn add(&mut self, id: usize, vec: &SparseVec);
39
40 fn finalize(&mut self);
42
43 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult>;
45
46 fn query_top_k_reranked(
48 &self,
49 query: &SparseVec,
50 vectors: &HashMap<usize, SparseVec>,
51 candidate_k: usize,
52 k: usize,
53 ) -> Vec<RerankedResult>;
54}
55
56#[derive(Clone, Debug)]
63pub struct BruteForceIndex {
64 vectors: HashMap<usize, SparseVec>,
65 config: IndexConfig,
66}
67
68impl BruteForceIndex {
69 pub fn new(config: IndexConfig) -> Self {
70 Self {
71 vectors: HashMap::new(),
72 config,
73 }
74 }
75
76 pub fn build_from_map(vectors: HashMap<usize, SparseVec>, config: IndexConfig) -> Self {
78 Self { vectors, config }
79 }
80}
81
82impl RetrievalIndex for BruteForceIndex {
83 fn add(&mut self, id: usize, vec: &SparseVec) {
84 self.vectors.insert(id, vec.clone());
85 }
86
87 fn finalize(&mut self) {
88 }
90
91 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
92 if k == 0 || self.vectors.is_empty() {
93 return Vec::new();
94 }
95
96 let mut results: Vec<SearchResult> = self
97 .vectors
98 .iter()
99 .map(|(id, vec)| {
100 let score = (compute_similarity(query, vec, self.config.metric) * 1000.0) as i32;
101 SearchResult { id: *id, score }
102 })
103 .collect();
104
105 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
106 results.truncate(k);
107 results
108 }
109
110 fn query_top_k_reranked(
111 &self,
112 query: &SparseVec,
113 _vectors: &HashMap<usize, SparseVec>,
114 _candidate_k: usize,
115 k: usize,
116 ) -> Vec<RerankedResult> {
117 if k == 0 || self.vectors.is_empty() {
118 return Vec::new();
119 }
120
121 let mut results: Vec<RerankedResult> = self
122 .vectors
123 .iter()
124 .map(|(id, vec)| {
125 let cosine = query.cosine(vec);
126 let approx_score = (cosine * 1000.0) as i32;
127 RerankedResult {
128 id: *id,
129 approx_score,
130 cosine,
131 }
132 })
133 .collect();
134
135 results.sort_by(|a, b| {
136 b.cosine
137 .partial_cmp(&a.cosine)
138 .unwrap_or(std::cmp::Ordering::Equal)
139 .then_with(|| a.id.cmp(&b.id))
140 });
141 results.truncate(k);
142 results
143 }
144}
145
146#[derive(Clone, Debug)]
151pub struct HierarchicalIndex {
152 clusters: Vec<Vec<SparseVec>>,
154 cluster_members: Vec<Vec<Vec<usize>>>,
156 vectors: HashMap<usize, SparseVec>,
158 config: IndexConfig,
159}
160
161impl HierarchicalIndex {
162 pub fn new(config: IndexConfig) -> Self {
163 Self {
164 clusters: Vec::new(),
165 cluster_members: Vec::new(),
166 vectors: HashMap::new(),
167 config,
168 }
169 }
170
171 fn build_hierarchy(&mut self) {
173 if self.vectors.is_empty() {
174 return;
175 }
176
177 let num_clusters = (self.vectors.len() as f64).sqrt() as usize + 1;
180 let mut cluster_assignment: HashMap<usize, usize> = HashMap::new();
181
182 let cluster_centers: Vec<SparseVec> =
184 self.vectors.values().take(num_clusters).cloned().collect();
185
186 for (id, vec) in &self.vectors {
188 let mut best_cluster = 0;
189 let mut best_score = f64::NEG_INFINITY;
190
191 for (cluster_id, center) in cluster_centers.iter().enumerate() {
192 let score = vec.cosine(center);
193 if score > best_score {
194 best_score = score;
195 best_cluster = cluster_id;
196 }
197 }
198
199 cluster_assignment.insert(*id, best_cluster);
200 }
201
202 let mut members: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204 for (id, cluster_id) in cluster_assignment {
205 members[cluster_id].push(id);
206 }
207
208 self.clusters = vec![cluster_centers];
209 self.cluster_members = vec![members];
210 }
211}
212
213impl RetrievalIndex for HierarchicalIndex {
214 fn add(&mut self, id: usize, vec: &SparseVec) {
215 self.vectors.insert(id, vec.clone());
216 }
217
218 fn finalize(&mut self) {
219 if self.config.hierarchical {
220 self.build_hierarchy();
221 }
222 }
223
224 fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
225 if !self.config.hierarchical || self.clusters.is_empty() {
226 let mut results: Vec<SearchResult> = self
228 .vectors
229 .iter()
230 .map(|(id, vec)| {
231 let score = (query.cosine(vec) * 1000.0) as i32;
232 SearchResult { id: *id, score }
233 })
234 .collect();
235
236 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
237 results.truncate(k);
238 return results;
239 }
240
241 let beam_width = k.max(10);
243 let mut candidate_ids: Vec<usize> = Vec::new();
244
245 let metric = self.config.metric;
247 if let Some(top_level_clusters) = self.clusters.first() {
248 let mut cluster_scores: Vec<(usize, f64)> = top_level_clusters
249 .iter()
250 .enumerate()
251 .map(|(idx, center)| (idx, compute_similarity(query, center, metric)))
252 .collect();
253
254 cluster_scores
255 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
256
257 for (cluster_id, _score) in cluster_scores.iter().take(beam_width) {
259 if let Some(level_members) = self.cluster_members.first() {
260 if let Some(members) = level_members.get(*cluster_id) {
261 candidate_ids.extend(members);
262 }
263 }
264 }
265 }
266
267 let metric = self.config.metric;
269 let mut results: Vec<SearchResult> = candidate_ids
270 .into_iter()
271 .filter_map(|id| {
272 self.vectors.get(&id).map(|vec| {
273 let score = (compute_similarity(query, vec, metric) * 1000.0) as i32;
274 SearchResult { id, score }
275 })
276 })
277 .collect();
278
279 results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
280 results.truncate(k);
281 results
282 }
283
284 fn query_top_k_reranked(
285 &self,
286 query: &SparseVec,
287 _vectors: &HashMap<usize, SparseVec>,
288 candidate_k: usize,
289 k: usize,
290 ) -> Vec<RerankedResult> {
291 let candidates = self.query_top_k(query, candidate_k);
292
293 let metric = self.config.metric;
295 let mut results: Vec<RerankedResult> = candidates
296 .into_iter()
297 .filter_map(|cand| {
298 self.vectors.get(&cand.id).map(|vec| RerankedResult {
299 id: cand.id,
300 approx_score: cand.score,
301 cosine: compute_similarity(query, vec, metric),
302 })
303 })
304 .collect();
305
306 results.sort_by(|a, b| {
307 b.cosine
308 .partial_cmp(&a.cosine)
309 .unwrap_or(std::cmp::Ordering::Equal)
310 .then_with(|| a.id.cmp(&b.id))
311 });
312 results.truncate(k);
313 results
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use embeddenator_vsa::ReversibleVSAConfig;
321
322 #[test]
323 fn test_brute_force_index() {
324 let config = ReversibleVSAConfig::default();
325 let mut index = BruteForceIndex::new(IndexConfig::default());
326
327 let vec1 = SparseVec::encode_data(b"apple", &config, None);
328 let vec2 = SparseVec::encode_data(b"banana", &config, None);
329 let vec3 = SparseVec::encode_data(b"cherry", &config, None);
330
331 index.add(1, &vec1);
332 index.add(2, &vec2);
333 index.add(3, &vec3);
334 index.finalize();
335
336 let query = SparseVec::encode_data(b"apple", &config, None);
337 let results = index.query_top_k(&query, 2);
338
339 assert!(!results.is_empty());
340 assert_eq!(results[0].id, 1); }
342
343 #[test]
344 fn test_hierarchical_index() {
345 let config = ReversibleVSAConfig::default();
346 let index_config = IndexConfig {
347 hierarchical: true,
348 ..IndexConfig::default()
349 };
350 let mut index = HierarchicalIndex::new(index_config);
351
352 for i in 0..20 {
354 let data = format!("doc-{}", i);
355 let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
356 index.add(i, &vec);
357 }
358 index.finalize();
359
360 let query = SparseVec::encode_data(b"doc-5", &config, None);
361 let results = index.query_top_k(&query, 5);
362
363 assert!(!results.is_empty());
364 }
365}