Skip to main content

embeddenator_retrieval/
index.rs

1//! Index structures for efficient retrieval
2//!
3//! This module provides various index structures optimized for different
4//! retrieval scenarios:
5//! - In-memory indexes for fast queries
6//! - Disk-backed indexes for large datasets
7//! - Hierarchical indexes for multi-scale search
8
9use crate::retrieval::{RerankedResult, SearchResult};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14/// Index configuration
15#[derive(Debug, Clone)]
16pub struct IndexConfig {
17    /// Similarity metric to use for reranking
18    pub metric: SimilarityMetric,
19    /// Whether to enable hierarchical indexing
20    pub hierarchical: bool,
21    /// Maximum entries per leaf node (for hierarchical)
22    pub leaf_size: usize,
23}
24
25impl Default for IndexConfig {
26    fn default() -> Self {
27        Self {
28            metric: SimilarityMetric::Cosine,
29            hierarchical: false,
30            leaf_size: 1000,
31        }
32    }
33}
34
35/// Abstract index trait for different retrieval strategies
36pub trait RetrievalIndex {
37    /// Add a vector to the index
38    fn add(&mut self, id: usize, vec: &SparseVec);
39
40    /// Finalize the index (sort, optimize, etc.)
41    fn finalize(&mut self);
42
43    /// Query for top-k candidates
44    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult>;
45
46    /// Query and rerank with exact similarity
47    fn query_top_k_reranked(
48        &self,
49        query: &SparseVec,
50        vectors: &HashMap<usize, SparseVec>,
51        candidate_k: usize,
52        k: usize,
53    ) -> Vec<RerankedResult>;
54}
55
56/// Brute force index - linear scan for ground truth
57///
58/// Useful for:
59/// - Small datasets (< 10k vectors)
60/// - Ground truth for accuracy testing
61/// - Baseline performance comparison
62#[derive(Clone, Debug)]
63pub struct BruteForceIndex {
64    vectors: HashMap<usize, SparseVec>,
65    config: IndexConfig,
66}
67
68impl BruteForceIndex {
69    pub fn new(config: IndexConfig) -> Self {
70        Self {
71            vectors: HashMap::new(),
72            config,
73        }
74    }
75
76    /// Build index from existing vectors
77    pub fn build_from_map(vectors: HashMap<usize, SparseVec>, config: IndexConfig) -> Self {
78        Self { vectors, config }
79    }
80}
81
82impl RetrievalIndex for BruteForceIndex {
83    fn add(&mut self, id: usize, vec: &SparseVec) {
84        self.vectors.insert(id, vec.clone());
85    }
86
87    fn finalize(&mut self) {
88        // Nothing to do for brute force
89    }
90
91    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
92        if k == 0 || self.vectors.is_empty() {
93            return Vec::new();
94        }
95
96        let mut results: Vec<SearchResult> = self
97            .vectors
98            .iter()
99            .map(|(id, vec)| {
100                let score = (compute_similarity(query, vec, self.config.metric) * 1000.0) as i32;
101                SearchResult { id: *id, score }
102            })
103            .collect();
104
105        results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
106        results.truncate(k);
107        results
108    }
109
110    fn query_top_k_reranked(
111        &self,
112        query: &SparseVec,
113        _vectors: &HashMap<usize, SparseVec>,
114        _candidate_k: usize,
115        k: usize,
116    ) -> Vec<RerankedResult> {
117        if k == 0 || self.vectors.is_empty() {
118            return Vec::new();
119        }
120
121        let mut results: Vec<RerankedResult> = self
122            .vectors
123            .iter()
124            .map(|(id, vec)| {
125                let cosine = query.cosine(vec);
126                let approx_score = (cosine * 1000.0) as i32;
127                RerankedResult {
128                    id: *id,
129                    approx_score,
130                    cosine,
131                }
132            })
133            .collect();
134
135        results.sort_by(|a, b| {
136            b.cosine
137                .partial_cmp(&a.cosine)
138                .unwrap_or(std::cmp::Ordering::Equal)
139                .then_with(|| a.id.cmp(&b.id))
140        });
141        results.truncate(k);
142        results
143    }
144}
145
146/// Hierarchical index using clustering for faster search
147///
148/// Divides the vector space into clusters and performs beam search
149/// through the hierarchy.
150#[derive(Clone, Debug)]
151pub struct HierarchicalIndex {
152    /// Cluster centroids at each level
153    clusters: Vec<Vec<SparseVec>>,
154    /// Mapping from cluster to member IDs
155    cluster_members: Vec<Vec<Vec<usize>>>,
156    /// All vectors (for reranking)
157    vectors: HashMap<usize, SparseVec>,
158    config: IndexConfig,
159}
160
161impl HierarchicalIndex {
162    pub fn new(config: IndexConfig) -> Self {
163        Self {
164            clusters: Vec::new(),
165            cluster_members: Vec::new(),
166            vectors: HashMap::new(),
167            config,
168        }
169    }
170
171    /// Build clusters from current vectors
172    fn build_hierarchy(&mut self) {
173        if self.vectors.is_empty() {
174            return;
175        }
176
177        // Simple k-means style clustering
178        // For production, use more sophisticated methods (HNSW, etc.)
179        let num_clusters = (self.vectors.len() as f64).sqrt() as usize + 1;
180        let mut cluster_assignment: HashMap<usize, usize> = HashMap::new();
181
182        // Initialize clusters with random vectors
183        let cluster_centers: Vec<SparseVec> =
184            self.vectors.values().take(num_clusters).cloned().collect();
185
186        // Assign each vector to nearest cluster
187        for (id, vec) in &self.vectors {
188            let mut best_cluster = 0;
189            let mut best_score = f64::NEG_INFINITY;
190
191            for (cluster_id, center) in cluster_centers.iter().enumerate() {
192                let score = vec.cosine(center);
193                if score > best_score {
194                    best_score = score;
195                    best_cluster = cluster_id;
196                }
197            }
198
199            cluster_assignment.insert(*id, best_cluster);
200        }
201
202        // Build cluster members lists
203        let mut members: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204        for (id, cluster_id) in cluster_assignment {
205            members[cluster_id].push(id);
206        }
207
208        self.clusters = vec![cluster_centers];
209        self.cluster_members = vec![members];
210    }
211}
212
213impl RetrievalIndex for HierarchicalIndex {
214    fn add(&mut self, id: usize, vec: &SparseVec) {
215        self.vectors.insert(id, vec.clone());
216    }
217
218    fn finalize(&mut self) {
219        if self.config.hierarchical {
220            self.build_hierarchy();
221        }
222    }
223
224    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
225        if !self.config.hierarchical || self.clusters.is_empty() {
226            // Fall back to brute force
227            let mut results: Vec<SearchResult> = self
228                .vectors
229                .iter()
230                .map(|(id, vec)| {
231                    let score = (query.cosine(vec) * 1000.0) as i32;
232                    SearchResult { id: *id, score }
233                })
234                .collect();
235
236            results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
237            results.truncate(k);
238            return results;
239        }
240
241        // Hierarchical search: find best clusters first
242        let beam_width = k.max(10);
243        let mut candidate_ids: Vec<usize> = Vec::new();
244
245        if let Some(top_level_clusters) = self.clusters.first() {
246            let mut cluster_scores: Vec<(usize, f64)> = top_level_clusters
247                .iter()
248                .enumerate()
249                .map(|(idx, center)| (idx, query.cosine(center)))
250                .collect();
251
252            cluster_scores
253                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
254
255            // Get candidates from top clusters
256            for (cluster_id, _score) in cluster_scores.iter().take(beam_width) {
257                if let Some(level_members) = self.cluster_members.first() {
258                    if let Some(members) = level_members.get(*cluster_id) {
259                        candidate_ids.extend(members);
260                    }
261                }
262            }
263        }
264
265        // Score all candidates
266        let mut results: Vec<SearchResult> = candidate_ids
267            .into_iter()
268            .filter_map(|id| {
269                self.vectors.get(&id).map(|vec| {
270                    let score = (query.cosine(vec) * 1000.0) as i32;
271                    SearchResult { id, score }
272                })
273            })
274            .collect();
275
276        results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
277        results.truncate(k);
278        results
279    }
280
281    fn query_top_k_reranked(
282        &self,
283        query: &SparseVec,
284        _vectors: &HashMap<usize, SparseVec>,
285        candidate_k: usize,
286        k: usize,
287    ) -> Vec<RerankedResult> {
288        let candidates = self.query_top_k(query, candidate_k);
289
290        let mut results: Vec<RerankedResult> = candidates
291            .into_iter()
292            .filter_map(|cand| {
293                self.vectors.get(&cand.id).map(|vec| RerankedResult {
294                    id: cand.id,
295                    approx_score: cand.score,
296                    cosine: query.cosine(vec),
297                })
298            })
299            .collect();
300
301        results.sort_by(|a, b| {
302            b.cosine
303                .partial_cmp(&a.cosine)
304                .unwrap_or(std::cmp::Ordering::Equal)
305                .then_with(|| a.id.cmp(&b.id))
306        });
307        results.truncate(k);
308        results
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use embeddenator_vsa::ReversibleVSAConfig;
316
317    #[test]
318    fn test_brute_force_index() {
319        let config = ReversibleVSAConfig::default();
320        let mut index = BruteForceIndex::new(IndexConfig::default());
321
322        let vec1 = SparseVec::encode_data(b"apple", &config, None);
323        let vec2 = SparseVec::encode_data(b"banana", &config, None);
324        let vec3 = SparseVec::encode_data(b"cherry", &config, None);
325
326        index.add(1, &vec1);
327        index.add(2, &vec2);
328        index.add(3, &vec3);
329        index.finalize();
330
331        let query = SparseVec::encode_data(b"apple", &config, None);
332        let results = index.query_top_k(&query, 2);
333
334        assert!(!results.is_empty());
335        assert_eq!(results[0].id, 1); // Should match apple best
336    }
337
338    #[test]
339    fn test_hierarchical_index() {
340        let config = ReversibleVSAConfig::default();
341        let mut index_config = IndexConfig::default();
342        index_config.hierarchical = true;
343        let mut index = HierarchicalIndex::new(index_config);
344
345        // Add multiple vectors
346        for i in 0..20 {
347            let data = format!("doc-{}", i);
348            let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
349            index.add(i, &vec);
350        }
351        index.finalize();
352
353        let query = SparseVec::encode_data(b"doc-5", &config, None);
354        let results = index.query_top_k(&query, 5);
355
356        assert!(!results.is_empty());
357    }
358}