Skip to main content

next_plaid/
search.rs

1//! Search functionality for PLAID
2
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15/// Per-token top-k heaps and per-centroid max scores from a batch of centroids.
16type ProbePartial = (
17    Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18    HashMap<usize, f32>,
19);
20
21/// Maximum number of documents to decompress concurrently during exact scoring.
22/// This limits peak memory usage from parallel decompression.
23/// With 128 docs × ~300KB per doc = ~40MB max concurrent decompression memory.
24const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26/// Search parameters
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29    /// Number of queries per batch
30    pub batch_size: usize,
31    /// Number of documents to re-rank with exact scores
32    pub n_full_scores: usize,
33    /// Number of final results to return per query
34    pub top_k: usize,
35    /// Number of IVF cells to probe during search
36    pub n_ivf_probe: usize,
37    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
38    /// Lower values use less memory but are slower. Default 100_000.
39    /// Only used when num_centroids > centroid_batch_size.
40    #[serde(default = "default_centroid_batch_size")]
41    pub centroid_batch_size: usize,
42    /// Centroid score threshold (t_cs) for centroid pruning.
43    /// A centroid is only included if its maximum score across all query tokens
44    /// meets or exceeds this threshold. Set to None to disable pruning.
45    /// Default: Some(0.4)
46    #[serde(default = "default_centroid_score_threshold")]
47    pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51    100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55    Some(0.4)
56}
57
58impl Default for SearchParameters {
59    fn default() -> Self {
60        Self {
61            batch_size: 2000,
62            n_full_scores: 4096,
63            top_k: 10,
64            n_ivf_probe: 8,
65            centroid_batch_size: default_centroid_batch_size(),
66            centroid_score_threshold: default_centroid_score_threshold(),
67        }
68    }
69}
70
71/// Result of a single query
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74    /// Query ID
75    pub query_id: usize,
76    /// Retrieved document IDs (ranked by relevance)
77    pub passage_ids: Vec<i64>,
78    /// Relevance scores for each document
79    pub scores: Vec<f32>,
80}
81
82/// Minimum matrix size (query_tokens * doc_tokens) to use CUDA.
83/// Below this threshold, CPU is faster due to GPU transfer overhead.
84/// Based on benchmarks: 128 * 1024 = 131072
85#[cfg(feature = "cuda")]
86const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
87
88/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
89/// with any document token, then sum across query tokens.
90///
91/// When the `cuda` feature is enabled and matrices are large enough,
92/// this function automatically uses CUDA acceleration.
93fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
94    // Try CUDA for large matrices
95    #[cfg(feature = "cuda")]
96    {
97        let matrix_size = query.nrows() * doc.nrows();
98        if matrix_size >= CUDA_COLBERT_MIN_SIZE {
99            if let Some(ctx) = crate::cuda::get_global_context() {
100                match crate::cuda::colbert_score_cuda(ctx, query, doc) {
101                    Ok(score) => return score,
102                    Err(_) => {
103                        // Silent fallback to CPU for scoring (happens frequently)
104                    }
105                }
106            }
107        }
108    }
109
110    // CPU implementation
111    colbert_score_cpu(query, doc)
112}
113
114/// CPU implementation of ColBERT MaxSim scoring.
115/// Uses SIMD-accelerated max reduction and BLAS for matrix multiplication.
116fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
117    maxsim::maxsim_score(query, doc)
118}
119
120/// Compute adaptive IVF probe for filtered search on memory-mapped index.
121///
122/// Ensures enough centroids are probed to cover at least `top_k` candidates from the subset.
123/// This function counts how many subset documents are in each centroid, then greedily
124/// selects centroids (by query similarity) until the cumulative document count reaches `top_k`.
125#[allow(clippy::too_many_arguments)]
126fn compute_adaptive_ivf_probe_mmap(
127    query_centroid_scores: &Array2<f32>,
128    mmap_codes: &crate::mmap::MmapNpyArray1I64,
129    doc_offsets: &[usize],
130    num_docs: usize,
131    subset: &[i64],
132    top_k: usize,
133    n_ivf_probe: usize,
134    centroid_score_threshold: Option<f32>,
135) -> Vec<usize> {
136    // Count unique docs per centroid for subset
137    let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
138    for &doc_id in subset {
139        let doc_idx = doc_id as usize;
140        if doc_idx < num_docs {
141            let start = doc_offsets[doc_idx];
142            let end = doc_offsets[doc_idx + 1];
143            let codes = mmap_codes.slice(start, end);
144            for &c in codes.iter() {
145                centroid_doc_counts
146                    .entry(c as usize)
147                    .or_default()
148                    .insert(doc_id);
149            }
150        }
151    }
152
153    if centroid_doc_counts.is_empty() {
154        return vec![];
155    }
156
157    // Score each centroid by max query-centroid similarity
158    let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
159        .into_iter()
160        .map(|(c, docs)| {
161            let max_score: f32 = query_centroid_scores
162                .axis_iter(Axis(0))
163                .map(|q| q[c])
164                .max_by(|a, b| a.partial_cmp(b).unwrap())
165                .unwrap_or(0.0);
166            (c, max_score, docs.len())
167        })
168        .collect();
169
170    // Apply threshold if set
171    if let Some(threshold) = centroid_score_threshold {
172        scored_centroids.retain(|(_, score, _)| *score >= threshold);
173    }
174
175    // Sort by score descending
176    scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
177
178    // Greedily select centroids until we cover top_k candidates
179    let mut cumulative_docs = 0;
180    let mut n_probe = 0;
181
182    for (_, _, doc_count) in &scored_centroids {
183        cumulative_docs += doc_count;
184        n_probe += 1;
185        // Stop when we have enough coverage AND met minimum probe requirement
186        if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
187            break;
188        }
189    }
190
191    // Ensure at least n_ivf_probe centroids (unless fewer exist)
192    n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
193
194    scored_centroids
195        .iter()
196        .take(n_probe)
197        .map(|(c, _, _)| *c)
198        .collect()
199}
200
201/// Wrapper for f32 to use with BinaryHeap (implements Ord)
202#[derive(Clone, Copy, PartialEq)]
203struct OrdF32(f32);
204
205impl Eq for OrdF32 {}
206
207impl PartialOrd for OrdF32 {
208    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
209        Some(self.cmp(other))
210    }
211}
212
213impl Ord for OrdF32 {
214    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
215        self.0
216            .partial_cmp(&other.0)
217            .unwrap_or(std::cmp::Ordering::Equal)
218    }
219}
220
221/// Batched IVF probing for memory-efficient centroid scoring.
222///
223/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
224/// Returns the union of top centroids across all query tokens.
225/// If a threshold is provided, filters out centroids where max score < threshold.
226fn ivf_probe_batched(
227    query: &Array2<f32>,
228    centroids: &CentroidStore,
229    n_probe: usize,
230    batch_size: usize,
231    centroid_score_threshold: Option<f32>,
232) -> Vec<usize> {
233    let num_centroids = centroids.nrows();
234    let num_tokens = query.nrows();
235
236    // Build batch ranges for parallel processing
237    let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
238        .step_by(batch_size)
239        .map(|start| (start, (start + batch_size).min(num_centroids)))
240        .collect();
241
242    // Process centroid batches in parallel. Each rayon thread computes a GEMM
243    // (with single-threaded BLAS via OPENBLAS_NUM_THREADS=1) and maintains local
244    // per-token top-k heaps. Memory is bounded: rayon's thread pool ensures at most
245    // num_cpus batch_scores matrices (each batch_size × num_tokens × 4 bytes) exist
246    // simultaneously, same as the sequential approach where num_cpus queries each
247    // process one batch at a time.
248    let local_results: Vec<ProbePartial> = batch_ranges
249        .par_iter()
250        .map(|&(batch_start, batch_end)| {
251            let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
252                .map(|_| BinaryHeap::with_capacity(n_probe + 1))
253                .collect();
254            let mut max_scores: HashMap<usize, f32> = HashMap::new();
255
256            // Get batch view (zero-copy from mmap)
257            let batch_centroids = centroids.slice_rows(batch_start, batch_end);
258
259            // Compute scores: [num_tokens, batch_size] — single-threaded BLAS
260            let batch_scores = query.dot(&batch_centroids.t());
261
262            // Update local heaps with this batch's scores
263            for (q_idx, heap) in heaps.iter_mut().enumerate() {
264                for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
265                    let global_c = batch_start + local_c;
266                    let entry = (Reverse(OrdF32(score)), global_c);
267
268                    if heap.len() < n_probe {
269                        heap.push(entry);
270                        max_scores
271                            .entry(global_c)
272                            .and_modify(|s| *s = s.max(score))
273                            .or_insert(score);
274                    } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
275                        if score > min_score {
276                            heap.pop();
277                            heap.push(entry);
278                            max_scores
279                                .entry(global_c)
280                                .and_modify(|s| *s = s.max(score))
281                                .or_insert(score);
282                        }
283                    }
284                }
285            }
286
287            (heaps, max_scores)
288        })
289        .collect();
290
291    // Merge local heaps into final result (lightweight: each heap has at most
292    // n_probe entries, and there are num_batches heaps per token to merge)
293    let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
294        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
295        .collect();
296    let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
297
298    for (local_heaps, local_max_scores) in local_results {
299        for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
300            for entry in local_heap {
301                let (Reverse(OrdF32(score)), _) = entry;
302                if final_heaps[q_idx].len() < n_probe {
303                    final_heaps[q_idx].push(entry);
304                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
305                    if score > min_score {
306                        final_heaps[q_idx].pop();
307                        final_heaps[q_idx].push(entry);
308                    }
309                }
310            }
311        }
312        for (c, score) in local_max_scores {
313            final_max_scores
314                .entry(c)
315                .and_modify(|s| *s = s.max(score))
316                .or_insert(score);
317        }
318    }
319
320    // Union top centroids across all query tokens
321    let mut selected: HashSet<usize> = HashSet::new();
322    for heap in final_heaps {
323        for (_, c) in heap {
324            selected.insert(c);
325        }
326    }
327
328    // Apply centroid score threshold if set
329    if let Some(threshold) = centroid_score_threshold {
330        selected.retain(|c| {
331            final_max_scores
332                .get(c)
333                .copied()
334                .unwrap_or(f32::NEG_INFINITY)
335                >= threshold
336        });
337    }
338
339    selected.into_iter().collect()
340}
341
342/// Build sparse centroid scores for a set of centroid IDs.
343///
344/// Returns a HashMap mapping centroid_id -> query scores array.
345fn build_sparse_centroid_scores(
346    query: &Array2<f32>,
347    centroids: &CentroidStore,
348    centroid_ids: &HashSet<usize>,
349) -> HashMap<usize, Array1<f32>> {
350    centroid_ids
351        .iter()
352        .map(|&c| {
353            let centroid = centroids.row(c);
354            let scores: Array1<f32> = query.dot(&centroid);
355            (c, scores)
356        })
357        .collect()
358}
359
360/// Compute approximate scores using sparse centroid score lookup.
361fn approximate_score_sparse(
362    sparse_scores: &HashMap<usize, Array1<f32>>,
363    doc_codes: &[usize],
364    num_query_tokens: usize,
365) -> f32 {
366    let mut score = 0.0;
367
368    // For each query token
369    for q_idx in 0..num_query_tokens {
370        let mut max_score = f32::NEG_INFINITY;
371
372        // For each document token's code
373        for &code in doc_codes.iter() {
374            if let Some(centroid_scores) = sparse_scores.get(&code) {
375                let centroid_score = centroid_scores[q_idx];
376                if centroid_score > max_score {
377                    max_score = centroid_score;
378                }
379            }
380        }
381
382        if max_score > f32::NEG_INFINITY {
383            score += max_score;
384        }
385    }
386
387    score
388}
389
390/// Compute approximate scores for mmap index using code lookups.
391fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
392    let mut score = 0.0;
393
394    for q_idx in 0..query_centroid_scores.nrows() {
395        let mut max_score = f32::NEG_INFINITY;
396
397        for &code in doc_codes.iter() {
398            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
399            if centroid_score > max_score {
400                max_score = centroid_score;
401            }
402        }
403
404        if max_score > f32::NEG_INFINITY {
405            score += max_score;
406        }
407    }
408
409    score
410}
411
412/// Search a memory-mapped index for a single query.
413pub fn search_one_mmap(
414    index: &crate::index::MmapIndex,
415    query: &Array2<f32>,
416    params: &SearchParameters,
417    subset: Option<&[i64]>,
418) -> Result<QueryResult> {
419    let num_centroids = index.codec.num_centroids();
420    let num_query_tokens = query.nrows();
421
422    // Decide whether to use batched mode for memory efficiency
423    let use_batched = params.centroid_batch_size > 0
424        && num_centroids > params.centroid_batch_size
425        && subset.is_none();
426
427    if use_batched {
428        // Batched path: memory-efficient IVF probing for large centroid counts
429        return search_one_mmap_batched(index, query, params);
430    }
431
432    // Standard path: compute full query-centroid scores upfront
433    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
434
435    // Find top IVF cells to probe
436    let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
437        // Use adaptive IVF probing that ensures enough centroids to cover top_k candidates
438        compute_adaptive_ivf_probe_mmap(
439            &query_centroid_scores,
440            &index.mmap_codes,
441            index.doc_offsets.as_slice().unwrap(),
442            index.doc_lengths.len(),
443            subset_docs,
444            params.top_k,
445            params.n_ivf_probe,
446            params.centroid_score_threshold,
447        )
448    } else {
449        // Standard path: select top-k centroids per query token
450        let mut selected_centroids = HashSet::new();
451
452        for q_idx in 0..num_query_tokens {
453            let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
454                .map(|c| (c, query_centroid_scores[[q_idx, c]]))
455                .collect();
456
457            // Partial selection: O(K) average instead of O(K log K) for full sort
458            // After this, the top n_ivf_probe elements are in positions 0..n_ivf_probe
459            // (but not sorted among themselves - which is fine since we use a HashSet)
460            if centroid_scores.len() > params.n_ivf_probe {
461                centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
462                    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
463                });
464            }
465
466            for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
467                selected_centroids.insert(*c);
468            }
469        }
470
471        // Apply centroid score threshold: filter out centroids where max score < threshold
472        if let Some(threshold) = params.centroid_score_threshold {
473            selected_centroids.retain(|&c| {
474                let max_score: f32 = (0..num_query_tokens)
475                    .map(|q_idx| query_centroid_scores[[q_idx, c]])
476                    .max_by(|a, b| a.partial_cmp(b).unwrap())
477                    .unwrap_or(f32::NEG_INFINITY);
478                max_score >= threshold
479            });
480        }
481
482        selected_centroids.into_iter().collect()
483    };
484
485    // Get candidate documents from IVF
486    let mut candidates = index.get_candidates(&cells_to_probe);
487
488    // Filter by subset if provided
489    if let Some(subset_docs) = subset {
490        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
491        candidates.retain(|&c| subset_set.contains(&c));
492    }
493
494    if candidates.is_empty() {
495        return Ok(QueryResult {
496            query_id: 0,
497            passage_ids: vec![],
498            scores: vec![],
499        });
500    }
501
502    // Compute approximate scores
503    let mut approx_scores: Vec<(i64, f32)> = candidates
504        .par_iter()
505        .map(|&doc_id| {
506            let start = index.doc_offsets[doc_id as usize];
507            let end = index.doc_offsets[doc_id as usize + 1];
508            let codes = index.mmap_codes.slice(start, end);
509            let score = approximate_score_mmap(&query_centroid_scores, &codes);
510            (doc_id, score)
511        })
512        .collect();
513
514    // Sort by approximate score and take top candidates
515    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
516    let top_candidates: Vec<i64> = approx_scores
517        .iter()
518        .take(params.n_full_scores)
519        .map(|(id, _)| *id)
520        .collect();
521
522    // Further reduce for full decompression
523    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
524    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
525
526    if to_decompress.is_empty() {
527        return Ok(QueryResult {
528            query_id: 0,
529            passage_ids: vec![],
530            scores: vec![],
531        });
532    }
533
534    // Compute exact scores with decompressed embeddings
535    // Use chunked processing to limit concurrent memory from parallel decompression
536    let mut exact_scores: Vec<(i64, f32)> = to_decompress
537        .par_chunks(DECOMPRESS_CHUNK_SIZE)
538        .flat_map(|chunk| {
539            chunk
540                .iter()
541                .filter_map(|&doc_id| {
542                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
543                    let score = colbert_score(&query.view(), &doc_embeddings.view());
544                    Some((doc_id, score))
545                })
546                .collect::<Vec<_>>()
547        })
548        .collect();
549
550    // Sort by exact score
551    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
552
553    // Return top-k results
554    let result_count = params.top_k.min(exact_scores.len());
555    let passage_ids: Vec<i64> = exact_scores
556        .iter()
557        .take(result_count)
558        .map(|(id, _)| *id)
559        .collect();
560    let scores: Vec<f32> = exact_scores
561        .iter()
562        .take(result_count)
563        .map(|(_, s)| *s)
564        .collect();
565
566    Ok(QueryResult {
567        query_id: 0,
568        passage_ids,
569        scores,
570    })
571}
572
573/// Memory-efficient batched search for MmapIndex with large centroid counts.
574///
575/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
576fn search_one_mmap_batched(
577    index: &crate::index::MmapIndex,
578    query: &Array2<f32>,
579    params: &SearchParameters,
580) -> Result<QueryResult> {
581    let num_query_tokens = query.nrows();
582
583    // Step 1: Batched IVF probing
584    let cells_to_probe = ivf_probe_batched(
585        query,
586        &index.codec.centroids,
587        params.n_ivf_probe,
588        params.centroid_batch_size,
589        params.centroid_score_threshold,
590    );
591
592    // Step 2: Get candidate documents from IVF
593    let candidates = index.get_candidates(&cells_to_probe);
594
595    if candidates.is_empty() {
596        return Ok(QueryResult {
597            query_id: 0,
598            passage_ids: vec![],
599            scores: vec![],
600        });
601    }
602
603    // Step 3: Collect unique centroids from all candidate documents
604    let mut unique_centroids: HashSet<usize> = HashSet::new();
605    for &doc_id in &candidates {
606        let start = index.doc_offsets[doc_id as usize];
607        let end = index.doc_offsets[doc_id as usize + 1];
608        let codes = index.mmap_codes.slice(start, end);
609        for &code in codes.iter() {
610            unique_centroids.insert(code as usize);
611        }
612    }
613
614    // Step 4: Build sparse centroid scores
615    let sparse_scores =
616        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
617
618    // Step 5: Compute approximate scores using sparse lookup
619    let mut approx_scores: Vec<(i64, f32)> = candidates
620        .par_iter()
621        .map(|&doc_id| {
622            let start = index.doc_offsets[doc_id as usize];
623            let end = index.doc_offsets[doc_id as usize + 1];
624            let codes = index.mmap_codes.slice(start, end);
625            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
626            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
627            (doc_id, score)
628        })
629        .collect();
630
631    // Sort by approximate score and take top candidates
632    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
633    let top_candidates: Vec<i64> = approx_scores
634        .iter()
635        .take(params.n_full_scores)
636        .map(|(id, _)| *id)
637        .collect();
638
639    // Further reduce for full decompression
640    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
641    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
642
643    if to_decompress.is_empty() {
644        return Ok(QueryResult {
645            query_id: 0,
646            passage_ids: vec![],
647            scores: vec![],
648        });
649    }
650
651    // Compute exact scores with decompressed embeddings
652    // Use chunked processing to limit concurrent memory from parallel decompression
653    let mut exact_scores: Vec<(i64, f32)> = to_decompress
654        .par_chunks(DECOMPRESS_CHUNK_SIZE)
655        .flat_map(|chunk| {
656            chunk
657                .iter()
658                .filter_map(|&doc_id| {
659                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
660                    let score = colbert_score(&query.view(), &doc_embeddings.view());
661                    Some((doc_id, score))
662                })
663                .collect::<Vec<_>>()
664        })
665        .collect();
666
667    // Sort by exact score
668    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
669
670    // Return top-k results
671    let result_count = params.top_k.min(exact_scores.len());
672    let passage_ids: Vec<i64> = exact_scores
673        .iter()
674        .take(result_count)
675        .map(|(id, _)| *id)
676        .collect();
677    let scores: Vec<f32> = exact_scores
678        .iter()
679        .take(result_count)
680        .map(|(_, s)| *s)
681        .collect();
682
683    Ok(QueryResult {
684        query_id: 0,
685        passage_ids,
686        scores,
687    })
688}
689
690/// Search a memory-mapped index for multiple queries.
691pub fn search_many_mmap(
692    index: &crate::index::MmapIndex,
693    queries: &[Array2<f32>],
694    params: &SearchParameters,
695    parallel: bool,
696    subset: Option<&[i64]>,
697) -> Result<Vec<QueryResult>> {
698    if parallel {
699        let results: Vec<QueryResult> = queries
700            .par_iter()
701            .enumerate()
702            .map(|(i, query)| {
703                let mut result =
704                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
705                        query_id: i,
706                        passage_ids: vec![],
707                        scores: vec![],
708                    });
709                result.query_id = i;
710                result
711            })
712            .collect();
713        Ok(results)
714    } else {
715        let mut results = Vec::with_capacity(queries.len());
716        for (i, query) in queries.iter().enumerate() {
717            let mut result = search_one_mmap(index, query, params, subset)?;
718            result.query_id = i;
719            results.push(result);
720        }
721        Ok(results)
722    }
723}
724
725/// Alias type for search result (for API compatibility)
726pub type SearchResult = QueryResult;
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731
732    #[test]
733    fn test_colbert_score() {
734        // Query with 2 tokens, dim 4
735        let query =
736            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
737
738        // Document with 3 tokens
739        let doc = Array2::from_shape_vec(
740            (3, 4),
741            vec![
742                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
743                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
744                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
745            ],
746        )
747        .unwrap();
748
749        let score = colbert_score(&query.view(), &doc.view());
750        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
751        // Total: 0.8 + 0.9 = 1.7
752        assert!((score - 1.7).abs() < 1e-5);
753    }
754
755    #[test]
756    fn test_search_params_default() {
757        let params = SearchParameters::default();
758        assert_eq!(params.batch_size, 2000);
759        assert_eq!(params.n_full_scores, 4096);
760        assert_eq!(params.top_k, 10);
761        assert_eq!(params.n_ivf_probe, 8);
762        assert_eq!(params.centroid_score_threshold, Some(0.4));
763    }
764}