Skip to main content

embeddenator_retrieval/
search.rs

1//! Search strategies for semantic retrieval
2//!
3//! This module implements various search algorithms:
4//! - Exact search (brute force)
5//! - Approximate search (inverted index)
6//! - Beam search (hierarchical)
7//! - Two-stage search (candidate generation + reranking)
8//!
9//! All search functions support parallel execution via `SearchConfig::parallel`.
10
11use std::collections::HashMap;
12
13use rayon::prelude::*;
14
15use crate::retrieval::{SearchResult, TernaryInvertedIndex};
16use crate::similarity::{compute_similarity, SimilarityMetric};
17use embeddenator_vsa::SparseVec;
18
19/// Search strategy configuration
20#[derive(Debug, Clone)]
21pub struct SearchConfig {
22    /// Similarity metric for final ranking
23    pub metric: SimilarityMetric,
24    /// Number of candidates to generate before reranking
25    pub candidate_k: usize,
26    /// Beam width for hierarchical search.
27    ///
28    /// **Note:** This field is reserved for future hierarchical/beam search
29    /// implementations. It is not currently used by the search functions in
30    /// this module. Setting this value has no effect on search behavior.
31    pub beam_width: usize,
32    /// Enable parallel search
33    pub parallel: bool,
34}
35
36impl Default for SearchConfig {
37    fn default() -> Self {
38        Self {
39            metric: SimilarityMetric::Cosine,
40            candidate_k: 200,
41            beam_width: 10,
42            parallel: false,
43        }
44    }
45}
46
47/// Search result with additional metadata
48#[derive(Debug, Clone, PartialEq)]
49pub struct RankedResult {
50    /// Document ID
51    pub id: usize,
52    /// Final similarity score
53    pub score: f64,
54    /// Approximate score from first stage
55    pub approx_score: i32,
56    /// Rank in results (1-indexed)
57    pub rank: usize,
58}
59
60/// Two-stage search: fast candidate generation + accurate reranking
61///
62/// This is the recommended strategy for most use cases. It combines the
63/// speed of inverted index search with the accuracy of exact similarity.
64///
65/// # Arguments
66/// * `query` - Query vector
67/// * `index` - Inverted index for candidate generation
68/// * `vectors` - Full vector collection for reranking
69/// * `config` - Search configuration
70/// * `k` - Number of final results to return
71///
72/// # Returns
73/// Top-k results ranked by exact similarity
74///
75/// # Examples
76///
77/// ```
78/// use embeddenator_retrieval::search::{two_stage_search, SearchConfig};
79/// use embeddenator_retrieval::TernaryInvertedIndex;
80/// use embeddenator_vsa::SparseVec;
81/// use std::collections::HashMap;
82///
83/// let mut index = TernaryInvertedIndex::new();
84/// let mut vectors = HashMap::new();
85///
86/// let vec1 = SparseVec::from_data(b"document one");
87/// let vec2 = SparseVec::from_data(b"document two");
88///
89/// index.add(1, &vec1);
90/// index.add(2, &vec2);
91/// index.finalize();
92///
93/// vectors.insert(1, vec1);
94/// vectors.insert(2, vec2);
95///
96/// let query = SparseVec::from_data(b"document");
97/// let config = SearchConfig::default();
98/// let results = two_stage_search(&query, &index, &vectors, &config, 5);
99///
100/// assert!(!results.is_empty());
101/// ```
102pub fn two_stage_search(
103    query: &SparseVec,
104    index: &TernaryInvertedIndex,
105    vectors: &HashMap<usize, SparseVec>,
106    config: &SearchConfig,
107    k: usize,
108) -> Vec<RankedResult> {
109    if k == 0 {
110        return Vec::new();
111    }
112
113    // Stage 1: Generate candidates using inverted index
114    let candidate_k = config.candidate_k.max(k);
115    let candidates = index.query_top_k(query, candidate_k);
116
117    // Stage 2: Rerank candidates with exact similarity
118    // Use parallel iteration when enabled for compute-intensive similarity calculations
119    let mut reranked: Vec<RankedResult> = if config.parallel {
120        // Collect candidates with their vectors first to enable parallel processing
121        let candidates_with_vecs: Vec<_> = candidates
122            .iter()
123            .filter_map(|cand| vectors.get(&cand.id).map(|vec| (cand, vec)))
124            .collect();
125
126        candidates_with_vecs
127            .par_iter()
128            .map(|(cand, vec)| {
129                let score = compute_similarity(query, vec, config.metric);
130                RankedResult {
131                    id: cand.id,
132                    score,
133                    approx_score: cand.score,
134                    rank: 0, // Will be set after sorting
135                }
136            })
137            .collect()
138    } else {
139        candidates
140            .iter()
141            .filter_map(|cand| {
142                vectors.get(&cand.id).map(|vec| {
143                    let score = compute_similarity(query, vec, config.metric);
144                    RankedResult {
145                        id: cand.id,
146                        score,
147                        approx_score: cand.score,
148                        rank: 0, // Will be set after sorting
149                    }
150                })
151            })
152            .collect()
153    };
154
155    // Sort by similarity score
156    reranked.sort_by(|a, b| {
157        b.score
158            .partial_cmp(&a.score)
159            .unwrap_or(std::cmp::Ordering::Equal)
160            .then_with(|| a.id.cmp(&b.id))
161    });
162
163    // Assign ranks and truncate
164    reranked.truncate(k);
165    for (idx, result) in reranked.iter_mut().enumerate() {
166        result.rank = idx + 1;
167    }
168
169    reranked
170}
171
172/// Exact search using brute force comparison
173///
174/// Computes similarity against all vectors in the collection.
175/// Use for small collections or ground truth evaluation.
176///
177/// # Arguments
178/// * `query` - Query vector
179/// * `vectors` - Vector collection
180/// * `metric` - Similarity metric to use
181/// * `k` - Number of results to return
182///
183/// # Returns
184/// Top-k results ranked by similarity
185///
186/// # Examples
187///
188/// ```
189/// use embeddenator_retrieval::search::{exact_search};
190/// use embeddenator_retrieval::similarity::SimilarityMetric;
191/// use embeddenator_vsa::SparseVec;
192/// use std::collections::HashMap;
193///
194/// let mut vectors = HashMap::new();
195/// vectors.insert(1, SparseVec::from_data(b"document one"));
196/// vectors.insert(2, SparseVec::from_data(b"document two"));
197///
198/// let query = SparseVec::from_data(b"document");
199/// let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 5);
200///
201/// assert!(!results.is_empty());
202/// ```
203pub fn exact_search(
204    query: &SparseVec,
205    vectors: &HashMap<usize, SparseVec>,
206    metric: SimilarityMetric,
207    k: usize,
208) -> Vec<RankedResult> {
209    exact_search_impl(query, vectors, metric, k, false)
210}
211
212/// Exact search with parallel option
213///
214/// Same as `exact_search` but allows enabling parallel processing
215/// for large vector collections.
216pub fn exact_search_parallel(
217    query: &SparseVec,
218    vectors: &HashMap<usize, SparseVec>,
219    metric: SimilarityMetric,
220    k: usize,
221    parallel: bool,
222) -> Vec<RankedResult> {
223    exact_search_impl(query, vectors, metric, k, parallel)
224}
225
226fn exact_search_impl(
227    query: &SparseVec,
228    vectors: &HashMap<usize, SparseVec>,
229    metric: SimilarityMetric,
230    k: usize,
231    parallel: bool,
232) -> Vec<RankedResult> {
233    if k == 0 || vectors.is_empty() {
234        return Vec::new();
235    }
236
237    let mut results: Vec<RankedResult> = if parallel {
238        // Collect to vec first for parallel iteration
239        let vec_entries: Vec<_> = vectors.iter().collect();
240        vec_entries
241            .par_iter()
242            .map(|(id, vec)| {
243                let score = compute_similarity(query, vec, metric);
244                RankedResult {
245                    id: **id,
246                    score,
247                    approx_score: (score * 1000.0) as i32,
248                    rank: 0,
249                }
250            })
251            .collect()
252    } else {
253        vectors
254            .iter()
255            .map(|(id, vec)| {
256                let score = compute_similarity(query, vec, metric);
257                RankedResult {
258                    id: *id,
259                    score,
260                    approx_score: (score * 1000.0) as i32,
261                    rank: 0,
262                }
263            })
264            .collect()
265    };
266
267    results.sort_by(|a, b| {
268        b.score
269            .partial_cmp(&a.score)
270            .unwrap_or(std::cmp::Ordering::Equal)
271            .then_with(|| a.id.cmp(&b.id))
272    });
273
274    results.truncate(k);
275    for (idx, result) in results.iter_mut().enumerate() {
276        result.rank = idx + 1;
277    }
278
279    results
280}
281
282/// Approximate search using only the inverted index
283///
284/// Fast but less accurate. Good for initial filtering or when
285/// speed is more important than perfect ranking.
286///
287/// # Arguments
288/// * `query` - Query vector
289/// * `index` - Inverted index
290/// * `k` - Number of results to return
291///
292/// # Returns
293/// Top-k results ranked by approximate score
294///
295/// # Examples
296///
297/// ```
298/// use embeddenator_retrieval::search::approximate_search;
299/// use embeddenator_retrieval::TernaryInvertedIndex;
300/// use embeddenator_vsa::SparseVec;
301///
302/// let mut index = TernaryInvertedIndex::new();
303/// let vec1 = SparseVec::from_data(b"document one");
304/// index.add(1, &vec1);
305/// index.finalize();
306///
307/// let query = SparseVec::from_data(b"document");
308/// let results = approximate_search(&query, &index, 5);
309///
310/// assert!(!results.is_empty());
311/// ```
312pub fn approximate_search(
313    query: &SparseVec,
314    index: &TernaryInvertedIndex,
315    k: usize,
316) -> Vec<SearchResult> {
317    index.query_top_k(query, k)
318}
319
320/// Batch search - process multiple queries efficiently
321///
322/// # Arguments
323/// * `queries` - Multiple query vectors
324/// * `index` - Inverted index
325/// * `vectors` - Vector collection
326/// * `config` - Search configuration
327/// * `k` - Number of results per query
328///
329/// # Returns
330/// Results for each query
331///
332/// # Examples
333///
334/// ```
335/// use embeddenator_retrieval::search::{batch_search, SearchConfig};
336/// use embeddenator_retrieval::TernaryInvertedIndex;
337/// use embeddenator_vsa::SparseVec;
338/// use std::collections::HashMap;
339///
340/// let mut index = TernaryInvertedIndex::new();
341/// let mut vectors = HashMap::new();
342/// let vec1 = SparseVec::from_data(b"doc one");
343/// index.add(1, &vec1);
344/// index.finalize();
345/// vectors.insert(1, vec1);
346///
347/// let queries = vec![
348///     SparseVec::from_data(b"query1"),
349///     SparseVec::from_data(b"query2"),
350/// ];
351/// let config = SearchConfig::default();
352/// let results = batch_search(&queries, &index, &vectors, &config, 5);
353///
354/// assert_eq!(results.len(), 2);
355/// ```
356pub fn batch_search(
357    queries: &[SparseVec],
358    index: &TernaryInvertedIndex,
359    vectors: &HashMap<usize, SparseVec>,
360    config: &SearchConfig,
361    k: usize,
362) -> Vec<Vec<RankedResult>> {
363    if config.parallel {
364        // Process multiple queries concurrently
365        queries
366            .par_iter()
367            .map(|query| two_stage_search(query, index, vectors, config, k))
368            .collect()
369    } else {
370        queries
371            .iter()
372            .map(|query| two_stage_search(query, index, vectors, config, k))
373            .collect()
374    }
375}
376
377/// Compute recall@k metric for search quality evaluation
378///
379/// Compares approximate search results against ground truth.
380///
381/// # Arguments
382/// * `approx_results` - Results from approximate search
383/// * `exact_results` - Ground truth from exact search
384/// * `k` - Number of top results to consider
385///
386/// # Returns
387/// Recall score in [0, 1]
388pub fn compute_recall_at_k(
389    approx_results: &[SearchResult],
390    exact_results: &[RankedResult],
391    k: usize,
392) -> f64 {
393    if k == 0 || exact_results.is_empty() {
394        return 0.0;
395    }
396
397    let exact_ids: std::collections::HashSet<usize> =
398        exact_results.iter().take(k).map(|r| r.id).collect();
399
400    let matches = approx_results
401        .iter()
402        .take(k)
403        .filter(|r| exact_ids.contains(&r.id))
404        .count();
405
406    matches as f64 / k.min(exact_results.len()) as f64
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use embeddenator_vsa::ReversibleVSAConfig;
413
414    #[test]
415    fn test_two_stage_search() {
416        let config = ReversibleVSAConfig::default();
417        let mut index = TernaryInvertedIndex::new();
418        let mut vectors = HashMap::new();
419
420        let vec1 = SparseVec::encode_data(b"hello world", &config, None);
421        let vec2 = SparseVec::encode_data(b"goodbye world", &config, None);
422
423        index.add(1, &vec1);
424        index.add(2, &vec2);
425        index.finalize();
426
427        vectors.insert(1, vec1);
428        vectors.insert(2, vec2);
429
430        let query = SparseVec::encode_data(b"hello", &config, None);
431        let search_config = SearchConfig::default();
432        let results = two_stage_search(&query, &index, &vectors, &search_config, 2);
433
434        assert!(!results.is_empty());
435        assert_eq!(results[0].rank, 1);
436    }
437
438    #[test]
439    fn test_exact_search() {
440        let config = ReversibleVSAConfig::default();
441        let mut vectors = HashMap::new();
442
443        vectors.insert(1, SparseVec::encode_data(b"apple", &config, None));
444        vectors.insert(2, SparseVec::encode_data(b"banana", &config, None));
445        vectors.insert(3, SparseVec::encode_data(b"cherry", &config, None));
446
447        let query = SparseVec::encode_data(b"apple", &config, None);
448        let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 3);
449
450        assert_eq!(results.len(), 3);
451        assert_eq!(results[0].id, 1); // Should match apple best
452    }
453
454    #[test]
455    fn test_batch_search() {
456        let config = ReversibleVSAConfig::default();
457        let mut index = TernaryInvertedIndex::new();
458        let mut vectors = HashMap::new();
459
460        let vec1 = SparseVec::encode_data(b"doc1", &config, None);
461        let vec2 = SparseVec::encode_data(b"doc2", &config, None);
462
463        index.add(1, &vec1);
464        index.add(2, &vec2);
465        index.finalize();
466
467        vectors.insert(1, vec1);
468        vectors.insert(2, vec2);
469
470        let queries = vec![
471            SparseVec::encode_data(b"query1", &config, None),
472            SparseVec::encode_data(b"query2", &config, None),
473        ];
474
475        let search_config = SearchConfig::default();
476        let results = batch_search(&queries, &index, &vectors, &search_config, 2);
477
478        assert_eq!(results.len(), 2);
479    }
480
481    #[test]
482    fn test_recall_computation() {
483        let approx = vec![
484            SearchResult { id: 1, score: 100 },
485            SearchResult { id: 2, score: 90 },
486            SearchResult { id: 5, score: 80 },
487        ];
488
489        let exact = vec![
490            RankedResult {
491                id: 1,
492                score: 0.95,
493                approx_score: 100,
494                rank: 1,
495            },
496            RankedResult {
497                id: 3,
498                score: 0.90,
499                approx_score: 95,
500                rank: 2,
501            },
502            RankedResult {
503                id: 2,
504                score: 0.85,
505                approx_score: 90,
506                rank: 3,
507            },
508        ];
509
510        let recall = compute_recall_at_k(&approx, &exact, 3);
511        assert!((recall - 0.666).abs() < 0.01); // 2/3 match
512    }
513
514    #[test]
515    fn test_parallel_two_stage_search_matches_sequential() {
516        let config = ReversibleVSAConfig::default();
517        let mut index = TernaryInvertedIndex::new();
518        let mut vectors = HashMap::new();
519
520        // Build a corpus of 50 vectors for meaningful parallel work
521        for i in 0..50 {
522            let data = format!("document number {} with some content", i);
523            let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
524            index.add(i, &vec);
525            vectors.insert(i, vec);
526        }
527        index.finalize();
528
529        let query = SparseVec::encode_data(b"document number 25", &config, None);
530
531        let seq_config = SearchConfig {
532            parallel: false,
533            ..SearchConfig::default()
534        };
535        let par_config = SearchConfig {
536            parallel: true,
537            ..SearchConfig::default()
538        };
539
540        let seq_results = two_stage_search(&query, &index, &vectors, &seq_config, 10);
541        let par_results = two_stage_search(&query, &index, &vectors, &par_config, 10);
542
543        assert_eq!(seq_results.len(), par_results.len());
544        for (seq, par) in seq_results.iter().zip(par_results.iter()) {
545            assert_eq!(seq.id, par.id);
546            assert!((seq.score - par.score).abs() < 1e-10);
547            assert_eq!(seq.rank, par.rank);
548        }
549    }
550
551    #[test]
552    fn test_parallel_exact_search_matches_sequential() {
553        let config = ReversibleVSAConfig::default();
554        let mut vectors = HashMap::new();
555
556        for i in 0..100 {
557            let data = format!("item {} for testing parallel exact search", i);
558            vectors.insert(i, SparseVec::encode_data(data.as_bytes(), &config, None));
559        }
560
561        let query = SparseVec::encode_data(b"item 50 for testing", &config, None);
562
563        let seq_results =
564            exact_search_parallel(&query, &vectors, SimilarityMetric::Cosine, 20, false);
565        let par_results =
566            exact_search_parallel(&query, &vectors, SimilarityMetric::Cosine, 20, true);
567
568        assert_eq!(seq_results.len(), par_results.len());
569        for (seq, par) in seq_results.iter().zip(par_results.iter()) {
570            assert_eq!(seq.id, par.id);
571            assert!((seq.score - par.score).abs() < 1e-10);
572        }
573    }
574
575    #[test]
576    fn test_parallel_batch_search_matches_sequential() {
577        let config = ReversibleVSAConfig::default();
578        let mut index = TernaryInvertedIndex::new();
579        let mut vectors = HashMap::new();
580
581        for i in 0..30 {
582            let data = format!("batch doc {}", i);
583            let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
584            index.add(i, &vec);
585            vectors.insert(i, vec);
586        }
587        index.finalize();
588
589        let queries: Vec<SparseVec> = (0..10)
590            .map(|i| {
591                let data = format!("query {}", i);
592                SparseVec::encode_data(data.as_bytes(), &config, None)
593            })
594            .collect();
595
596        let seq_config = SearchConfig {
597            parallel: false,
598            ..SearchConfig::default()
599        };
600        let par_config = SearchConfig {
601            parallel: true,
602            ..SearchConfig::default()
603        };
604
605        let seq_results = batch_search(&queries, &index, &vectors, &seq_config, 5);
606        let par_results = batch_search(&queries, &index, &vectors, &par_config, 5);
607
608        assert_eq!(seq_results.len(), par_results.len());
609        for (seq_batch, par_batch) in seq_results.iter().zip(par_results.iter()) {
610            assert_eq!(seq_batch.len(), par_batch.len());
611            for (seq, par) in seq_batch.iter().zip(par_batch.iter()) {
612                assert_eq!(seq.id, par.id);
613                assert!((seq.score - par.score).abs() < 1e-10);
614            }
615        }
616    }
617}