Skip to main content

next_plaid/
search.rs

1//! Search functionality for PLAID
2
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15/// Maximum number of documents to decompress concurrently during exact scoring.
16/// This limits peak memory usage from parallel decompression.
17/// With 128 docs × ~300KB per doc = ~40MB max concurrent decompression memory.
18const DECOMPRESS_CHUNK_SIZE: usize = 128;
19
20/// Search parameters
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SearchParameters {
23    /// Number of queries per batch
24    pub batch_size: usize,
25    /// Number of documents to re-rank with exact scores
26    pub n_full_scores: usize,
27    /// Number of final results to return per query
28    pub top_k: usize,
29    /// Number of IVF cells to probe during search
30    pub n_ivf_probe: usize,
31    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
32    /// Lower values use less memory but are slower. Default 100_000.
33    /// Only used when num_centroids > centroid_batch_size.
34    #[serde(default = "default_centroid_batch_size")]
35    pub centroid_batch_size: usize,
36    /// Centroid score threshold (t_cs) for centroid pruning.
37    /// A centroid is only included if its maximum score across all query tokens
38    /// meets or exceeds this threshold. Set to None to disable pruning.
39    /// Default: Some(0.4)
40    #[serde(default = "default_centroid_score_threshold")]
41    pub centroid_score_threshold: Option<f32>,
42}
43
44fn default_centroid_batch_size() -> usize {
45    100_000
46}
47
48fn default_centroid_score_threshold() -> Option<f32> {
49    Some(0.4)
50}
51
52impl Default for SearchParameters {
53    fn default() -> Self {
54        Self {
55            batch_size: 2000,
56            n_full_scores: 4096,
57            top_k: 10,
58            n_ivf_probe: 8,
59            centroid_batch_size: default_centroid_batch_size(),
60            centroid_score_threshold: default_centroid_score_threshold(),
61        }
62    }
63}
64
65/// Result of a single query
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct QueryResult {
68    /// Query ID
69    pub query_id: usize,
70    /// Retrieved document IDs (ranked by relevance)
71    pub passage_ids: Vec<i64>,
72    /// Relevance scores for each document
73    pub scores: Vec<f32>,
74}
75
76/// Minimum matrix size (query_tokens * doc_tokens) to use CUDA.
77/// Below this threshold, CPU is faster due to GPU transfer overhead.
78/// Based on benchmarks: 128 * 1024 = 131072
79#[cfg(feature = "cuda")]
80const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
81
82/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
83/// with any document token, then sum across query tokens.
84///
85/// When the `cuda` feature is enabled and matrices are large enough,
86/// this function automatically uses CUDA acceleration.
87fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
88    // Try CUDA for large matrices
89    #[cfg(feature = "cuda")]
90    {
91        let matrix_size = query.nrows() * doc.nrows();
92        if matrix_size >= CUDA_COLBERT_MIN_SIZE {
93            if let Some(ctx) = crate::cuda::get_global_context() {
94                match crate::cuda::colbert_score_cuda(ctx, query, doc) {
95                    Ok(score) => return score,
96                    Err(_) => {
97                        // Silent fallback to CPU for scoring (happens frequently)
98                    }
99                }
100            }
101        }
102    }
103
104    // CPU implementation
105    colbert_score_cpu(query, doc)
106}
107
108/// CPU implementation of ColBERT MaxSim scoring.
109/// Uses SIMD-accelerated max reduction and BLAS for matrix multiplication.
110fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
111    maxsim::maxsim_score(query, doc)
112}
113
114/// Compute adaptive IVF probe for filtered search on memory-mapped index.
115///
116/// Ensures enough centroids are probed to cover at least `top_k` candidates from the subset.
117/// This function counts how many subset documents are in each centroid, then greedily
118/// selects centroids (by query similarity) until the cumulative document count reaches `top_k`.
119#[allow(clippy::too_many_arguments)]
120fn compute_adaptive_ivf_probe_mmap(
121    query_centroid_scores: &Array2<f32>,
122    mmap_codes: &crate::mmap::MmapNpyArray1I64,
123    doc_offsets: &[usize],
124    num_docs: usize,
125    subset: &[i64],
126    top_k: usize,
127    n_ivf_probe: usize,
128    centroid_score_threshold: Option<f32>,
129) -> Vec<usize> {
130    // Count unique docs per centroid for subset
131    let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
132    for &doc_id in subset {
133        let doc_idx = doc_id as usize;
134        if doc_idx < num_docs {
135            let start = doc_offsets[doc_idx];
136            let end = doc_offsets[doc_idx + 1];
137            let codes = mmap_codes.slice(start, end);
138            for &c in codes.iter() {
139                centroid_doc_counts
140                    .entry(c as usize)
141                    .or_default()
142                    .insert(doc_id);
143            }
144        }
145    }
146
147    if centroid_doc_counts.is_empty() {
148        return vec![];
149    }
150
151    // Score each centroid by max query-centroid similarity
152    let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
153        .into_iter()
154        .map(|(c, docs)| {
155            let max_score: f32 = query_centroid_scores
156                .axis_iter(Axis(0))
157                .map(|q| q[c])
158                .max_by(|a, b| a.partial_cmp(b).unwrap())
159                .unwrap_or(0.0);
160            (c, max_score, docs.len())
161        })
162        .collect();
163
164    // Apply threshold if set
165    if let Some(threshold) = centroid_score_threshold {
166        scored_centroids.retain(|(_, score, _)| *score >= threshold);
167    }
168
169    // Sort by score descending
170    scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
171
172    // Greedily select centroids until we cover top_k candidates
173    let mut cumulative_docs = 0;
174    let mut n_probe = 0;
175
176    for (_, _, doc_count) in &scored_centroids {
177        cumulative_docs += doc_count;
178        n_probe += 1;
179        // Stop when we have enough coverage AND met minimum probe requirement
180        if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
181            break;
182        }
183    }
184
185    // Ensure at least n_ivf_probe centroids (unless fewer exist)
186    n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
187
188    scored_centroids
189        .iter()
190        .take(n_probe)
191        .map(|(c, _, _)| *c)
192        .collect()
193}
194
195/// Wrapper for f32 to use with BinaryHeap (implements Ord)
196#[derive(Clone, Copy, PartialEq)]
197struct OrdF32(f32);
198
199impl Eq for OrdF32 {}
200
201impl PartialOrd for OrdF32 {
202    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
203        Some(self.cmp(other))
204    }
205}
206
207impl Ord for OrdF32 {
208    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
209        self.0
210            .partial_cmp(&other.0)
211            .unwrap_or(std::cmp::Ordering::Equal)
212    }
213}
214
215/// Batched IVF probing for memory-efficient centroid scoring.
216///
217/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
218/// Returns the union of top centroids across all query tokens.
219/// If a threshold is provided, filters out centroids where max score < threshold.
220fn ivf_probe_batched(
221    query: &Array2<f32>,
222    centroids: &CentroidStore,
223    n_probe: usize,
224    batch_size: usize,
225    centroid_score_threshold: Option<f32>,
226) -> Vec<usize> {
227    let num_centroids = centroids.nrows();
228    let num_tokens = query.nrows();
229
230    // Min-heap per query token to track top centroids
231    // Entry: (Reverse(score), centroid_id) - Reverse for min-heap behavior
232    let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
233        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
234        .collect();
235
236    // Track max score per centroid for threshold filtering
237    let mut max_scores: HashMap<usize, f32> = HashMap::new();
238
239    for batch_start in (0..num_centroids).step_by(batch_size) {
240        let batch_end = (batch_start + batch_size).min(num_centroids);
241
242        // Get batch view (zero-copy from mmap)
243        let batch_centroids = centroids.slice_rows(batch_start, batch_end);
244
245        // Compute scores: [num_tokens, batch_size]
246        let batch_scores = query.dot(&batch_centroids.t());
247
248        // Update heaps with this batch's scores
249        for (q_idx, heap) in heaps.iter_mut().enumerate() {
250            for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
251                let global_c = batch_start + local_c;
252                let entry = (Reverse(OrdF32(score)), global_c);
253
254                if heap.len() < n_probe {
255                    heap.push(entry);
256                    // Track max score for threshold filtering
257                    max_scores
258                        .entry(global_c)
259                        .and_modify(|s| *s = s.max(score))
260                        .or_insert(score);
261                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
262                    if score > min_score {
263                        heap.pop();
264                        heap.push(entry);
265                        // Track max score for threshold filtering
266                        max_scores
267                            .entry(global_c)
268                            .and_modify(|s| *s = s.max(score))
269                            .or_insert(score);
270                    }
271                }
272            }
273        }
274    }
275
276    // Union top centroids across all query tokens
277    let mut selected: HashSet<usize> = HashSet::new();
278    for heap in heaps {
279        for (_, c) in heap {
280            selected.insert(c);
281        }
282    }
283
284    // Apply centroid score threshold if set
285    if let Some(threshold) = centroid_score_threshold {
286        selected.retain(|c| max_scores.get(c).copied().unwrap_or(f32::NEG_INFINITY) >= threshold);
287    }
288
289    selected.into_iter().collect()
290}
291
292/// Build sparse centroid scores for a set of centroid IDs.
293///
294/// Returns a HashMap mapping centroid_id -> query scores array.
295fn build_sparse_centroid_scores(
296    query: &Array2<f32>,
297    centroids: &CentroidStore,
298    centroid_ids: &HashSet<usize>,
299) -> HashMap<usize, Array1<f32>> {
300    centroid_ids
301        .iter()
302        .map(|&c| {
303            let centroid = centroids.row(c);
304            let scores: Array1<f32> = query.dot(&centroid);
305            (c, scores)
306        })
307        .collect()
308}
309
310/// Compute approximate scores using sparse centroid score lookup.
311fn approximate_score_sparse(
312    sparse_scores: &HashMap<usize, Array1<f32>>,
313    doc_codes: &[usize],
314    num_query_tokens: usize,
315) -> f32 {
316    let mut score = 0.0;
317
318    // For each query token
319    for q_idx in 0..num_query_tokens {
320        let mut max_score = f32::NEG_INFINITY;
321
322        // For each document token's code
323        for &code in doc_codes.iter() {
324            if let Some(centroid_scores) = sparse_scores.get(&code) {
325                let centroid_score = centroid_scores[q_idx];
326                if centroid_score > max_score {
327                    max_score = centroid_score;
328                }
329            }
330        }
331
332        if max_score > f32::NEG_INFINITY {
333            score += max_score;
334        }
335    }
336
337    score
338}
339
340/// Compute approximate scores for mmap index using code lookups.
341fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
342    let mut score = 0.0;
343
344    for q_idx in 0..query_centroid_scores.nrows() {
345        let mut max_score = f32::NEG_INFINITY;
346
347        for &code in doc_codes.iter() {
348            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
349            if centroid_score > max_score {
350                max_score = centroid_score;
351            }
352        }
353
354        if max_score > f32::NEG_INFINITY {
355            score += max_score;
356        }
357    }
358
359    score
360}
361
362/// Search a memory-mapped index for a single query.
363pub fn search_one_mmap(
364    index: &crate::index::MmapIndex,
365    query: &Array2<f32>,
366    params: &SearchParameters,
367    subset: Option<&[i64]>,
368) -> Result<QueryResult> {
369    let num_centroids = index.codec.num_centroids();
370    let num_query_tokens = query.nrows();
371
372    // Decide whether to use batched mode for memory efficiency
373    let use_batched = params.centroid_batch_size > 0
374        && num_centroids > params.centroid_batch_size
375        && subset.is_none();
376
377    if use_batched {
378        // Batched path: memory-efficient IVF probing for large centroid counts
379        return search_one_mmap_batched(index, query, params);
380    }
381
382    // Standard path: compute full query-centroid scores upfront
383    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
384
385    // Find top IVF cells to probe
386    let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
387        // Use adaptive IVF probing that ensures enough centroids to cover top_k candidates
388        compute_adaptive_ivf_probe_mmap(
389            &query_centroid_scores,
390            &index.mmap_codes,
391            index.doc_offsets.as_slice().unwrap(),
392            index.doc_lengths.len(),
393            subset_docs,
394            params.top_k,
395            params.n_ivf_probe,
396            params.centroid_score_threshold,
397        )
398    } else {
399        // Standard path: select top-k centroids per query token
400        let mut selected_centroids = HashSet::new();
401
402        for q_idx in 0..num_query_tokens {
403            let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
404                .map(|c| (c, query_centroid_scores[[q_idx, c]]))
405                .collect();
406
407            // Partial selection: O(K) average instead of O(K log K) for full sort
408            // After this, the top n_ivf_probe elements are in positions 0..n_ivf_probe
409            // (but not sorted among themselves - which is fine since we use a HashSet)
410            if centroid_scores.len() > params.n_ivf_probe {
411                centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
412                    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
413                });
414            }
415
416            for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
417                selected_centroids.insert(*c);
418            }
419        }
420
421        // Apply centroid score threshold: filter out centroids where max score < threshold
422        if let Some(threshold) = params.centroid_score_threshold {
423            selected_centroids.retain(|&c| {
424                let max_score: f32 = (0..num_query_tokens)
425                    .map(|q_idx| query_centroid_scores[[q_idx, c]])
426                    .max_by(|a, b| a.partial_cmp(b).unwrap())
427                    .unwrap_or(f32::NEG_INFINITY);
428                max_score >= threshold
429            });
430        }
431
432        selected_centroids.into_iter().collect()
433    };
434
435    // Get candidate documents from IVF
436    let mut candidates = index.get_candidates(&cells_to_probe);
437
438    // Filter by subset if provided
439    if let Some(subset_docs) = subset {
440        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
441        candidates.retain(|&c| subset_set.contains(&c));
442    }
443
444    if candidates.is_empty() {
445        return Ok(QueryResult {
446            query_id: 0,
447            passage_ids: vec![],
448            scores: vec![],
449        });
450    }
451
452    // Compute approximate scores
453    let mut approx_scores: Vec<(i64, f32)> = candidates
454        .par_iter()
455        .map(|&doc_id| {
456            let start = index.doc_offsets[doc_id as usize];
457            let end = index.doc_offsets[doc_id as usize + 1];
458            let codes = index.mmap_codes.slice(start, end);
459            let score = approximate_score_mmap(&query_centroid_scores, &codes);
460            (doc_id, score)
461        })
462        .collect();
463
464    // Sort by approximate score and take top candidates
465    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
466    let top_candidates: Vec<i64> = approx_scores
467        .iter()
468        .take(params.n_full_scores)
469        .map(|(id, _)| *id)
470        .collect();
471
472    // Further reduce for full decompression
473    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
474    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
475
476    if to_decompress.is_empty() {
477        return Ok(QueryResult {
478            query_id: 0,
479            passage_ids: vec![],
480            scores: vec![],
481        });
482    }
483
484    // Compute exact scores with decompressed embeddings
485    // Use chunked processing to limit concurrent memory from parallel decompression
486    let mut exact_scores: Vec<(i64, f32)> = to_decompress
487        .par_chunks(DECOMPRESS_CHUNK_SIZE)
488        .flat_map(|chunk| {
489            chunk
490                .iter()
491                .filter_map(|&doc_id| {
492                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
493                    let score = colbert_score(&query.view(), &doc_embeddings.view());
494                    Some((doc_id, score))
495                })
496                .collect::<Vec<_>>()
497        })
498        .collect();
499
500    // Sort by exact score
501    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
502
503    // Return top-k results
504    let result_count = params.top_k.min(exact_scores.len());
505    let passage_ids: Vec<i64> = exact_scores
506        .iter()
507        .take(result_count)
508        .map(|(id, _)| *id)
509        .collect();
510    let scores: Vec<f32> = exact_scores
511        .iter()
512        .take(result_count)
513        .map(|(_, s)| *s)
514        .collect();
515
516    Ok(QueryResult {
517        query_id: 0,
518        passage_ids,
519        scores,
520    })
521}
522
523/// Memory-efficient batched search for MmapIndex with large centroid counts.
524///
525/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
526fn search_one_mmap_batched(
527    index: &crate::index::MmapIndex,
528    query: &Array2<f32>,
529    params: &SearchParameters,
530) -> Result<QueryResult> {
531    let num_query_tokens = query.nrows();
532
533    // Step 1: Batched IVF probing
534    let cells_to_probe = ivf_probe_batched(
535        query,
536        &index.codec.centroids,
537        params.n_ivf_probe,
538        params.centroid_batch_size,
539        params.centroid_score_threshold,
540    );
541
542    // Step 2: Get candidate documents from IVF
543    let candidates = index.get_candidates(&cells_to_probe);
544
545    if candidates.is_empty() {
546        return Ok(QueryResult {
547            query_id: 0,
548            passage_ids: vec![],
549            scores: vec![],
550        });
551    }
552
553    // Step 3: Collect unique centroids from all candidate documents
554    let mut unique_centroids: HashSet<usize> = HashSet::new();
555    for &doc_id in &candidates {
556        let start = index.doc_offsets[doc_id as usize];
557        let end = index.doc_offsets[doc_id as usize + 1];
558        let codes = index.mmap_codes.slice(start, end);
559        for &code in codes.iter() {
560            unique_centroids.insert(code as usize);
561        }
562    }
563
564    // Step 4: Build sparse centroid scores
565    let sparse_scores =
566        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
567
568    // Step 5: Compute approximate scores using sparse lookup
569    let mut approx_scores: Vec<(i64, f32)> = candidates
570        .par_iter()
571        .map(|&doc_id| {
572            let start = index.doc_offsets[doc_id as usize];
573            let end = index.doc_offsets[doc_id as usize + 1];
574            let codes = index.mmap_codes.slice(start, end);
575            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
576            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
577            (doc_id, score)
578        })
579        .collect();
580
581    // Sort by approximate score and take top candidates
582    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
583    let top_candidates: Vec<i64> = approx_scores
584        .iter()
585        .take(params.n_full_scores)
586        .map(|(id, _)| *id)
587        .collect();
588
589    // Further reduce for full decompression
590    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
591    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
592
593    if to_decompress.is_empty() {
594        return Ok(QueryResult {
595            query_id: 0,
596            passage_ids: vec![],
597            scores: vec![],
598        });
599    }
600
601    // Compute exact scores with decompressed embeddings
602    // Use chunked processing to limit concurrent memory from parallel decompression
603    let mut exact_scores: Vec<(i64, f32)> = to_decompress
604        .par_chunks(DECOMPRESS_CHUNK_SIZE)
605        .flat_map(|chunk| {
606            chunk
607                .iter()
608                .filter_map(|&doc_id| {
609                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
610                    let score = colbert_score(&query.view(), &doc_embeddings.view());
611                    Some((doc_id, score))
612                })
613                .collect::<Vec<_>>()
614        })
615        .collect();
616
617    // Sort by exact score
618    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
619
620    // Return top-k results
621    let result_count = params.top_k.min(exact_scores.len());
622    let passage_ids: Vec<i64> = exact_scores
623        .iter()
624        .take(result_count)
625        .map(|(id, _)| *id)
626        .collect();
627    let scores: Vec<f32> = exact_scores
628        .iter()
629        .take(result_count)
630        .map(|(_, s)| *s)
631        .collect();
632
633    Ok(QueryResult {
634        query_id: 0,
635        passage_ids,
636        scores,
637    })
638}
639
640/// Search a memory-mapped index for multiple queries.
641pub fn search_many_mmap(
642    index: &crate::index::MmapIndex,
643    queries: &[Array2<f32>],
644    params: &SearchParameters,
645    parallel: bool,
646    subset: Option<&[i64]>,
647) -> Result<Vec<QueryResult>> {
648    if parallel {
649        let results: Vec<QueryResult> = queries
650            .par_iter()
651            .enumerate()
652            .map(|(i, query)| {
653                let mut result =
654                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
655                        query_id: i,
656                        passage_ids: vec![],
657                        scores: vec![],
658                    });
659                result.query_id = i;
660                result
661            })
662            .collect();
663        Ok(results)
664    } else {
665        let mut results = Vec::with_capacity(queries.len());
666        for (i, query) in queries.iter().enumerate() {
667            let mut result = search_one_mmap(index, query, params, subset)?;
668            result.query_id = i;
669            results.push(result);
670        }
671        Ok(results)
672    }
673}
674
675/// Alias type for search result (for API compatibility)
676pub type SearchResult = QueryResult;
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681
682    #[test]
683    fn test_colbert_score() {
684        // Query with 2 tokens, dim 4
685        let query =
686            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
687
688        // Document with 3 tokens
689        let doc = Array2::from_shape_vec(
690            (3, 4),
691            vec![
692                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
693                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
694                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
695            ],
696        )
697        .unwrap();
698
699        let score = colbert_score(&query.view(), &doc.view());
700        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
701        // Total: 0.8 + 0.9 = 1.7
702        assert!((score - 1.7).abs() < 1e-5);
703    }
704
705    #[test]
706    fn test_search_params_default() {
707        let params = SearchParameters::default();
708        assert_eq!(params.batch_size, 2000);
709        assert_eq!(params.n_full_scores, 4096);
710        assert_eq!(params.top_k, 10);
711        assert_eq!(params.n_ivf_probe, 8);
712        assert_eq!(params.centroid_score_threshold, Some(0.4));
713    }
714}