1use std::collections::HashMap;
12
13use rayon::prelude::*;
14
15use crate::retrieval::{SearchResult, TernaryInvertedIndex};
16use crate::similarity::{compute_similarity, SimilarityMetric};
17use embeddenator_vsa::SparseVec;
18
19#[derive(Debug, Clone)]
21pub struct SearchConfig {
22 pub metric: SimilarityMetric,
24 pub candidate_k: usize,
26 pub beam_width: usize,
32 pub parallel: bool,
34}
35
36impl Default for SearchConfig {
37 fn default() -> Self {
38 Self {
39 metric: SimilarityMetric::Cosine,
40 candidate_k: 200,
41 beam_width: 10,
42 parallel: false,
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct RankedResult {
50 pub id: usize,
52 pub score: f64,
54 pub approx_score: i32,
56 pub rank: usize,
58}
59
60pub fn two_stage_search(
103 query: &SparseVec,
104 index: &TernaryInvertedIndex,
105 vectors: &HashMap<usize, SparseVec>,
106 config: &SearchConfig,
107 k: usize,
108) -> Vec<RankedResult> {
109 if k == 0 {
110 return Vec::new();
111 }
112
113 let candidate_k = config.candidate_k.max(k);
115 let candidates = index.query_top_k(query, candidate_k);
116
117 let mut reranked: Vec<RankedResult> = if config.parallel {
120 let candidates_with_vecs: Vec<_> = candidates
122 .iter()
123 .filter_map(|cand| vectors.get(&cand.id).map(|vec| (cand, vec)))
124 .collect();
125
126 candidates_with_vecs
127 .par_iter()
128 .map(|(cand, vec)| {
129 let score = compute_similarity(query, vec, config.metric);
130 RankedResult {
131 id: cand.id,
132 score,
133 approx_score: cand.score,
134 rank: 0, }
136 })
137 .collect()
138 } else {
139 candidates
140 .iter()
141 .filter_map(|cand| {
142 vectors.get(&cand.id).map(|vec| {
143 let score = compute_similarity(query, vec, config.metric);
144 RankedResult {
145 id: cand.id,
146 score,
147 approx_score: cand.score,
148 rank: 0, }
150 })
151 })
152 .collect()
153 };
154
155 reranked.sort_by(|a, b| {
157 b.score
158 .partial_cmp(&a.score)
159 .unwrap_or(std::cmp::Ordering::Equal)
160 .then_with(|| a.id.cmp(&b.id))
161 });
162
163 reranked.truncate(k);
165 for (idx, result) in reranked.iter_mut().enumerate() {
166 result.rank = idx + 1;
167 }
168
169 reranked
170}
171
172pub fn exact_search(
204 query: &SparseVec,
205 vectors: &HashMap<usize, SparseVec>,
206 metric: SimilarityMetric,
207 k: usize,
208) -> Vec<RankedResult> {
209 exact_search_impl(query, vectors, metric, k, false)
210}
211
212pub fn exact_search_parallel(
217 query: &SparseVec,
218 vectors: &HashMap<usize, SparseVec>,
219 metric: SimilarityMetric,
220 k: usize,
221 parallel: bool,
222) -> Vec<RankedResult> {
223 exact_search_impl(query, vectors, metric, k, parallel)
224}
225
226fn exact_search_impl(
227 query: &SparseVec,
228 vectors: &HashMap<usize, SparseVec>,
229 metric: SimilarityMetric,
230 k: usize,
231 parallel: bool,
232) -> Vec<RankedResult> {
233 if k == 0 || vectors.is_empty() {
234 return Vec::new();
235 }
236
237 let mut results: Vec<RankedResult> = if parallel {
238 let vec_entries: Vec<_> = vectors.iter().collect();
240 vec_entries
241 .par_iter()
242 .map(|(id, vec)| {
243 let score = compute_similarity(query, vec, metric);
244 RankedResult {
245 id: **id,
246 score,
247 approx_score: (score * 1000.0) as i32,
248 rank: 0,
249 }
250 })
251 .collect()
252 } else {
253 vectors
254 .iter()
255 .map(|(id, vec)| {
256 let score = compute_similarity(query, vec, metric);
257 RankedResult {
258 id: *id,
259 score,
260 approx_score: (score * 1000.0) as i32,
261 rank: 0,
262 }
263 })
264 .collect()
265 };
266
267 results.sort_by(|a, b| {
268 b.score
269 .partial_cmp(&a.score)
270 .unwrap_or(std::cmp::Ordering::Equal)
271 .then_with(|| a.id.cmp(&b.id))
272 });
273
274 results.truncate(k);
275 for (idx, result) in results.iter_mut().enumerate() {
276 result.rank = idx + 1;
277 }
278
279 results
280}
281
282pub fn approximate_search(
313 query: &SparseVec,
314 index: &TernaryInvertedIndex,
315 k: usize,
316) -> Vec<SearchResult> {
317 index.query_top_k(query, k)
318}
319
320pub fn batch_search(
357 queries: &[SparseVec],
358 index: &TernaryInvertedIndex,
359 vectors: &HashMap<usize, SparseVec>,
360 config: &SearchConfig,
361 k: usize,
362) -> Vec<Vec<RankedResult>> {
363 if config.parallel {
364 queries
366 .par_iter()
367 .map(|query| two_stage_search(query, index, vectors, config, k))
368 .collect()
369 } else {
370 queries
371 .iter()
372 .map(|query| two_stage_search(query, index, vectors, config, k))
373 .collect()
374 }
375}
376
377pub fn compute_recall_at_k(
389 approx_results: &[SearchResult],
390 exact_results: &[RankedResult],
391 k: usize,
392) -> f64 {
393 if k == 0 || exact_results.is_empty() {
394 return 0.0;
395 }
396
397 let exact_ids: std::collections::HashSet<usize> =
398 exact_results.iter().take(k).map(|r| r.id).collect();
399
400 let matches = approx_results
401 .iter()
402 .take(k)
403 .filter(|r| exact_ids.contains(&r.id))
404 .count();
405
406 matches as f64 / k.min(exact_results.len()) as f64
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use embeddenator_vsa::ReversibleVSAConfig;
413
414 #[test]
415 fn test_two_stage_search() {
416 let config = ReversibleVSAConfig::default();
417 let mut index = TernaryInvertedIndex::new();
418 let mut vectors = HashMap::new();
419
420 let vec1 = SparseVec::encode_data(b"hello world", &config, None);
421 let vec2 = SparseVec::encode_data(b"goodbye world", &config, None);
422
423 index.add(1, &vec1);
424 index.add(2, &vec2);
425 index.finalize();
426
427 vectors.insert(1, vec1);
428 vectors.insert(2, vec2);
429
430 let query = SparseVec::encode_data(b"hello", &config, None);
431 let search_config = SearchConfig::default();
432 let results = two_stage_search(&query, &index, &vectors, &search_config, 2);
433
434 assert!(!results.is_empty());
435 assert_eq!(results[0].rank, 1);
436 }
437
438 #[test]
439 fn test_exact_search() {
440 let config = ReversibleVSAConfig::default();
441 let mut vectors = HashMap::new();
442
443 vectors.insert(1, SparseVec::encode_data(b"apple", &config, None));
444 vectors.insert(2, SparseVec::encode_data(b"banana", &config, None));
445 vectors.insert(3, SparseVec::encode_data(b"cherry", &config, None));
446
447 let query = SparseVec::encode_data(b"apple", &config, None);
448 let results = exact_search(&query, &vectors, SimilarityMetric::Cosine, 3);
449
450 assert_eq!(results.len(), 3);
451 assert_eq!(results[0].id, 1); }
453
454 #[test]
455 fn test_batch_search() {
456 let config = ReversibleVSAConfig::default();
457 let mut index = TernaryInvertedIndex::new();
458 let mut vectors = HashMap::new();
459
460 let vec1 = SparseVec::encode_data(b"doc1", &config, None);
461 let vec2 = SparseVec::encode_data(b"doc2", &config, None);
462
463 index.add(1, &vec1);
464 index.add(2, &vec2);
465 index.finalize();
466
467 vectors.insert(1, vec1);
468 vectors.insert(2, vec2);
469
470 let queries = vec![
471 SparseVec::encode_data(b"query1", &config, None),
472 SparseVec::encode_data(b"query2", &config, None),
473 ];
474
475 let search_config = SearchConfig::default();
476 let results = batch_search(&queries, &index, &vectors, &search_config, 2);
477
478 assert_eq!(results.len(), 2);
479 }
480
481 #[test]
482 fn test_recall_computation() {
483 let approx = vec![
484 SearchResult { id: 1, score: 100 },
485 SearchResult { id: 2, score: 90 },
486 SearchResult { id: 5, score: 80 },
487 ];
488
489 let exact = vec![
490 RankedResult {
491 id: 1,
492 score: 0.95,
493 approx_score: 100,
494 rank: 1,
495 },
496 RankedResult {
497 id: 3,
498 score: 0.90,
499 approx_score: 95,
500 rank: 2,
501 },
502 RankedResult {
503 id: 2,
504 score: 0.85,
505 approx_score: 90,
506 rank: 3,
507 },
508 ];
509
510 let recall = compute_recall_at_k(&approx, &exact, 3);
511 assert!((recall - 0.666).abs() < 0.01); }
513
514 #[test]
515 fn test_parallel_two_stage_search_matches_sequential() {
516 let config = ReversibleVSAConfig::default();
517 let mut index = TernaryInvertedIndex::new();
518 let mut vectors = HashMap::new();
519
520 for i in 0..50 {
522 let data = format!("document number {} with some content", i);
523 let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
524 index.add(i, &vec);
525 vectors.insert(i, vec);
526 }
527 index.finalize();
528
529 let query = SparseVec::encode_data(b"document number 25", &config, None);
530
531 let seq_config = SearchConfig {
532 parallel: false,
533 ..SearchConfig::default()
534 };
535 let par_config = SearchConfig {
536 parallel: true,
537 ..SearchConfig::default()
538 };
539
540 let seq_results = two_stage_search(&query, &index, &vectors, &seq_config, 10);
541 let par_results = two_stage_search(&query, &index, &vectors, &par_config, 10);
542
543 assert_eq!(seq_results.len(), par_results.len());
544 for (seq, par) in seq_results.iter().zip(par_results.iter()) {
545 assert_eq!(seq.id, par.id);
546 assert!((seq.score - par.score).abs() < 1e-10);
547 assert_eq!(seq.rank, par.rank);
548 }
549 }
550
551 #[test]
552 fn test_parallel_exact_search_matches_sequential() {
553 let config = ReversibleVSAConfig::default();
554 let mut vectors = HashMap::new();
555
556 for i in 0..100 {
557 let data = format!("item {} for testing parallel exact search", i);
558 vectors.insert(i, SparseVec::encode_data(data.as_bytes(), &config, None));
559 }
560
561 let query = SparseVec::encode_data(b"item 50 for testing", &config, None);
562
563 let seq_results =
564 exact_search_parallel(&query, &vectors, SimilarityMetric::Cosine, 20, false);
565 let par_results =
566 exact_search_parallel(&query, &vectors, SimilarityMetric::Cosine, 20, true);
567
568 assert_eq!(seq_results.len(), par_results.len());
569 for (seq, par) in seq_results.iter().zip(par_results.iter()) {
570 assert_eq!(seq.id, par.id);
571 assert!((seq.score - par.score).abs() < 1e-10);
572 }
573 }
574
575 #[test]
576 fn test_parallel_batch_search_matches_sequential() {
577 let config = ReversibleVSAConfig::default();
578 let mut index = TernaryInvertedIndex::new();
579 let mut vectors = HashMap::new();
580
581 for i in 0..30 {
582 let data = format!("batch doc {}", i);
583 let vec = SparseVec::encode_data(data.as_bytes(), &config, None);
584 index.add(i, &vec);
585 vectors.insert(i, vec);
586 }
587 index.finalize();
588
589 let queries: Vec<SparseVec> = (0..10)
590 .map(|i| {
591 let data = format!("query {}", i);
592 SparseVec::encode_data(data.as_bytes(), &config, None)
593 })
594 .collect();
595
596 let seq_config = SearchConfig {
597 parallel: false,
598 ..SearchConfig::default()
599 };
600 let par_config = SearchConfig {
601 parallel: true,
602 ..SearchConfig::default()
603 };
604
605 let seq_results = batch_search(&queries, &index, &vectors, &seq_config, 5);
606 let par_results = batch_search(&queries, &index, &vectors, &par_config, 5);
607
608 assert_eq!(seq_results.len(), par_results.len());
609 for (seq_batch, par_batch) in seq_results.iter().zip(par_results.iter()) {
610 assert_eq!(seq_batch.len(), par_batch.len());
611 for (seq, par) in seq_batch.iter().zip(par_batch.iter()) {
612 assert_eq!(seq.id, par.id);
613 assert!((seq.score - par.score).abs() < 1e-10);
614 }
615 }
616 }
617}