next_plaid/
search.rs

1//! Search functionality for PLAID
2
3#[cfg(feature = "npy")]
4use std::cmp::Reverse;
5#[cfg(feature = "npy")]
6use std::collections::{BinaryHeap, HashMap, HashSet};
7
8#[cfg(feature = "npy")]
9use ndarray::Array1;
10use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "npy")]
15use crate::codec::CentroidStore;
16use crate::error::Result;
17use crate::index::Index;
18
19/// Search parameters
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SearchParameters {
22    /// Number of queries per batch
23    pub batch_size: usize,
24    /// Number of documents to re-rank with exact scores
25    pub n_full_scores: usize,
26    /// Number of final results to return per query
27    pub top_k: usize,
28    /// Number of IVF cells to probe during search
29    pub n_ivf_probe: usize,
30    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
31    /// Lower values use less memory but are slower. Default 100_000.
32    /// Only used when num_centroids > centroid_batch_size.
33    #[serde(default = "default_centroid_batch_size")]
34    pub centroid_batch_size: usize,
35}
36
37fn default_centroid_batch_size() -> usize {
38    100_000
39}
40
41impl Default for SearchParameters {
42    fn default() -> Self {
43        Self {
44            batch_size: 2000,
45            n_full_scores: 4096,
46            top_k: 10,
47            n_ivf_probe: 8,
48            centroid_batch_size: default_centroid_batch_size(),
49        }
50    }
51}
52
53/// Result of a single query
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QueryResult {
56    /// Query ID
57    pub query_id: usize,
58    /// Retrieved document IDs (ranked by relevance)
59    pub passage_ids: Vec<i64>,
60    /// Relevance scores for each document
61    pub scores: Vec<f32>,
62}
63
64/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
65/// with any document token, then sum across query tokens.
66fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
67    let mut total_score = 0.0;
68
69    // For each query token
70    for q_row in query.axis_iter(Axis(0)) {
71        let mut max_sim = f32::NEG_INFINITY;
72
73        // Find max similarity with any document token
74        for d_row in doc.axis_iter(Axis(0)) {
75            let sim: f32 = q_row.dot(&d_row);
76            if sim > max_sim {
77                max_sim = sim;
78            }
79        }
80
81        if max_sim > f32::NEG_INFINITY {
82            total_score += max_sim;
83        }
84    }
85
86    total_score
87}
88
89/// Compute approximate scores using centroid similarities.
90fn approximate_score(query_centroid_scores: &Array2<f32>, doc_codes: &ArrayView1<usize>) -> f32 {
91    let mut score = 0.0;
92
93    // For each query token
94    for q_idx in 0..query_centroid_scores.nrows() {
95        let mut max_score = f32::NEG_INFINITY;
96
97        // For each document token's code
98        for &code in doc_codes.iter() {
99            let centroid_score = query_centroid_scores[[q_idx, code]];
100            if centroid_score > max_score {
101                max_score = centroid_score;
102            }
103        }
104
105        if max_score > f32::NEG_INFINITY {
106            score += max_score;
107        }
108    }
109
110    score
111}
112
113/// Wrapper for f32 to use with BinaryHeap (implements Ord)
114#[cfg(feature = "npy")]
115#[derive(Clone, Copy, PartialEq)]
116struct OrdF32(f32);
117
118#[cfg(feature = "npy")]
119impl Eq for OrdF32 {}
120
121#[cfg(feature = "npy")]
122impl PartialOrd for OrdF32 {
123    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
124        Some(self.cmp(other))
125    }
126}
127
128#[cfg(feature = "npy")]
129impl Ord for OrdF32 {
130    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
131        self.0
132            .partial_cmp(&other.0)
133            .unwrap_or(std::cmp::Ordering::Equal)
134    }
135}
136
137/// Batched IVF probing for memory-efficient centroid scoring.
138///
139/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
140/// Returns the union of top centroids across all query tokens.
141#[cfg(feature = "npy")]
142fn ivf_probe_batched(
143    query: &Array2<f32>,
144    centroids: &CentroidStore,
145    n_probe: usize,
146    batch_size: usize,
147) -> Vec<usize> {
148    let num_centroids = centroids.nrows();
149    let num_tokens = query.nrows();
150
151    // Min-heap per query token to track top centroids
152    // Entry: (Reverse(score), centroid_id) - Reverse for min-heap behavior
153    let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
154        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
155        .collect();
156
157    for batch_start in (0..num_centroids).step_by(batch_size) {
158        let batch_end = (batch_start + batch_size).min(num_centroids);
159
160        // Get batch view (zero-copy from mmap)
161        let batch_centroids = centroids.slice_rows(batch_start, batch_end);
162
163        // Compute scores: [num_tokens, batch_size]
164        let batch_scores = query.dot(&batch_centroids.t());
165
166        // Update heaps with this batch's scores
167        for (q_idx, heap) in heaps.iter_mut().enumerate() {
168            for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
169                let global_c = batch_start + local_c;
170                let entry = (Reverse(OrdF32(score)), global_c);
171
172                if heap.len() < n_probe {
173                    heap.push(entry);
174                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
175                    if score > min_score {
176                        heap.pop();
177                        heap.push(entry);
178                    }
179                }
180            }
181        }
182    }
183
184    // Union top centroids across all query tokens
185    let mut selected: HashSet<usize> = HashSet::new();
186    for heap in heaps {
187        for (_, c) in heap {
188            selected.insert(c);
189        }
190    }
191    selected.into_iter().collect()
192}
193
194/// Build sparse centroid scores for a set of centroid IDs.
195///
196/// Returns a HashMap mapping centroid_id -> query scores array.
197#[cfg(feature = "npy")]
198fn build_sparse_centroid_scores(
199    query: &Array2<f32>,
200    centroids: &CentroidStore,
201    centroid_ids: &HashSet<usize>,
202) -> HashMap<usize, Array1<f32>> {
203    centroid_ids
204        .iter()
205        .map(|&c| {
206            let centroid = centroids.row(c);
207            let scores: Array1<f32> = query.dot(&centroid);
208            (c, scores)
209        })
210        .collect()
211}
212
213/// Compute approximate scores using sparse centroid score lookup.
214#[cfg(feature = "npy")]
215fn approximate_score_sparse(
216    sparse_scores: &HashMap<usize, Array1<f32>>,
217    doc_codes: &[usize],
218    num_query_tokens: usize,
219) -> f32 {
220    let mut score = 0.0;
221
222    // For each query token
223    for q_idx in 0..num_query_tokens {
224        let mut max_score = f32::NEG_INFINITY;
225
226        // For each document token's code
227        for &code in doc_codes.iter() {
228            if let Some(centroid_scores) = sparse_scores.get(&code) {
229                let centroid_score = centroid_scores[q_idx];
230                if centroid_score > max_score {
231                    max_score = centroid_score;
232                }
233            }
234        }
235
236        if max_score > f32::NEG_INFINITY {
237            score += max_score;
238        }
239    }
240
241    score
242}
243
244/// Search for a single query.
245pub fn search_one(
246    query: &Array2<f32>,
247    index: &Index,
248    params: &SearchParameters,
249    subset: Option<&[i64]>,
250) -> Result<(Vec<i64>, Vec<f32>)> {
251    // Compute query-centroid scores
252    // query: [num_tokens, dim], centroids: [num_centroids, dim]
253    // scores: [num_tokens, num_centroids]
254    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
255
256    // Find top IVF cells to probe
257    let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
258        // When filtering by subset, only probe centroids that contain subset documents
259        let mut subset_centroids: Vec<usize> = Vec::new();
260        for &doc_id in subset_docs {
261            if (doc_id as usize) < index.doc_codes.len() {
262                subset_centroids.extend(index.doc_codes[doc_id as usize].iter().copied());
263            }
264        }
265        subset_centroids.sort_unstable();
266        subset_centroids.dedup();
267
268        if subset_centroids.is_empty() {
269            return Ok((vec![], vec![]));
270        }
271
272        // Compute scores for subset centroids and take top-k
273        let mut centroid_scores: Vec<(usize, f32)> = subset_centroids
274            .iter()
275            .map(|&c| {
276                let score: f32 = query_centroid_scores
277                    .axis_iter(Axis(0))
278                    .map(|q| q[c])
279                    .max_by(|a, b| a.partial_cmp(b).unwrap())
280                    .unwrap_or(0.0);
281                (c, score)
282            })
283            .collect();
284
285        centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
286        centroid_scores
287            .iter()
288            .take(params.n_ivf_probe)
289            .map(|(c, _)| *c)
290            .collect()
291    } else {
292        // Standard path: select top-k centroids PER query token, then take union
293        // This matches fast-plaid's algorithm: for each query token, find the best centroids
294        let num_centroids = index.codec.num_centroids();
295        let num_query_tokens = query_centroid_scores.nrows();
296
297        // Collect all centroid indices from per-token top-k
298        let mut selected_centroids = std::collections::HashSet::new();
299
300        for q_idx in 0..num_query_tokens {
301            // Get scores for this query token
302            let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
303                .map(|c| (c, query_centroid_scores[[q_idx, c]]))
304                .collect();
305
306            // Sort by score descending and take top n_ivf_probe
307            centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
308
309            for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
310                selected_centroids.insert(*c);
311            }
312        }
313
314        selected_centroids.into_iter().collect()
315    };
316
317    // Get candidate documents from IVF
318    let mut candidates = index.get_candidates(&cells_to_probe);
319
320    // Filter by subset if provided
321    if let Some(subset_docs) = subset {
322        let subset_set: std::collections::HashSet<i64> = subset_docs.iter().copied().collect();
323        candidates.retain(|&c| subset_set.contains(&c));
324    }
325
326    if candidates.is_empty() {
327        return Ok((vec![], vec![]));
328    }
329
330    // Compute approximate scores
331    let mut approx_scores: Vec<(i64, f32)> = candidates
332        .par_iter()
333        .map(|&doc_id| {
334            let codes = &index.doc_codes[doc_id as usize];
335            let score = approximate_score(&query_centroid_scores, &codes.view());
336            (doc_id, score)
337        })
338        .collect();
339
340    // Sort by approximate score and take top candidates
341    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
342    let top_candidates: Vec<i64> = approx_scores
343        .iter()
344        .take(params.n_full_scores)
345        .map(|(id, _)| *id)
346        .collect();
347
348    // Further reduce for full decompression
349    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
350    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
351
352    if to_decompress.is_empty() {
353        return Ok((vec![], vec![]));
354    }
355
356    // Compute exact scores with decompressed embeddings
357    let mut exact_scores: Vec<(i64, f32)> = to_decompress
358        .par_iter()
359        .filter_map(|&doc_id| {
360            let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
361            let score = colbert_score(&query.view(), &doc_embeddings.view());
362            Some((doc_id, score))
363        })
364        .collect();
365
366    // Sort by exact score
367    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
368
369    // Return top-k results
370    let result_count = params.top_k.min(exact_scores.len());
371    let passage_ids: Vec<i64> = exact_scores
372        .iter()
373        .take(result_count)
374        .map(|(id, _)| *id)
375        .collect();
376    let scores: Vec<f32> = exact_scores
377        .iter()
378        .take(result_count)
379        .map(|(_, s)| *s)
380        .collect();
381
382    Ok((passage_ids, scores))
383}
384
385/// Search for multiple queries.
386pub fn search_many(
387    queries: &[Array2<f32>],
388    index: &Index,
389    params: &SearchParameters,
390    show_progress: bool,
391    subsets: Option<&[Vec<i64>]>,
392) -> Result<Vec<QueryResult>> {
393    let progress = if show_progress {
394        let bar = indicatif::ProgressBar::new(queries.len() as u64);
395        bar.set_message("Searching...");
396        Some(bar)
397    } else {
398        None
399    };
400
401    let results: Vec<QueryResult> = queries
402        .par_iter()
403        .enumerate()
404        .map(|(i, query)| {
405            let subset = subsets.and_then(|s| s.get(i).map(|v| v.as_slice()));
406            let (passage_ids, scores) =
407                search_one(query, index, params, subset).unwrap_or_default();
408
409            if let Some(ref bar) = progress {
410                bar.inc(1);
411            }
412
413            QueryResult {
414                query_id: i,
415                passage_ids,
416                scores,
417            }
418        })
419        .collect();
420
421    if let Some(bar) = progress {
422        bar.finish();
423    }
424
425    Ok(results)
426}
427
428/// Convenience function to search the index.
429impl Index {
430    /// Search the index with a single query.
431    pub fn search(
432        &self,
433        query: &Array2<f32>,
434        params: &SearchParameters,
435        subset: Option<&[i64]>,
436    ) -> Result<QueryResult> {
437        let (passage_ids, scores) = search_one(query, self, params, subset)?;
438        Ok(QueryResult {
439            query_id: 0,
440            passage_ids,
441            scores,
442        })
443    }
444
445    /// Search the index with multiple queries.
446    pub fn search_batch(
447        &self,
448        queries: &[Array2<f32>],
449        params: &SearchParameters,
450        show_progress: bool,
451        subsets: Option<&[Vec<i64>]>,
452    ) -> Result<Vec<QueryResult>> {
453        search_many(queries, self, params, show_progress, subsets)
454    }
455}
456
457// ============================================================================
458// Memory-Mapped Index Search
459// ============================================================================
460
461/// Compute approximate scores for mmap index using code lookups.
462#[cfg(feature = "npy")]
463fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
464    let mut score = 0.0;
465
466    for q_idx in 0..query_centroid_scores.nrows() {
467        let mut max_score = f32::NEG_INFINITY;
468
469        for &code in doc_codes.iter() {
470            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
471            if centroid_score > max_score {
472                max_score = centroid_score;
473            }
474        }
475
476        if max_score > f32::NEG_INFINITY {
477            score += max_score;
478        }
479    }
480
481    score
482}
483
484/// Search a memory-mapped index for a single query.
485#[cfg(feature = "npy")]
486pub fn search_one_mmap(
487    index: &crate::index::MmapIndex,
488    query: &Array2<f32>,
489    params: &SearchParameters,
490    subset: Option<&[i64]>,
491) -> Result<QueryResult> {
492    use ndarray::Axis;
493
494    let num_centroids = index.codec.num_centroids();
495    let num_query_tokens = query.nrows();
496
497    // Decide whether to use batched mode for memory efficiency
498    let use_batched = params.centroid_batch_size > 0
499        && num_centroids > params.centroid_batch_size
500        && subset.is_none();
501
502    if use_batched {
503        // Batched path: memory-efficient IVF probing for large centroid counts
504        return search_one_mmap_batched(index, query, params);
505    }
506
507    // Standard path: compute full query-centroid scores upfront
508    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
509
510    // Find top IVF cells to probe
511    let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
512        // When filtering by subset, only probe centroids that contain subset documents
513        let mut subset_centroids: Vec<usize> = Vec::new();
514        for &doc_id in subset_docs {
515            if (doc_id as usize) < index.doc_lengths.len() {
516                let start = index.doc_offsets[doc_id as usize];
517                let end = index.doc_offsets[doc_id as usize + 1];
518                let codes = index.mmap_codes.slice(start, end);
519                subset_centroids.extend(codes.iter().map(|&c| c as usize));
520            }
521        }
522        subset_centroids.sort_unstable();
523        subset_centroids.dedup();
524
525        if subset_centroids.is_empty() {
526            return Ok(QueryResult {
527                query_id: 0,
528                passage_ids: vec![],
529                scores: vec![],
530            });
531        }
532
533        // Compute scores for subset centroids and take top-k
534        let mut centroid_scores: Vec<(usize, f32)> = subset_centroids
535            .iter()
536            .map(|&c| {
537                let score: f32 = query_centroid_scores
538                    .axis_iter(Axis(0))
539                    .map(|q| q[c])
540                    .max_by(|a, b| a.partial_cmp(b).unwrap())
541                    .unwrap_or(0.0);
542                (c, score)
543            })
544            .collect();
545
546        centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
547        centroid_scores
548            .iter()
549            .take(params.n_ivf_probe)
550            .map(|(c, _)| *c)
551            .collect()
552    } else {
553        // Standard path: select top-k centroids per query token
554        let mut selected_centroids = HashSet::new();
555
556        for q_idx in 0..num_query_tokens {
557            let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
558                .map(|c| (c, query_centroid_scores[[q_idx, c]]))
559                .collect();
560
561            centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
562
563            for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
564                selected_centroids.insert(*c);
565            }
566        }
567
568        selected_centroids.into_iter().collect()
569    };
570
571    // Get candidate documents from IVF
572    let mut candidates = index.get_candidates(&cells_to_probe);
573
574    // Filter by subset if provided
575    if let Some(subset_docs) = subset {
576        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
577        candidates.retain(|&c| subset_set.contains(&c));
578    }
579
580    if candidates.is_empty() {
581        return Ok(QueryResult {
582            query_id: 0,
583            passage_ids: vec![],
584            scores: vec![],
585        });
586    }
587
588    // Compute approximate scores
589    let mut approx_scores: Vec<(i64, f32)> = candidates
590        .par_iter()
591        .map(|&doc_id| {
592            let start = index.doc_offsets[doc_id as usize];
593            let end = index.doc_offsets[doc_id as usize + 1];
594            let codes = index.mmap_codes.slice(start, end);
595            let score = approximate_score_mmap(&query_centroid_scores, codes);
596            (doc_id, score)
597        })
598        .collect();
599
600    // Sort by approximate score and take top candidates
601    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
602    let top_candidates: Vec<i64> = approx_scores
603        .iter()
604        .take(params.n_full_scores)
605        .map(|(id, _)| *id)
606        .collect();
607
608    // Further reduce for full decompression
609    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
610    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
611
612    if to_decompress.is_empty() {
613        return Ok(QueryResult {
614            query_id: 0,
615            passage_ids: vec![],
616            scores: vec![],
617        });
618    }
619
620    // Compute exact scores with decompressed embeddings
621    let mut exact_scores: Vec<(i64, f32)> = to_decompress
622        .par_iter()
623        .filter_map(|&doc_id| {
624            let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
625            let score = colbert_score(&query.view(), &doc_embeddings.view());
626            Some((doc_id, score))
627        })
628        .collect();
629
630    // Sort by exact score
631    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
632
633    // Return top-k results
634    let result_count = params.top_k.min(exact_scores.len());
635    let passage_ids: Vec<i64> = exact_scores
636        .iter()
637        .take(result_count)
638        .map(|(id, _)| *id)
639        .collect();
640    let scores: Vec<f32> = exact_scores
641        .iter()
642        .take(result_count)
643        .map(|(_, s)| *s)
644        .collect();
645
646    Ok(QueryResult {
647        query_id: 0,
648        passage_ids,
649        scores,
650    })
651}
652
653/// Memory-efficient batched search for MmapIndex with large centroid counts.
654///
655/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
656#[cfg(feature = "npy")]
657fn search_one_mmap_batched(
658    index: &crate::index::MmapIndex,
659    query: &Array2<f32>,
660    params: &SearchParameters,
661) -> Result<QueryResult> {
662    let num_query_tokens = query.nrows();
663
664    // Step 1: Batched IVF probing
665    let cells_to_probe = ivf_probe_batched(
666        query,
667        &index.codec.centroids,
668        params.n_ivf_probe,
669        params.centroid_batch_size,
670    );
671
672    // Step 2: Get candidate documents from IVF
673    let candidates = index.get_candidates(&cells_to_probe);
674
675    if candidates.is_empty() {
676        return Ok(QueryResult {
677            query_id: 0,
678            passage_ids: vec![],
679            scores: vec![],
680        });
681    }
682
683    // Step 3: Collect unique centroids from all candidate documents
684    let mut unique_centroids: HashSet<usize> = HashSet::new();
685    for &doc_id in &candidates {
686        let start = index.doc_offsets[doc_id as usize];
687        let end = index.doc_offsets[doc_id as usize + 1];
688        let codes = index.mmap_codes.slice(start, end);
689        for &code in codes.iter() {
690            unique_centroids.insert(code as usize);
691        }
692    }
693
694    // Step 4: Build sparse centroid scores
695    let sparse_scores =
696        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
697
698    // Step 5: Compute approximate scores using sparse lookup
699    let mut approx_scores: Vec<(i64, f32)> = candidates
700        .par_iter()
701        .map(|&doc_id| {
702            let start = index.doc_offsets[doc_id as usize];
703            let end = index.doc_offsets[doc_id as usize + 1];
704            let codes = index.mmap_codes.slice(start, end);
705            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
706            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
707            (doc_id, score)
708        })
709        .collect();
710
711    // Sort by approximate score and take top candidates
712    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
713    let top_candidates: Vec<i64> = approx_scores
714        .iter()
715        .take(params.n_full_scores)
716        .map(|(id, _)| *id)
717        .collect();
718
719    // Further reduce for full decompression
720    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
721    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
722
723    if to_decompress.is_empty() {
724        return Ok(QueryResult {
725            query_id: 0,
726            passage_ids: vec![],
727            scores: vec![],
728        });
729    }
730
731    // Compute exact scores with decompressed embeddings
732    let mut exact_scores: Vec<(i64, f32)> = to_decompress
733        .par_iter()
734        .filter_map(|&doc_id| {
735            let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
736            let score = colbert_score(&query.view(), &doc_embeddings.view());
737            Some((doc_id, score))
738        })
739        .collect();
740
741    // Sort by exact score
742    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
743
744    // Return top-k results
745    let result_count = params.top_k.min(exact_scores.len());
746    let passage_ids: Vec<i64> = exact_scores
747        .iter()
748        .take(result_count)
749        .map(|(id, _)| *id)
750        .collect();
751    let scores: Vec<f32> = exact_scores
752        .iter()
753        .take(result_count)
754        .map(|(_, s)| *s)
755        .collect();
756
757    Ok(QueryResult {
758        query_id: 0,
759        passage_ids,
760        scores,
761    })
762}
763
764/// Search a memory-mapped index for multiple queries.
765#[cfg(feature = "npy")]
766pub fn search_many_mmap(
767    index: &crate::index::MmapIndex,
768    queries: &[Array2<f32>],
769    params: &SearchParameters,
770    parallel: bool,
771    subset: Option<&[i64]>,
772) -> Result<Vec<QueryResult>> {
773    if parallel {
774        let results: Vec<QueryResult> = queries
775            .par_iter()
776            .enumerate()
777            .map(|(i, query)| {
778                let mut result =
779                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
780                        query_id: i,
781                        passage_ids: vec![],
782                        scores: vec![],
783                    });
784                result.query_id = i;
785                result
786            })
787            .collect();
788        Ok(results)
789    } else {
790        let mut results = Vec::with_capacity(queries.len());
791        for (i, query) in queries.iter().enumerate() {
792            let mut result = search_one_mmap(index, query, params, subset)?;
793            result.query_id = i;
794            results.push(result);
795        }
796        Ok(results)
797    }
798}
799
800/// Alias type for search result (for API compatibility)
801pub type SearchResult = QueryResult;
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806
807    #[test]
808    fn test_colbert_score() {
809        // Query with 2 tokens, dim 4
810        let query =
811            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
812
813        // Document with 3 tokens
814        let doc = Array2::from_shape_vec(
815            (3, 4),
816            vec![
817                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
818                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
819                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
820            ],
821        )
822        .unwrap();
823
824        let score = colbert_score(&query.view(), &doc.view());
825        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
826        // Total: 0.8 + 0.9 = 1.7
827        assert!((score - 1.7).abs() < 1e-5);
828    }
829
830    #[test]
831    fn test_search_params_default() {
832        let params = SearchParameters::default();
833        assert_eq!(params.batch_size, 2000);
834        assert_eq!(params.n_full_scores, 4096);
835        assert_eq!(params.top_k, 10);
836        assert_eq!(params.n_ivf_probe, 8);
837    }
838}