Skip to main content

embeddenator_retrieval/
search.rs

1//! Search strategies for semantic retrieval
2//!
3//! This module implements various search algorithms:
4//! - Exact search (brute force)
5//! - Approximate search (inverted index)
6//! - Beam search (hierarchical)
7//! - Two-stage search (candidate generation + reranking)
8
9use crate::retrieval::{SearchResult, TernaryInvertedIndex};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14/// Search strategy configuration
15#[derive(Debug, Clone)]
16pub struct SearchConfig {
17    /// Similarity metric for final ranking
18    pub metric: SimilarityMetric,
19    /// Number of candidates to generate before reranking
20    pub candidate_k: usize,
21    /// Beam width for hierarchical search
22    pub beam_width: usize,
23    /// Enable parallel search
24    pub parallel: bool,
25}
26
27impl Default for SearchConfig {
28    fn default() -> Self {
29        Self {
30            metric: SimilarityMetric::Cosine,
31            candidate_k: 200,
32            beam_width: 10,
33            parallel: false,
34        }
35    }
36}
37
38/// Search result with additional metadata
39#[derive(Debug, Clone, PartialEq)]
40pub struct RankedResult {
41    /// Document ID
42    pub id: usize,
43    /// Final similarity score
44    pub score: f64,
45    /// Approximate score from first stage
46    pub approx_score: i32,
47    /// Rank in results (1-indexed)
48    pub rank: usize,
49}
50
51/// Two-stage search: fast candidate generation + accurate reranking
52///
53/// This is the recommended strategy for most use cases. It combines the
54/// speed of inverted index search with the accuracy of exact similarity.
55///
56/// # Arguments
57/// * `query` - Query vector
58/// * `index` - Inverted index for candidate generation
59/// * `vectors` - Full vector collection for reranking
60/// * `config` - Search configuration
61/// * `k` - Number of final results to return
62///
63/// # Returns
64/// Top-k results ranked by exact similarity
65///
66/// # Examples
67///
68/// ```
69/// use embeddenator_retrieval::search::{two_stage_search, SearchConfig};
70/// use embeddenator_retrieval::TernaryInvertedIndex;
71/// use embeddenator_vsa::SparseVec;
72/// use std::collections::HashMap;
73///
74/// let mut index = TernaryInvertedIndex::new();
75/// let mut vectors = HashMap::new();
76///
77/// let vec1 = SparseVec::from_data(b"document one");
78/// let vec2 = SparseVec::from_data(b"document two");
79///
80/// index.add(1, &vec1);
81/// index.add(2, &vec2);
82/// index.finalize();
83///
84/// vectors.insert(1, vec1);
85/// vectors.insert(2, vec2);
86///
87/// let query = SparseVec::from_data(b"document");
88/// let config = SearchConfig::default();
89/// let results = two_stage_search(&query, &index, &vectors, &config, 5);
90///
91/// assert!(!results.is_empty());
92/// ```
93pub fn two_stage_search(
94    query: &SparseVec,
95    index: &TernaryInvertedIndex,
96    vectors: &HashMap<usize, SparseVec>,
97    config: &SearchConfig,
98    k: usize,
99) -> Vec<RankedResult> {
100    if k == 0 {
101        return Vec::new();
102    }
103
104    // Stage 1: Generate candidates using inverted index
105    let candidate_k = config.candidate_k.max(k);
106    let candidates = index.query_top_k(query, candidate_k);
107
108    // Stage 2: Rerank candidates with exact similarity
109    let mut reranked: Vec<RankedResult> = candidates
110        .iter()
111        .filter_map(|cand| {
112            vectors.get(&cand.id).map(|vec| {
113                let score = compute_similarity(query, vec, config.metric);
114                RankedResult {
115                    id: cand.id,
116                    score,
117                    approx_score: cand.score,
118                    rank: 0, // Will be set after sorting
119                }
120            })
121        })
122        .collect();
123
124    // Sort by similarity score
125    reranked.sort_by(|a, b| {
126        b.score
127            .partial_cmp(&a.score)
128            .unwrap_or(std::cmp::Ordering::Equal)
129            .then_with(|| a.id.cmp(&b.id))
130    });
131
132    // Assign ranks and truncate
133    reranked.truncate(k);
134    for (idx, result) in reranked.iter_mut().enumerate() {
135        result.rank = idx + 1;
136    }
137
138    reranked
139}
140
141/// Exact search using brute force comparison
142///
143/// Computes similarity against all vectors in the collection.
144/// Use for small collections or ground truth evaluation.
145///
146/// # Arguments
147/// * `query` - Query vector
148/// * `vectors` - Vector collection
149/// * `metric` - Similarity metric to use
150/// * `k` - Number of results to return
151///
152/// # Returns
153/// Top-k results ranked by similarity
154///
155/// # Examples
156///
157/// ```
158/// use embeddenator_retrieval::search::{exact_search};
159/// use embeddenator_retrieval::similarity::SimilarityMetric;
160/// use embeddenator_vsa::SparseVec;
161/// use std::collections::HashMap;
162///
163/// let mut vectors = HashMap::new();
164/// vectors.insert(1, SparseVec::from_data(b"document one"));
165/// vectors.insert(2, SparseVec::from_data(b"document two"));
166///
167/// let query = SparseVec::from_data(b"document");
168/// let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 5);
169///
170/// assert!(!results.is_empty());
171/// ```
172pub fn exact_search(
173    query: &SparseVec,
174    vectors: &HashMap<usize, SparseVec>,
175    metric: SimilarityMetric,
176    k: usize,
177) -> Vec<RankedResult> {
178    if k == 0 || vectors.is_empty() {
179        return Vec::new();
180    }
181
182    let mut results: Vec<RankedResult> = vectors
183        .iter()
184        .map(|(id, vec)| {
185            let score = compute_similarity(query, vec, metric);
186            RankedResult {
187                id: *id,
188                score,
189                approx_score: (score * 1000.0) as i32,
190                rank: 0,
191            }
192        })
193        .collect();
194
195    results.sort_by(|a, b| {
196        b.score
197            .partial_cmp(&a.score)
198            .unwrap_or(std::cmp::Ordering::Equal)
199            .then_with(|| a.id.cmp(&b.id))
200    });
201
202    results.truncate(k);
203    for (idx, result) in results.iter_mut().enumerate() {
204        result.rank = idx + 1;
205    }
206
207    results
208}
209
210/// Approximate search using only the inverted index
211///
212/// Fast but less accurate. Good for initial filtering or when
213/// speed is more important than perfect ranking.
214///
215/// # Arguments
216/// * `query` - Query vector
217/// * `index` - Inverted index
218/// * `k` - Number of results to return
219///
220/// # Returns
221/// Top-k results ranked by approximate score
222///
223/// # Examples
224///
225/// ```
226/// use embeddenator_retrieval::search::approximate_search;
227/// use embeddenator_retrieval::TernaryInvertedIndex;
228/// use embeddenator_vsa::SparseVec;
229///
230/// let mut index = TernaryInvertedIndex::new();
231/// let vec1 = SparseVec::from_data(b"document one");
232/// index.add(1, &vec1);
233/// index.finalize();
234///
235/// let query = SparseVec::from_data(b"document");
236/// let results = approximate_search(&query, &index, 5);
237///
238/// assert!(!results.is_empty());
239/// ```
240pub fn approximate_search(
241    query: &SparseVec,
242    index: &TernaryInvertedIndex,
243    k: usize,
244) -> Vec<SearchResult> {
245    index.query_top_k(query, k)
246}
247
248/// Batch search - process multiple queries efficiently
249///
250/// # Arguments
251/// * `queries` - Multiple query vectors
252/// * `index` - Inverted index
253/// * `vectors` - Vector collection
254/// * `config` - Search configuration
255/// * `k` - Number of results per query
256///
257/// # Returns
258/// Results for each query
259///
260/// # Examples
261///
262/// ```
263/// use embeddenator_retrieval::search::{batch_search, SearchConfig};
264/// use embeddenator_retrieval::TernaryInvertedIndex;
265/// use embeddenator_vsa::SparseVec;
266/// use std::collections::HashMap;
267///
268/// let mut index = TernaryInvertedIndex::new();
269/// let mut vectors = HashMap::new();
270/// let vec1 = SparseVec::from_data(b"doc one");
271/// index.add(1, &vec1);
272/// index.finalize();
273/// vectors.insert(1, vec1);
274///
275/// let queries = vec![
276///     SparseVec::from_data(b"query1"),
277///     SparseVec::from_data(b"query2"),
278/// ];
279/// let config = SearchConfig::default();
280/// let results = batch_search(&queries, &index, &vectors, &config, 5);
281///
282/// assert_eq!(results.len(), 2);
283/// ```
284pub fn batch_search(
285    queries: &[SparseVec],
286    index: &TernaryInvertedIndex,
287    vectors: &HashMap<usize, SparseVec>,
288    config: &SearchConfig,
289    k: usize,
290) -> Vec<Vec<RankedResult>> {
291    queries
292        .iter()
293        .map(|query| two_stage_search(query, index, vectors, config, k))
294        .collect()
295}
296
297/// Compute recall@k metric for search quality evaluation
298///
299/// Compares approximate search results against ground truth.
300///
301/// # Arguments
302/// * `approx_results` - Results from approximate search
303/// * `exact_results` - Ground truth from exact search
304/// * `k` - Number of top results to consider
305///
306/// # Returns
307/// Recall score in [0, 1]
308pub fn compute_recall_at_k(
309    approx_results: &[SearchResult],
310    exact_results: &[RankedResult],
311    k: usize,
312) -> f64 {
313    if k == 0 || exact_results.is_empty() {
314        return 0.0;
315    }
316
317    let exact_ids: std::collections::HashSet<usize> =
318        exact_results.iter().take(k).map(|r| r.id).collect();
319
320    let matches = approx_results
321        .iter()
322        .take(k)
323        .filter(|r| exact_ids.contains(&r.id))
324        .count();
325
326    matches as f64 / k.min(exact_results.len()) as f64
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use embeddenator_vsa::ReversibleVSAConfig;
333
334    #[test]
335    fn test_two_stage_search() {
336        let config = ReversibleVSAConfig::default();
337        let mut index = TernaryInvertedIndex::new();
338        let mut vectors = HashMap::new();
339
340        let vec1 = SparseVec::encode_data(b"hello world", &config, None);
341        let vec2 = SparseVec::encode_data(b"goodbye world", &config, None);
342
343        index.add(1, &vec1);
344        index.add(2, &vec2);
345        index.finalize();
346
347        vectors.insert(1, vec1);
348        vectors.insert(2, vec2);
349
350        let query = SparseVec::encode_data(b"hello", &config, None);
351        let search_config = SearchConfig::default();
352        let results = two_stage_search(&query, &index, &vectors, &search_config, 2);
353
354        assert!(!results.is_empty());
355        assert_eq!(results[0].rank, 1);
356    }
357
358    #[test]
359    fn test_exact_search() {
360        let config = ReversibleVSAConfig::default();
361        let mut vectors = HashMap::new();
362
363        vectors.insert(1, SparseVec::encode_data(b"apple", &config, None));
364        vectors.insert(2, SparseVec::encode_data(b"banana", &config, None));
365        vectors.insert(3, SparseVec::encode_data(b"cherry", &config, None));
366
367        let query = SparseVec::encode_data(b"apple", &config, None);
368        let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 3);
369
370        assert_eq!(results.len(), 3);
371        assert_eq!(results[0].id, 1); // Should match apple best
372    }
373
374    #[test]
375    fn test_batch_search() {
376        let config = ReversibleVSAConfig::default();
377        let mut index = TernaryInvertedIndex::new();
378        let mut vectors = HashMap::new();
379
380        let vec1 = SparseVec::encode_data(b"doc1", &config, None);
381        let vec2 = SparseVec::encode_data(b"doc2", &config, None);
382
383        index.add(1, &vec1);
384        index.add(2, &vec2);
385        index.finalize();
386
387        vectors.insert(1, vec1);
388        vectors.insert(2, vec2);
389
390        let queries = vec![
391            SparseVec::encode_data(b"query1", &config, None),
392            SparseVec::encode_data(b"query2", &config, None),
393        ];
394
395        let search_config = SearchConfig::default();
396        let results = batch_search(&queries, &index, &vectors, &search_config, 2);
397
398        assert_eq!(results.len(), 2);
399    }
400
401    #[test]
402    fn test_recall_computation() {
403        let approx = vec![
404            SearchResult { id: 1, score: 100 },
405            SearchResult { id: 2, score: 90 },
406            SearchResult { id: 5, score: 80 },
407        ];
408
409        let exact = vec![
410            RankedResult {
411                id: 1,
412                score: 0.95,
413                approx_score: 100,
414                rank: 1,
415            },
416            RankedResult {
417                id: 3,
418                score: 0.90,
419                approx_score: 95,
420                rank: 2,
421            },
422            RankedResult {
423                id: 2,
424                score: 0.85,
425                approx_score: 90,
426                rank: 3,
427            },
428        ];
429
430        let recall = compute_recall_at_k(&approx, &exact, 3);
431        assert!((recall - 0.666).abs() < 0.01); // 2/3 match
432    }
433}