Skip to main content

next_plaid/
search.rs

1//! Search functionality for PLAID
2
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15/// Per-token top-k heaps and per-centroid max scores from a batch of centroids.
16type ProbePartial = (
17    Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18    HashMap<usize, f32>,
19);
20
21/// Maximum number of documents to decompress concurrently during exact scoring.
22/// This limits peak memory usage from parallel decompression.
23/// With 128 docs × ~300KB per doc = ~40MB max concurrent decompression memory.
24const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26/// Search parameters
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29    /// Number of queries per batch
30    pub batch_size: usize,
31    /// Number of documents to re-rank with exact scores
32    pub n_full_scores: usize,
33    /// Number of final results to return per query
34    pub top_k: usize,
35    /// Number of IVF cells to probe during search
36    pub n_ivf_probe: usize,
37    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
38    /// Lower values use less memory but are slower. Default 100_000.
39    /// Only used when num_centroids > centroid_batch_size.
40    #[serde(default = "default_centroid_batch_size")]
41    pub centroid_batch_size: usize,
42    /// Centroid score threshold (t_cs) for centroid pruning.
43    /// A centroid is only included if its maximum score across all query tokens
44    /// meets or exceeds this threshold. Set to None to disable pruning.
45    /// Default: Some(0.4)
46    #[serde(default = "default_centroid_score_threshold")]
47    pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51    100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55    Some(0.4)
56}
57
58impl Default for SearchParameters {
59    fn default() -> Self {
60        Self {
61            batch_size: 2000,
62            n_full_scores: 4096,
63            top_k: 10,
64            n_ivf_probe: 8,
65            centroid_batch_size: default_centroid_batch_size(),
66            centroid_score_threshold: default_centroid_score_threshold(),
67        }
68    }
69}
70
71/// Result of a single query
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74    /// Query ID
75    pub query_id: usize,
76    /// Retrieved document IDs (ranked by relevance)
77    pub passage_ids: Vec<i64>,
78    /// Relevance scores for each document
79    pub scores: Vec<f32>,
80}
81
82/// Minimum matrix size (query_tokens * doc_tokens) to use CUDA.
83/// Below this threshold, CPU is faster due to GPU transfer overhead.
84/// Based on benchmarks: 128 * 1024 = 131072
85#[cfg(feature = "cuda")]
86const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
87
88/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
89/// with any document token, then sum across query tokens.
90///
91/// When the `cuda` feature is enabled and matrices are large enough,
92/// this function automatically uses CUDA acceleration.
93fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
94    // Try CUDA for large matrices
95    #[cfg(feature = "cuda")]
96    {
97        let matrix_size = query.nrows() * doc.nrows();
98        if matrix_size >= CUDA_COLBERT_MIN_SIZE {
99            if let Some(ctx) = crate::cuda::get_global_context() {
100                match crate::cuda::colbert_score_cuda(ctx, query, doc) {
101                    Ok(score) => return score,
102                    Err(_) => {
103                        // Silent fallback to CPU for scoring (happens frequently)
104                    }
105                }
106            }
107        }
108    }
109
110    // CPU implementation
111    colbert_score_cpu(query, doc)
112}
113
114/// CPU implementation of ColBERT MaxSim scoring.
115/// Uses SIMD-accelerated max reduction and BLAS for matrix multiplication.
116fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
117    maxsim::maxsim_score(query, doc)
118}
119
120/// Wrapper for f32 to use with BinaryHeap (implements Ord)
121#[derive(Clone, Copy, PartialEq)]
122struct OrdF32(f32);
123
124impl Eq for OrdF32 {}
125
126impl PartialOrd for OrdF32 {
127    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
128        Some(self.cmp(other))
129    }
130}
131
132impl Ord for OrdF32 {
133    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
134        self.0
135            .partial_cmp(&other.0)
136            .unwrap_or(std::cmp::Ordering::Equal)
137    }
138}
139
140/// Batched IVF probing for memory-efficient centroid scoring.
141///
142/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
143/// Returns the union of top centroids across all query tokens.
144/// If a threshold is provided, filters out centroids where max score < threshold.
145fn ivf_probe_batched(
146    query: &Array2<f32>,
147    centroids: &CentroidStore,
148    n_probe: usize,
149    batch_size: usize,
150    centroid_score_threshold: Option<f32>,
151) -> Vec<usize> {
152    let num_centroids = centroids.nrows();
153    let num_tokens = query.nrows();
154
155    // Build batch ranges for parallel processing
156    let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
157        .step_by(batch_size)
158        .map(|start| (start, (start + batch_size).min(num_centroids)))
159        .collect();
160
161    // Process centroid batches in parallel. Each rayon thread computes a GEMM
162    // (with single-threaded BLAS via OPENBLAS_NUM_THREADS=1) and maintains local
163    // per-token top-k heaps. Memory is bounded: rayon's thread pool ensures at most
164    // num_cpus batch_scores matrices (each batch_size × num_tokens × 4 bytes) exist
165    // simultaneously, same as the sequential approach where num_cpus queries each
166    // process one batch at a time.
167    let local_results: Vec<ProbePartial> = batch_ranges
168        .par_iter()
169        .map(|&(batch_start, batch_end)| {
170            let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
171                .map(|_| BinaryHeap::with_capacity(n_probe + 1))
172                .collect();
173            let mut max_scores: HashMap<usize, f32> = HashMap::new();
174
175            // Get batch view (zero-copy from mmap)
176            let batch_centroids = centroids.slice_rows(batch_start, batch_end);
177
178            // Compute scores: [num_tokens, batch_size] — single-threaded BLAS
179            let batch_scores = query.dot(&batch_centroids.t());
180
181            // Update local heaps with this batch's scores
182            for (q_idx, heap) in heaps.iter_mut().enumerate() {
183                for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
184                    let global_c = batch_start + local_c;
185                    let entry = (Reverse(OrdF32(score)), global_c);
186
187                    if heap.len() < n_probe {
188                        heap.push(entry);
189                        max_scores
190                            .entry(global_c)
191                            .and_modify(|s| *s = s.max(score))
192                            .or_insert(score);
193                    } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
194                        if score > min_score {
195                            heap.pop();
196                            heap.push(entry);
197                            max_scores
198                                .entry(global_c)
199                                .and_modify(|s| *s = s.max(score))
200                                .or_insert(score);
201                        }
202                    }
203                }
204            }
205
206            (heaps, max_scores)
207        })
208        .collect();
209
210    // Merge local heaps into final result (lightweight: each heap has at most
211    // n_probe entries, and there are num_batches heaps per token to merge)
212    let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
213        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
214        .collect();
215    let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
216
217    for (local_heaps, local_max_scores) in local_results {
218        for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
219            for entry in local_heap {
220                let (Reverse(OrdF32(score)), _) = entry;
221                if final_heaps[q_idx].len() < n_probe {
222                    final_heaps[q_idx].push(entry);
223                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
224                    if score > min_score {
225                        final_heaps[q_idx].pop();
226                        final_heaps[q_idx].push(entry);
227                    }
228                }
229            }
230        }
231        for (c, score) in local_max_scores {
232            final_max_scores
233                .entry(c)
234                .and_modify(|s| *s = s.max(score))
235                .or_insert(score);
236        }
237    }
238
239    // Union top centroids across all query tokens
240    let mut selected: HashSet<usize> = HashSet::new();
241    for heap in final_heaps {
242        for (_, c) in heap {
243            selected.insert(c);
244        }
245    }
246
247    // Apply centroid score threshold if set
248    if let Some(threshold) = centroid_score_threshold {
249        selected.retain(|c| {
250            final_max_scores
251                .get(c)
252                .copied()
253                .unwrap_or(f32::NEG_INFINITY)
254                >= threshold
255        });
256    }
257
258    selected.into_iter().collect()
259}
260
261/// Build sparse centroid scores for a set of centroid IDs.
262///
263/// Returns a HashMap mapping centroid_id -> query scores array.
264fn build_sparse_centroid_scores(
265    query: &Array2<f32>,
266    centroids: &CentroidStore,
267    centroid_ids: &HashSet<usize>,
268) -> HashMap<usize, Array1<f32>> {
269    centroid_ids
270        .iter()
271        .map(|&c| {
272            let centroid = centroids.row(c);
273            let scores: Array1<f32> = query.dot(&centroid);
274            (c, scores)
275        })
276        .collect()
277}
278
279/// Compute approximate scores using sparse centroid score lookup.
280fn approximate_score_sparse(
281    sparse_scores: &HashMap<usize, Array1<f32>>,
282    doc_codes: &[usize],
283    num_query_tokens: usize,
284) -> f32 {
285    let mut score = 0.0;
286
287    // For each query token
288    for q_idx in 0..num_query_tokens {
289        let mut max_score = f32::NEG_INFINITY;
290
291        // For each document token's code
292        for &code in doc_codes.iter() {
293            if let Some(centroid_scores) = sparse_scores.get(&code) {
294                let centroid_score = centroid_scores[q_idx];
295                if centroid_score > max_score {
296                    max_score = centroid_score;
297                }
298            }
299        }
300
301        if max_score > f32::NEG_INFINITY {
302            score += max_score;
303        }
304    }
305
306    score
307}
308
309/// Compute approximate scores for mmap index using code lookups.
310fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
311    let mut score = 0.0;
312
313    for q_idx in 0..query_centroid_scores.nrows() {
314        let mut max_score = f32::NEG_INFINITY;
315
316        for &code in doc_codes.iter() {
317            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
318            if centroid_score > max_score {
319                max_score = centroid_score;
320            }
321        }
322
323        if max_score > f32::NEG_INFINITY {
324            score += max_score;
325        }
326    }
327
328    score
329}
330
331/// Search a memory-mapped index for a single query.
332pub fn search_one_mmap(
333    index: &crate::index::MmapIndex,
334    query: &Array2<f32>,
335    params: &SearchParameters,
336    subset: Option<&[i64]>,
337) -> Result<QueryResult> {
338    let num_centroids = index.codec.num_centroids();
339    let num_query_tokens = query.nrows();
340
341    // Decide whether to use batched mode for memory efficiency
342    let use_batched = params.centroid_batch_size > 0 && num_centroids > params.centroid_batch_size;
343
344    if use_batched {
345        // Batched path: memory-efficient IVF probing for large centroid counts
346        return search_one_mmap_batched(index, query, params, subset);
347    }
348
349    // Standard path: compute full query-centroid scores upfront
350    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
351
352    // When subset is provided, pre-compute eligible centroids: only those containing
353    // at least one embedding from a subset document. Centroids without subset docs
354    // can't contribute candidates, so skipping them is a pure optimization.
355    let eligible_centroids: Option<HashSet<usize>> = subset.map(|subset_docs| {
356        let mut centroids = HashSet::new();
357        for &doc_id in subset_docs {
358            let doc_idx = doc_id as usize;
359            if doc_idx < index.doc_lengths.len() {
360                let start = index.doc_offsets[doc_idx];
361                let end = index.doc_offsets[doc_idx + 1];
362                let codes = index.mmap_codes.slice(start, end);
363                for &c in codes.iter() {
364                    centroids.insert(c as usize);
365                }
366            }
367        }
368        centroids
369    });
370
371    // When pre-filtering, scale n_ivf_probe by the document ratio to compensate
372    // for candidates lost to filtering. If 50% of docs are filtered out, we probe
373    // ~2x more centroids to find enough relevant candidates.
374    // No filter: n_ivf_probe unchanged.
375    let effective_n_ivf_probe = match (&eligible_centroids, subset) {
376        (Some(eligible), Some(subset_docs)) if !eligible.is_empty() => {
377            let num_docs = index.doc_lengths.len();
378            let subset_len = subset_docs.len();
379            let scaled = if subset_len > 0 {
380                (params.n_ivf_probe as u64 * num_docs as u64 / subset_len as u64) as usize
381            } else {
382                params.n_ivf_probe
383            };
384            scaled.max(params.n_ivf_probe).min(eligible.len())
385        }
386        _ => params.n_ivf_probe,
387    };
388
389    // Find top IVF cells to probe using per-token top-k selection.
390    // When pre-filtering, only score eligible centroids (same selection logic,
391    // smaller pool). This can only improve recall for subset docs since
392    // ineligible centroids would have wasted probe slots.
393    let cells_to_probe: Vec<usize> = {
394        let mut selected_centroids = HashSet::new();
395
396        for q_idx in 0..num_query_tokens {
397            let mut centroid_scores: Vec<(usize, f32)> = match &eligible_centroids {
398                Some(eligible) => eligible
399                    .iter()
400                    .map(|&c| (c, query_centroid_scores[[q_idx, c]]))
401                    .collect(),
402                None => (0..num_centroids)
403                    .map(|c| (c, query_centroid_scores[[q_idx, c]]))
404                    .collect(),
405            };
406
407            // Partial selection: O(K) average instead of O(K log K) for full sort
408            // After this, the top n elements are in positions 0..n
409            // (but not sorted among themselves - which is fine since we use a HashSet)
410            let n_probe = effective_n_ivf_probe.min(centroid_scores.len());
411            if centroid_scores.len() > n_probe {
412                centroid_scores.select_nth_unstable_by(n_probe - 1, |a, b| {
413                    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
414                });
415            }
416
417            for (c, _) in centroid_scores.iter().take(n_probe) {
418                selected_centroids.insert(*c);
419            }
420        }
421
422        // Apply centroid score threshold: filter out centroids where max score < threshold
423        if let Some(threshold) = params.centroid_score_threshold {
424            selected_centroids.retain(|&c| {
425                let max_score: f32 = (0..num_query_tokens)
426                    .map(|q_idx| query_centroid_scores[[q_idx, c]])
427                    .max_by(|a, b| a.partial_cmp(b).unwrap())
428                    .unwrap_or(f32::NEG_INFINITY);
429                max_score >= threshold
430            });
431        }
432
433        selected_centroids.into_iter().collect()
434    };
435
436    // Get candidate documents from IVF
437    let mut candidates = index.get_candidates(&cells_to_probe);
438
439    // Filter by subset if provided
440    if let Some(subset_docs) = subset {
441        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
442        candidates.retain(|&c| subset_set.contains(&c));
443    }
444
445    if candidates.is_empty() {
446        return Ok(QueryResult {
447            query_id: 0,
448            passage_ids: vec![],
449            scores: vec![],
450        });
451    }
452
453    // Compute approximate scores
454    let mut approx_scores: Vec<(i64, f32)> = candidates
455        .par_iter()
456        .map(|&doc_id| {
457            let start = index.doc_offsets[doc_id as usize];
458            let end = index.doc_offsets[doc_id as usize + 1];
459            let codes = index.mmap_codes.slice(start, end);
460            let score = approximate_score_mmap(&query_centroid_scores, &codes);
461            (doc_id, score)
462        })
463        .collect();
464
465    // Sort by approximate score and take top candidates
466    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
467    let top_candidates: Vec<i64> = approx_scores
468        .iter()
469        .take(params.n_full_scores)
470        .map(|(id, _)| *id)
471        .collect();
472
473    // Further reduce for full decompression
474    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
475    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
476
477    if to_decompress.is_empty() {
478        return Ok(QueryResult {
479            query_id: 0,
480            passage_ids: vec![],
481            scores: vec![],
482        });
483    }
484
485    // Compute exact scores with decompressed embeddings
486    // Use chunked processing to limit concurrent memory from parallel decompression
487    let mut exact_scores: Vec<(i64, f32)> = to_decompress
488        .par_chunks(DECOMPRESS_CHUNK_SIZE)
489        .flat_map(|chunk| {
490            chunk
491                .iter()
492                .filter_map(|&doc_id| {
493                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
494                    let score = colbert_score(&query.view(), &doc_embeddings.view());
495                    Some((doc_id, score))
496                })
497                .collect::<Vec<_>>()
498        })
499        .collect();
500
501    // Sort by exact score
502    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
503
504    // Return top-k results
505    let result_count = params.top_k.min(exact_scores.len());
506    let passage_ids: Vec<i64> = exact_scores
507        .iter()
508        .take(result_count)
509        .map(|(id, _)| *id)
510        .collect();
511    let scores: Vec<f32> = exact_scores
512        .iter()
513        .take(result_count)
514        .map(|(_, s)| *s)
515        .collect();
516
517    Ok(QueryResult {
518        query_id: 0,
519        passage_ids,
520        scores,
521    })
522}
523
524/// Memory-efficient batched search for MmapIndex with large centroid counts.
525///
526/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
527fn search_one_mmap_batched(
528    index: &crate::index::MmapIndex,
529    query: &Array2<f32>,
530    params: &SearchParameters,
531    subset: Option<&[i64]>,
532) -> Result<QueryResult> {
533    let num_query_tokens = query.nrows();
534
535    // Step 1: Batched IVF probing
536    let cells_to_probe = ivf_probe_batched(
537        query,
538        &index.codec.centroids,
539        params.n_ivf_probe,
540        params.centroid_batch_size,
541        params.centroid_score_threshold,
542    );
543
544    // Step 2: Get candidate documents from IVF
545    let mut candidates = index.get_candidates(&cells_to_probe);
546
547    // Filter by subset if provided
548    if let Some(subset_docs) = subset {
549        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
550        candidates.retain(|&c| subset_set.contains(&c));
551    }
552
553    if candidates.is_empty() {
554        return Ok(QueryResult {
555            query_id: 0,
556            passage_ids: vec![],
557            scores: vec![],
558        });
559    }
560
561    // Step 3: Collect unique centroids from all candidate documents
562    let mut unique_centroids: HashSet<usize> = HashSet::new();
563    for &doc_id in &candidates {
564        let start = index.doc_offsets[doc_id as usize];
565        let end = index.doc_offsets[doc_id as usize + 1];
566        let codes = index.mmap_codes.slice(start, end);
567        for &code in codes.iter() {
568            unique_centroids.insert(code as usize);
569        }
570    }
571
572    // Step 4: Build sparse centroid scores
573    let sparse_scores =
574        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
575
576    // Step 5: Compute approximate scores using sparse lookup
577    let mut approx_scores: Vec<(i64, f32)> = candidates
578        .par_iter()
579        .map(|&doc_id| {
580            let start = index.doc_offsets[doc_id as usize];
581            let end = index.doc_offsets[doc_id as usize + 1];
582            let codes = index.mmap_codes.slice(start, end);
583            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
584            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
585            (doc_id, score)
586        })
587        .collect();
588
589    // Sort by approximate score and take top candidates
590    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
591    let top_candidates: Vec<i64> = approx_scores
592        .iter()
593        .take(params.n_full_scores)
594        .map(|(id, _)| *id)
595        .collect();
596
597    // Further reduce for full decompression
598    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
599    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
600
601    if to_decompress.is_empty() {
602        return Ok(QueryResult {
603            query_id: 0,
604            passage_ids: vec![],
605            scores: vec![],
606        });
607    }
608
609    // Compute exact scores with decompressed embeddings
610    // Use chunked processing to limit concurrent memory from parallel decompression
611    let mut exact_scores: Vec<(i64, f32)> = to_decompress
612        .par_chunks(DECOMPRESS_CHUNK_SIZE)
613        .flat_map(|chunk| {
614            chunk
615                .iter()
616                .filter_map(|&doc_id| {
617                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
618                    let score = colbert_score(&query.view(), &doc_embeddings.view());
619                    Some((doc_id, score))
620                })
621                .collect::<Vec<_>>()
622        })
623        .collect();
624
625    // Sort by exact score
626    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
627
628    // Return top-k results
629    let result_count = params.top_k.min(exact_scores.len());
630    let passage_ids: Vec<i64> = exact_scores
631        .iter()
632        .take(result_count)
633        .map(|(id, _)| *id)
634        .collect();
635    let scores: Vec<f32> = exact_scores
636        .iter()
637        .take(result_count)
638        .map(|(_, s)| *s)
639        .collect();
640
641    Ok(QueryResult {
642        query_id: 0,
643        passage_ids,
644        scores,
645    })
646}
647
648/// Search a memory-mapped index for multiple queries.
649pub fn search_many_mmap(
650    index: &crate::index::MmapIndex,
651    queries: &[Array2<f32>],
652    params: &SearchParameters,
653    parallel: bool,
654    subset: Option<&[i64]>,
655) -> Result<Vec<QueryResult>> {
656    if parallel {
657        let results: Vec<QueryResult> = queries
658            .par_iter()
659            .enumerate()
660            .map(|(i, query)| {
661                let mut result =
662                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
663                        query_id: i,
664                        passage_ids: vec![],
665                        scores: vec![],
666                    });
667                result.query_id = i;
668                result
669            })
670            .collect();
671        Ok(results)
672    } else {
673        let mut results = Vec::with_capacity(queries.len());
674        for (i, query) in queries.iter().enumerate() {
675            let mut result = search_one_mmap(index, query, params, subset)?;
676            result.query_id = i;
677            results.push(result);
678        }
679        Ok(results)
680    }
681}
682
683/// Alias type for search result (for API compatibility)
684pub type SearchResult = QueryResult;
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[test]
691    fn test_colbert_score() {
692        // Query with 2 tokens, dim 4
693        let query =
694            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
695
696        // Document with 3 tokens
697        let doc = Array2::from_shape_vec(
698            (3, 4),
699            vec![
700                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
701                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
702                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
703            ],
704        )
705        .unwrap();
706
707        let score = colbert_score(&query.view(), &doc.view());
708        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
709        // Total: 0.8 + 0.9 = 1.7
710        assert!((score - 1.7).abs() < 1e-5);
711    }
712
713    #[test]
714    fn test_search_params_default() {
715        let params = SearchParameters::default();
716        assert_eq!(params.batch_size, 2000);
717        assert_eq!(params.n_full_scores, 4096);
718        assert_eq!(params.top_k, 10);
719        assert_eq!(params.n_ivf_probe, 8);
720        assert_eq!(params.centroid_score_threshold, Some(0.4));
721    }
722}