Skip to main content

embeddenator_retrieval/
index.rs

1//! Index structures for efficient retrieval
2//!
3//! This module provides various index structures optimized for different
4//! retrieval scenarios:
5//! - In-memory indexes for fast queries
6//! - Disk-backed indexes for large datasets
7//! - Hierarchical indexes for multi-scale search
8
9use crate::retrieval::{RerankedResult, SearchResult};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14/// Index configuration
15#[derive(Debug, Clone)]
16pub struct IndexConfig {
17    /// Similarity metric to use for reranking
18    pub metric: SimilarityMetric,
19    /// Whether to enable hierarchical indexing
20    pub hierarchical: bool,
21    /// Maximum entries per leaf node (for hierarchical)
22    pub leaf_size: usize,
23}
24
25impl Default for IndexConfig {
26    fn default() -> Self {
27        Self {
28            metric: SimilarityMetric::Cosine,
29            hierarchical: false,
30            leaf_size: 1000,
31        }
32    }
33}
34
35/// Abstract index trait for different retrieval strategies
36pub trait RetrievalIndex {
37    /// Add a vector to the index
38    fn add(&mut self, id: usize, vec: &SparseVec);
39
40    /// Finalize the index (sort, optimize, etc.)
41    fn finalize(&mut self);
42
43    /// Query for top-k candidates
44    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult>;
45
46    /// Query and rerank with exact similarity
47    fn query_top_k_reranked(
48        &self,
49        query: &SparseVec,
50        vectors: &HashMap<usize, SparseVec>,
51        candidate_k: usize,
52        k: usize,
53    ) -> Vec<RerankedResult>;
54}
55
56/// Brute force index - linear scan for ground truth
57///
58/// Useful for:
59/// - Small datasets (< 10k vectors)
60/// - Ground truth for accuracy testing
61/// - Baseline performance comparison
62#[derive(Clone, Debug)]
63pub struct BruteForceIndex {
64    vectors: HashMap<usize, SparseVec>,
65    config: IndexConfig,
66}
67
68impl BruteForceIndex {
69    pub fn new(config: IndexConfig) -> Self {
70        Self {
71            vectors: HashMap::new(),
72            config,
73        }
74    }
75
76    /// Build index from existing vectors
77    pub fn build_from_map(vectors: HashMap<usize, SparseVec>, config: IndexConfig) -> Self {
78        Self { vectors, config }
79    }
80}
81
82impl RetrievalIndex for BruteForceIndex {
83    fn add(&mut self, id: usize, vec: &SparseVec) {
84        self.vectors.insert(id, vec.clone());
85    }
86
87    fn finalize(&mut self) {
88        // Nothing to do for brute force
89    }
90
91    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
92        if k == 0 || self.vectors.is_empty() {
93            return Vec::new();
94        }
95
96        let mut results: Vec<SearchResult> = self
97            .vectors
98            .iter()
99            .map(|(id, vec)| {
100                let score = (compute_similarity(query, vec, self.config.metric) * 1000.0) as i32;
101                SearchResult { id: *id, score }
102            })
103            .collect();
104
105        results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
106        results.truncate(k);
107        results
108    }
109
110    fn query_top_k_reranked(
111        &self,
112        query: &SparseVec,
113        _vectors: &HashMap<usize, SparseVec>,
114        _candidate_k: usize,
115        k: usize,
116    ) -> Vec<RerankedResult> {
117        if k == 0 || self.vectors.is_empty() {
118            return Vec::new();
119        }
120
121        let mut results: Vec<RerankedResult> = self
122            .vectors
123            .iter()
124            .map(|(id, vec)| {
125                let cosine = query.cosine(vec);
126                let approx_score = (cosine * 1000.0) as i32;
127                RerankedResult {
128                    id: *id,
129                    approx_score,
130                    cosine,
131                }
132            })
133            .collect();
134
135        results.sort_by(|a, b| {
136            b.cosine
137                .partial_cmp(&a.cosine)
138                .unwrap_or(std::cmp::Ordering::Equal)
139                .then_with(|| a.id.cmp(&b.id))
140        });
141        results.truncate(k);
142        results
143    }
144}
145
146/// Hierarchical index using clustering for faster search
147///
148/// Divides the vector space into clusters and performs beam search
149/// through the hierarchy.
150#[derive(Clone, Debug)]
151pub struct HierarchicalIndex {
152    /// Cluster centroids at each level
153    clusters: Vec<Vec<SparseVec>>,
154    /// Mapping from cluster to member IDs
155    cluster_members: Vec<Vec<Vec<usize>>>,
156    /// All vectors (for reranking)
157    vectors: HashMap<usize, SparseVec>,
158    config: IndexConfig,
159}
160
161impl HierarchicalIndex {
162    pub fn new(config: IndexConfig) -> Self {
163        Self {
164            clusters: Vec::new(),
165            cluster_members: Vec::new(),
166            vectors: HashMap::new(),
167            config,
168        }
169    }
170
171    /// Build clusters from current vectors
172    fn build_hierarchy(&mut self) {
173        if self.vectors.is_empty() {
174            return;
175        }
176
177        // Simple k-means style clustering
178        // For production, use more sophisticated methods (HNSW, etc.)
179        let num_clusters = (self.vectors.len() as f64).sqrt() as usize + 1;
180        let mut cluster_assignment: HashMap<usize, usize> = HashMap::new();
181
182        // Initialize clusters with random vectors
183        let cluster_centers: Vec<SparseVec> =
184            self.vectors.values().take(num_clusters).cloned().collect();
185
186        // Assign each vector to nearest cluster
187        for (id, vec) in &self.vectors {
188            let mut best_cluster = 0;
189            let mut best_score = f64::NEG_INFINITY;
190
191            for (cluster_id, center) in cluster_centers.iter().enumerate() {
192                let score = vec.cosine(center);
193                if score > best_score {
194                    best_score = score;
195                    best_cluster = cluster_id;
196                }
197            }
198
199            cluster_assignment.insert(*id, best_cluster);
200        }
201
202        // Build cluster members lists
203        let mut members: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
204        for (id, cluster_id) in cluster_assignment {
205            members[cluster_id].push(id);
206        }
207
208        self.clusters = vec![cluster_centers];
209        self.cluster_members = vec![members];
210    }
211}
212
213impl RetrievalIndex for HierarchicalIndex {
214    fn add(&mut self, id: usize, vec: &SparseVec) {
215        self.vectors.insert(id, vec.clone());
216    }
217
218    fn finalize(&mut self) {
219        if self.config.hierarchical {
220            self.build_hierarchy();
221        }
222    }
223
224    fn query_top_k(&self, query: &SparseVec, k: usize) -> Vec<SearchResult> {
225        if !self.config.hierarchical || self.clusters.is_empty() {
226            // Fall back to brute force
227            let mut results: Vec<SearchResult> = self
228                .vectors
229                .iter()
230                .map(|(id, vec)| {
231                    let score = (query.cosine(vec) * 1000.0) as i32;
232                    SearchResult { id: *id, score }
233                })
234                .collect();
235
236            results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
237            results.truncate(k);
238            return results;
239        }
240
241        // Hierarchical search: find best clusters first
242        let beam_width = k.max(10);
243        let mut candidate_ids: Vec<usize> = Vec::new();
244
245        // Use configured metric for cluster scoring
246        let metric = self.config.metric;
247        if let Some(top_level_clusters) = self.clusters.first() {
248            let mut cluster_scores: Vec<(usize, f64)> = top_level_clusters
249                .iter()
250                .enumerate()
251                .map(|(idx, center)| (idx, compute_similarity(query, center, metric)))
252                .collect();
253
254            cluster_scores
255                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
256
257            // Get candidates from top clusters
258            for (cluster_id, _score) in cluster_scores.iter().take(beam_width) {
259                if let Some(level_members) = self.cluster_members.first() {
260                    if let Some(members) = level_members.get(*cluster_id) {
261                        candidate_ids.extend(members);
262                    }
263                }
264            }
265        }
266
267        // Score all candidates using configured metric
268        let metric = self.config.metric;
269        let mut results: Vec<SearchResult> = candidate_ids
270            .into_iter()
271            .filter_map(|id| {
272                self.vectors.get(&id).map(|vec| {
273                    let score = (compute_similarity(query, vec, metric) * 1000.0) as i32;
274                    SearchResult { id, score }
275                })
276            })
277            .collect();
278
279        results.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.id.cmp(&b.id)));
280        results.truncate(k);
281        results
282    }
283
284    fn query_top_k_reranked(
285        &self,
286        query: &SparseVec,
287        _vectors: &HashMap<usize, SparseVec>,
288        candidate_k: usize,
289        k: usize,
290    ) -> Vec<RerankedResult> {
291        let candidates = self.query_top_k(query, candidate_k);
292
293        // Rerank using configured metric
294        let metric = self.config.metric;
295        let mut results: Vec<RerankedResult> = candidates
296            .into_iter()
297            .filter_map(|cand| {
298                self.vectors.get(&cand.id).map(|vec| RerankedResult {
299                    id: cand.id,
300                    approx_score: cand.score,
301                    cosine: compute_similarity(query, vec, metric),
302                })
303            })
304            .collect();
305
306        results.sort_by(|a, b| {
307            b.cosine
308                .partial_cmp(&a.cosine)
309                .unwrap_or(std::cmp::Ordering::Equal)
310                .then_with(|| a.id.cmp(&b.id))
311        });
312        results.truncate(k);
313        results
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use embeddenator_vsa::ReversibleVSAConfig;
321
322    #[test]
323    fn test_brute_force_index() {
324        let config = ReversibleVSAConfig::default();
325        let mut index = BruteForceIndex::new(IndexConfig::default());
326
327        let vec1 = SparseVec::encode_data(b"apple", &config, None);
328        let vec2 = SparseVec::encode_data(b"banana", &config, None);
329        let vec3 = SparseVec::encode_data(b"cherry", &config, None);
330
331        index.add(1, &vec1);
332        index.add(2, &vec2);
333        index.add(3, &vec3);
334        index.finalize();
335
336        let query = SparseVec::encode_data(b"apple", &config, None);
337        let results = index.query_top_k(&query, 2);
338
339        assert!(!results.is_empty());
340        assert_eq!(results[0].id, 1); // Should match apple best
341    }
342
343    #[test]
344    fn test_hierarchical_index() {
345        let config = ReversibleVSAConfig::default();
346        let index_config = IndexConfig {
347            hierarchical: true,
348            ..IndexConfig::default()
349        };
350        let mut index = HierarchicalIndex::new(index_config);
351
352        // Add multiple vectors
353        for i in 0..20 {
354            let data = format!("doc-{}", i);
355            let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
356            index.add(i, &vec);
357        }
358        index.finalize();
359
360        let query = SparseVec::encode_data(b"doc-5", &config, None);
361        let results = index.query_top_k(&query, 5);
362
363        assert!(!results.is_empty());
364    }
365}