embeddenator_retrieval/search.rs
1//! Search strategies for semantic retrieval
2//!
3//! This module implements various search algorithms:
4//! - Exact search (brute force)
5//! - Approximate search (inverted index)
6//! - Beam search (hierarchical)
7//! - Two-stage search (candidate generation + reranking)
8
9use crate::retrieval::{SearchResult, TernaryInvertedIndex};
10use crate::similarity::{compute_similarity, SimilarityMetric};
11use embeddenator_vsa::SparseVec;
12use std::collections::HashMap;
13
14/// Search strategy configuration
15#[derive(Debug, Clone)]
16pub struct SearchConfig {
17 /// Similarity metric for final ranking
18 pub metric: SimilarityMetric,
19 /// Number of candidates to generate before reranking
20 pub candidate_k: usize,
21 /// Beam width for hierarchical search
22 pub beam_width: usize,
23 /// Enable parallel search
24 pub parallel: bool,
25}
26
27impl Default for SearchConfig {
28 fn default() -> Self {
29 Self {
30 metric: SimilarityMetric::Cosine,
31 candidate_k: 200,
32 beam_width: 10,
33 parallel: false,
34 }
35 }
36}
37
38/// Search result with additional metadata
39#[derive(Debug, Clone, PartialEq)]
40pub struct RankedResult {
41 /// Document ID
42 pub id: usize,
43 /// Final similarity score
44 pub score: f64,
45 /// Approximate score from first stage
46 pub approx_score: i32,
47 /// Rank in results (1-indexed)
48 pub rank: usize,
49}
50
51/// Two-stage search: fast candidate generation + accurate reranking
52///
53/// This is the recommended strategy for most use cases. It combines the
54/// speed of inverted index search with the accuracy of exact similarity.
55///
56/// # Arguments
57/// * `query` - Query vector
58/// * `index` - Inverted index for candidate generation
59/// * `vectors` - Full vector collection for reranking
60/// * `config` - Search configuration
61/// * `k` - Number of final results to return
62///
63/// # Returns
64/// Top-k results ranked by exact similarity
65///
66/// # Examples
67///
68/// ```
69/// use embeddenator_retrieval::search::{two_stage_search, SearchConfig};
70/// use embeddenator_retrieval::TernaryInvertedIndex;
71/// use embeddenator_vsa::SparseVec;
72/// use std::collections::HashMap;
73///
74/// let mut index = TernaryInvertedIndex::new();
75/// let mut vectors = HashMap::new();
76///
77/// let vec1 = SparseVec::from_data(b"document one");
78/// let vec2 = SparseVec::from_data(b"document two");
79///
80/// index.add(1, &vec1);
81/// index.add(2, &vec2);
82/// index.finalize();
83///
84/// vectors.insert(1, vec1);
85/// vectors.insert(2, vec2);
86///
87/// let query = SparseVec::from_data(b"document");
88/// let config = SearchConfig::default();
89/// let results = two_stage_search(&query, &index, &vectors, &config, 5);
90///
91/// assert!(!results.is_empty());
92/// ```
93pub fn two_stage_search(
94 query: &SparseVec,
95 index: &TernaryInvertedIndex,
96 vectors: &HashMap<usize, SparseVec>,
97 config: &SearchConfig,
98 k: usize,
99) -> Vec<RankedResult> {
100 if k == 0 {
101 return Vec::new();
102 }
103
104 // Stage 1: Generate candidates using inverted index
105 let candidate_k = config.candidate_k.max(k);
106 let candidates = index.query_top_k(query, candidate_k);
107
108 // Stage 2: Rerank candidates with exact similarity
109 let mut reranked: Vec<RankedResult> = candidates
110 .iter()
111 .filter_map(|cand| {
112 vectors.get(&cand.id).map(|vec| {
113 let score = compute_similarity(query, vec, config.metric);
114 RankedResult {
115 id: cand.id,
116 score,
117 approx_score: cand.score,
118 rank: 0, // Will be set after sorting
119 }
120 })
121 })
122 .collect();
123
124 // Sort by similarity score
125 reranked.sort_by(|a, b| {
126 b.score
127 .partial_cmp(&a.score)
128 .unwrap_or(std::cmp::Ordering::Equal)
129 .then_with(|| a.id.cmp(&b.id))
130 });
131
132 // Assign ranks and truncate
133 reranked.truncate(k);
134 for (idx, result) in reranked.iter_mut().enumerate() {
135 result.rank = idx + 1;
136 }
137
138 reranked
139}
140
141/// Exact search using brute force comparison
142///
143/// Computes similarity against all vectors in the collection.
144/// Use for small collections or ground truth evaluation.
145///
146/// # Arguments
147/// * `query` - Query vector
148/// * `vectors` - Vector collection
149/// * `metric` - Similarity metric to use
150/// * `k` - Number of results to return
151///
152/// # Returns
153/// Top-k results ranked by similarity
154///
155/// # Examples
156///
157/// ```
158/// use embeddenator_retrieval::search::{exact_search};
159/// use embeddenator_retrieval::similarity::SimilarityMetric;
160/// use embeddenator_vsa::SparseVec;
161/// use std::collections::HashMap;
162///
163/// let mut vectors = HashMap::new();
164/// vectors.insert(1, SparseVec::from_data(b"document one"));
165/// vectors.insert(2, SparseVec::from_data(b"document two"));
166///
167/// let query = SparseVec::from_data(b"document");
168/// let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 5);
169///
170/// assert!(!results.is_empty());
171/// ```
172pub fn exact_search(
173 query: &SparseVec,
174 vectors: &HashMap<usize, SparseVec>,
175 metric: SimilarityMetric,
176 k: usize,
177) -> Vec<RankedResult> {
178 if k == 0 || vectors.is_empty() {
179 return Vec::new();
180 }
181
182 let mut results: Vec<RankedResult> = vectors
183 .iter()
184 .map(|(id, vec)| {
185 let score = compute_similarity(query, vec, metric);
186 RankedResult {
187 id: *id,
188 score,
189 approx_score: (score * 1000.0) as i32,
190 rank: 0,
191 }
192 })
193 .collect();
194
195 results.sort_by(|a, b| {
196 b.score
197 .partial_cmp(&a.score)
198 .unwrap_or(std::cmp::Ordering::Equal)
199 .then_with(|| a.id.cmp(&b.id))
200 });
201
202 results.truncate(k);
203 for (idx, result) in results.iter_mut().enumerate() {
204 result.rank = idx + 1;
205 }
206
207 results
208}
209
210/// Approximate search using only the inverted index
211///
212/// Fast but less accurate. Good for initial filtering or when
213/// speed is more important than perfect ranking.
214///
215/// # Arguments
216/// * `query` - Query vector
217/// * `index` - Inverted index
218/// * `k` - Number of results to return
219///
220/// # Returns
221/// Top-k results ranked by approximate score
222///
223/// # Examples
224///
225/// ```
226/// use embeddenator_retrieval::search::approximate_search;
227/// use embeddenator_retrieval::TernaryInvertedIndex;
228/// use embeddenator_vsa::SparseVec;
229///
230/// let mut index = TernaryInvertedIndex::new();
231/// let vec1 = SparseVec::from_data(b"document one");
232/// index.add(1, &vec1);
233/// index.finalize();
234///
235/// let query = SparseVec::from_data(b"document");
236/// let results = approximate_search(&query, &index, 5);
237///
238/// assert!(!results.is_empty());
239/// ```
240pub fn approximate_search(
241 query: &SparseVec,
242 index: &TernaryInvertedIndex,
243 k: usize,
244) -> Vec<SearchResult> {
245 index.query_top_k(query, k)
246}
247
248/// Batch search - process multiple queries efficiently
249///
250/// # Arguments
251/// * `queries` - Multiple query vectors
252/// * `index` - Inverted index
253/// * `vectors` - Vector collection
254/// * `config` - Search configuration
255/// * `k` - Number of results per query
256///
257/// # Returns
258/// Results for each query
259///
260/// # Examples
261///
262/// ```
263/// use embeddenator_retrieval::search::{batch_search, SearchConfig};
264/// use embeddenator_retrieval::TernaryInvertedIndex;
265/// use embeddenator_vsa::SparseVec;
266/// use std::collections::HashMap;
267///
268/// let mut index = TernaryInvertedIndex::new();
269/// let mut vectors = HashMap::new();
270/// let vec1 = SparseVec::from_data(b"doc one");
271/// index.add(1, &vec1);
272/// index.finalize();
273/// vectors.insert(1, vec1);
274///
275/// let queries = vec![
276/// SparseVec::from_data(b"query1"),
277/// SparseVec::from_data(b"query2"),
278/// ];
279/// let config = SearchConfig::default();
280/// let results = batch_search(&queries, &index, &vectors, &config, 5);
281///
282/// assert_eq!(results.len(), 2);
283/// ```
284pub fn batch_search(
285 queries: &[SparseVec],
286 index: &TernaryInvertedIndex,
287 vectors: &HashMap<usize, SparseVec>,
288 config: &SearchConfig,
289 k: usize,
290) -> Vec<Vec<RankedResult>> {
291 queries
292 .iter()
293 .map(|query| two_stage_search(query, index, vectors, config, k))
294 .collect()
295}
296
297/// Compute recall@k metric for search quality evaluation
298///
299/// Compares approximate search results against ground truth.
300///
301/// # Arguments
302/// * `approx_results` - Results from approximate search
303/// * `exact_results` - Ground truth from exact search
304/// * `k` - Number of top results to consider
305///
306/// # Returns
307/// Recall score in [0, 1]
308pub fn compute_recall_at_k(
309 approx_results: &[SearchResult],
310 exact_results: &[RankedResult],
311 k: usize,
312) -> f64 {
313 if k == 0 || exact_results.is_empty() {
314 return 0.0;
315 }
316
317 let exact_ids: std::collections::HashSet<usize> =
318 exact_results.iter().take(k).map(|r| r.id).collect();
319
320 let matches = approx_results
321 .iter()
322 .take(k)
323 .filter(|r| exact_ids.contains(&r.id))
324 .count();
325
326 matches as f64 / k.min(exact_results.len()) as f64
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use embeddenator_vsa::ReversibleVSAConfig;
333
334 #[test]
335 fn test_two_stage_search() {
336 let config = ReversibleVSAConfig::default();
337 let mut index = TernaryInvertedIndex::new();
338 let mut vectors = HashMap::new();
339
340 let vec1 = SparseVec::encode_data(b"hello world", &config, None);
341 let vec2 = SparseVec::encode_data(b"goodbye world", &config, None);
342
343 index.add(1, &vec1);
344 index.add(2, &vec2);
345 index.finalize();
346
347 vectors.insert(1, vec1);
348 vectors.insert(2, vec2);
349
350 let query = SparseVec::encode_data(b"hello", &config, None);
351 let search_config = SearchConfig::default();
352 let results = two_stage_search(&query, &index, &vectors, &search_config, 2);
353
354 assert!(!results.is_empty());
355 assert_eq!(results[0].rank, 1);
356 }
357
358 #[test]
359 fn test_exact_search() {
360 let config = ReversibleVSAConfig::default();
361 let mut vectors = HashMap::new();
362
363 vectors.insert(1, SparseVec::encode_data(b"apple", &config, None));
364 vectors.insert(2, SparseVec::encode_data(b"banana", &config, None));
365 vectors.insert(3, SparseVec::encode_data(b"cherry", &config, None));
366
367 let query = SparseVec::encode_data(b"apple", &config, None);
368 let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 3);
369
370 assert_eq!(results.len(), 3);
371 assert_eq!(results[0].id, 1); // Should match apple best
372 }
373
374 #[test]
375 fn test_batch_search() {
376 let config = ReversibleVSAConfig::default();
377 let mut index = TernaryInvertedIndex::new();
378 let mut vectors = HashMap::new();
379
380 let vec1 = SparseVec::encode_data(b"doc1", &config, None);
381 let vec2 = SparseVec::encode_data(b"doc2", &config, None);
382
383 index.add(1, &vec1);
384 index.add(2, &vec2);
385 index.finalize();
386
387 vectors.insert(1, vec1);
388 vectors.insert(2, vec2);
389
390 let queries = vec![
391 SparseVec::encode_data(b"query1", &config, None),
392 SparseVec::encode_data(b"query2", &config, None),
393 ];
394
395 let search_config = SearchConfig::default();
396 let results = batch_search(&queries, &index, &vectors, &search_config, 2);
397
398 assert_eq!(results.len(), 2);
399 }
400
401 #[test]
402 fn test_recall_computation() {
403 let approx = vec![
404 SearchResult { id: 1, score: 100 },
405 SearchResult { id: 2, score: 90 },
406 SearchResult { id: 5, score: 80 },
407 ];
408
409 let exact = vec![
410 RankedResult {
411 id: 1,
412 score: 0.95,
413 approx_score: 100,
414 rank: 1,
415 },
416 RankedResult {
417 id: 3,
418 score: 0.90,
419 approx_score: 95,
420 rank: 2,
421 },
422 RankedResult {
423 id: 2,
424 score: 0.85,
425 approx_score: 90,
426 rank: 3,
427 },
428 ];
429
430 let recall = compute_recall_at_k(&approx, &exact, 3);
431 assert!((recall - 0.666).abs() < 0.01); // 2/3 match
432 }
433}