Skip to main content

next_plaid/
search.rs

1//! Search functionality for PLAID
2
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15/// Per-token top-k heaps and per-centroid max scores from a batch of centroids.
16type ProbePartial = (
17    Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18    HashMap<usize, f32>,
19);
20
21/// Maximum number of documents to decompress concurrently during exact scoring.
22/// This limits peak memory usage from parallel decompression.
23/// With 128 docs × ~300KB per doc = ~40MB max concurrent decompression memory.
24const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26/// Search parameters
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29    /// Number of queries per batch
30    pub batch_size: usize,
31    /// Number of documents to re-rank with exact scores
32    pub n_full_scores: usize,
33    /// Number of final results to return per query
34    pub top_k: usize,
35    /// Number of IVF cells to probe during search
36    pub n_ivf_probe: usize,
37    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
38    /// Lower values use less memory but are slower. Default 100_000.
39    /// Only used when num_centroids > centroid_batch_size.
40    #[serde(default = "default_centroid_batch_size")]
41    pub centroid_batch_size: usize,
42    /// Centroid score threshold (t_cs) for centroid pruning.
43    /// A centroid is only included if its maximum score across all query tokens
44    /// meets or exceeds this threshold. Set to None to disable pruning.
45    /// Default: Some(0.4)
46    #[serde(default = "default_centroid_score_threshold")]
47    pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51    100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55    Some(0.4)
56}
57
58impl Default for SearchParameters {
59    fn default() -> Self {
60        Self {
61            batch_size: 2000,
62            n_full_scores: 4096,
63            top_k: 10,
64            n_ivf_probe: 8,
65            centroid_batch_size: default_centroid_batch_size(),
66            centroid_score_threshold: default_centroid_score_threshold(),
67        }
68    }
69}
70
71/// Result of a single query
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74    /// Query ID
75    pub query_id: usize,
76    /// Retrieved document IDs (ranked by relevance)
77    pub passage_ids: Vec<i64>,
78    /// Relevance scores for each document
79    pub scores: Vec<f32>,
80}
81
82/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
83/// with any document token, then sum across query tokens.
84///
85/// Always uses the CPU implementation (BLAS GEMM + SIMD max reduction), which
86/// benchmarks show is faster than CUDA for per-document scoring due to GPU
87/// transfer overhead dominating at typical query/document sizes.
88fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
89    maxsim::maxsim_score(query, doc)
90}
91
92/// Wrapper for f32 to use with BinaryHeap (implements Ord)
93#[derive(Clone, Copy, PartialEq)]
94struct OrdF32(f32);
95
96impl Eq for OrdF32 {}
97
98impl PartialOrd for OrdF32 {
99    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
100        Some(self.cmp(other))
101    }
102}
103
104impl Ord for OrdF32 {
105    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
106        self.0
107            .partial_cmp(&other.0)
108            .unwrap_or(std::cmp::Ordering::Equal)
109    }
110}
111
112/// Batched IVF probing for memory-efficient centroid scoring.
113///
114/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
115/// Returns the union of top centroids across all query tokens.
116/// If a threshold is provided, filters out centroids where max score < threshold.
117fn ivf_probe_batched(
118    query: &Array2<f32>,
119    centroids: &CentroidStore,
120    n_probe: usize,
121    batch_size: usize,
122    centroid_score_threshold: Option<f32>,
123) -> Vec<usize> {
124    let num_centroids = centroids.nrows();
125    let num_tokens = query.nrows();
126
127    // Build batch ranges for parallel processing
128    let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
129        .step_by(batch_size)
130        .map(|start| (start, (start + batch_size).min(num_centroids)))
131        .collect();
132
133    // Process centroid batches in parallel. Each rayon thread computes a GEMM
134    // (with single-threaded BLAS via OPENBLAS_NUM_THREADS=1) and maintains local
135    // per-token top-k heaps. Memory is bounded: rayon's thread pool ensures at most
136    // num_cpus batch_scores matrices (each batch_size × num_tokens × 4 bytes) exist
137    // simultaneously, same as the sequential approach where num_cpus queries each
138    // process one batch at a time.
139    let local_results: Vec<ProbePartial> = batch_ranges
140        .par_iter()
141        .map(|&(batch_start, batch_end)| {
142            let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
143                .map(|_| BinaryHeap::with_capacity(n_probe + 1))
144                .collect();
145            let mut max_scores: HashMap<usize, f32> = HashMap::new();
146
147            // Get batch view (zero-copy from mmap)
148            let batch_centroids = centroids.slice_rows(batch_start, batch_end);
149
150            // Compute scores: [num_tokens, batch_size] — single-threaded BLAS
151            let batch_scores = query.dot(&batch_centroids.t());
152
153            // Update local heaps with this batch's scores
154            for (q_idx, heap) in heaps.iter_mut().enumerate() {
155                for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
156                    let global_c = batch_start + local_c;
157                    let entry = (Reverse(OrdF32(score)), global_c);
158
159                    if heap.len() < n_probe {
160                        heap.push(entry);
161                        max_scores
162                            .entry(global_c)
163                            .and_modify(|s| *s = s.max(score))
164                            .or_insert(score);
165                    } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
166                        if score > min_score {
167                            heap.pop();
168                            heap.push(entry);
169                            max_scores
170                                .entry(global_c)
171                                .and_modify(|s| *s = s.max(score))
172                                .or_insert(score);
173                        }
174                    }
175                }
176            }
177
178            (heaps, max_scores)
179        })
180        .collect();
181
182    // Merge local heaps into final result (lightweight: each heap has at most
183    // n_probe entries, and there are num_batches heaps per token to merge)
184    let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
185        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
186        .collect();
187    let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
188
189    for (local_heaps, local_max_scores) in local_results {
190        for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
191            for entry in local_heap {
192                let (Reverse(OrdF32(score)), _) = entry;
193                if final_heaps[q_idx].len() < n_probe {
194                    final_heaps[q_idx].push(entry);
195                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
196                    if score > min_score {
197                        final_heaps[q_idx].pop();
198                        final_heaps[q_idx].push(entry);
199                    }
200                }
201            }
202        }
203        for (c, score) in local_max_scores {
204            final_max_scores
205                .entry(c)
206                .and_modify(|s| *s = s.max(score))
207                .or_insert(score);
208        }
209    }
210
211    // Union top centroids across all query tokens
212    let mut selected: HashSet<usize> = HashSet::new();
213    for heap in final_heaps {
214        for (_, c) in heap {
215            selected.insert(c);
216        }
217    }
218
219    // Apply centroid score threshold if set
220    if let Some(threshold) = centroid_score_threshold {
221        selected.retain(|c| {
222            final_max_scores
223                .get(c)
224                .copied()
225                .unwrap_or(f32::NEG_INFINITY)
226                >= threshold
227        });
228    }
229
230    selected.into_iter().collect()
231}
232
233/// Build sparse centroid scores for a set of centroid IDs.
234///
235/// Returns a HashMap mapping centroid_id -> query scores array.
236fn build_sparse_centroid_scores(
237    query: &Array2<f32>,
238    centroids: &CentroidStore,
239    centroid_ids: &HashSet<usize>,
240) -> HashMap<usize, Array1<f32>> {
241    centroid_ids
242        .iter()
243        .map(|&c| {
244            let centroid = centroids.row(c);
245            let scores: Array1<f32> = query.dot(&centroid);
246            (c, scores)
247        })
248        .collect()
249}
250
251/// Compute approximate scores using sparse centroid score lookup.
252fn approximate_score_sparse(
253    sparse_scores: &HashMap<usize, Array1<f32>>,
254    doc_codes: &[usize],
255    num_query_tokens: usize,
256) -> f32 {
257    let mut score = 0.0;
258
259    // For each query token
260    for q_idx in 0..num_query_tokens {
261        let mut max_score = f32::NEG_INFINITY;
262
263        // For each document token's code
264        for &code in doc_codes.iter() {
265            if let Some(centroid_scores) = sparse_scores.get(&code) {
266                let centroid_score = centroid_scores[q_idx];
267                if centroid_score > max_score {
268                    max_score = centroid_score;
269                }
270            }
271        }
272
273        if max_score > f32::NEG_INFINITY {
274            score += max_score;
275        }
276    }
277
278    score
279}
280
281/// Compute approximate scores for mmap index using code lookups.
282fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
283    let mut score = 0.0;
284
285    for q_idx in 0..query_centroid_scores.nrows() {
286        let mut max_score = f32::NEG_INFINITY;
287
288        for &code in doc_codes.iter() {
289            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
290            if centroid_score > max_score {
291                max_score = centroid_score;
292            }
293        }
294
295        if max_score > f32::NEG_INFINITY {
296            score += max_score;
297        }
298    }
299
300    score
301}
302
303/// Search a memory-mapped index for a single query.
304pub fn search_one_mmap(
305    index: &crate::index::MmapIndex,
306    query: &Array2<f32>,
307    params: &SearchParameters,
308    subset: Option<&[i64]>,
309) -> Result<QueryResult> {
310    let num_centroids = index.codec.num_centroids();
311    let num_query_tokens = query.nrows();
312
313    // Decide whether to use batched mode for memory efficiency
314    let use_batched = params.centroid_batch_size > 0 && num_centroids > params.centroid_batch_size;
315
316    if use_batched {
317        // Batched path: memory-efficient IVF probing for large centroid counts
318        return search_one_mmap_batched(index, query, params, subset);
319    }
320
321    // Standard path: compute full query-centroid scores upfront
322    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
323
324    // When subset is provided, pre-compute eligible centroids: only those containing
325    // at least one embedding from a subset document. Centroids without subset docs
326    // can't contribute candidates, so skipping them is a pure optimization.
327    let eligible_centroids: Option<HashSet<usize>> = subset.map(|subset_docs| {
328        let mut centroids = HashSet::new();
329        for &doc_id in subset_docs {
330            let doc_idx = doc_id as usize;
331            if doc_idx < index.doc_lengths.len() {
332                let start = index.doc_offsets[doc_idx];
333                let end = index.doc_offsets[doc_idx + 1];
334                let codes = index.mmap_codes.slice(start, end);
335                for &c in codes.iter() {
336                    centroids.insert(c as usize);
337                }
338            }
339        }
340        centroids
341    });
342
343    // When pre-filtering, scale n_ivf_probe by the document ratio to compensate
344    // for candidates lost to filtering. If 50% of docs are filtered out, we probe
345    // ~2x more centroids to find enough relevant candidates.
346    // No filter: n_ivf_probe unchanged.
347    let effective_n_ivf_probe = match (&eligible_centroids, subset) {
348        (Some(eligible), Some(subset_docs)) if !eligible.is_empty() => {
349            let num_docs = index.doc_lengths.len();
350            let subset_len = subset_docs.len();
351            let scaled = if subset_len > 0 {
352                (params.n_ivf_probe as u64 * num_docs as u64 / subset_len as u64) as usize
353            } else {
354                params.n_ivf_probe
355            };
356            scaled.max(params.n_ivf_probe).min(eligible.len())
357        }
358        _ => params.n_ivf_probe,
359    };
360
361    // Find top IVF cells to probe using per-token top-k selection.
362    // When pre-filtering, only score eligible centroids (same selection logic,
363    // smaller pool). This can only improve recall for subset docs since
364    // ineligible centroids would have wasted probe slots.
365    let cells_to_probe: Vec<usize> = {
366        let mut selected_centroids = HashSet::new();
367
368        for q_idx in 0..num_query_tokens {
369            let mut centroid_scores: Vec<(usize, f32)> = match &eligible_centroids {
370                Some(eligible) => eligible
371                    .iter()
372                    .map(|&c| (c, query_centroid_scores[[q_idx, c]]))
373                    .collect(),
374                None => (0..num_centroids)
375                    .map(|c| (c, query_centroid_scores[[q_idx, c]]))
376                    .collect(),
377            };
378
379            // Partial selection: O(K) average instead of O(K log K) for full sort
380            // After this, the top n elements are in positions 0..n
381            // (but not sorted among themselves - which is fine since we use a HashSet)
382            let n_probe = effective_n_ivf_probe.min(centroid_scores.len());
383            if centroid_scores.len() > n_probe {
384                centroid_scores.select_nth_unstable_by(n_probe - 1, |a, b| {
385                    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
386                });
387            }
388
389            for (c, _) in centroid_scores.iter().take(n_probe) {
390                selected_centroids.insert(*c);
391            }
392        }
393
394        // Apply centroid score threshold: filter out centroids where max score < threshold
395        if let Some(threshold) = params.centroid_score_threshold {
396            selected_centroids.retain(|&c| {
397                let max_score: f32 = (0..num_query_tokens)
398                    .map(|q_idx| query_centroid_scores[[q_idx, c]])
399                    .max_by(|a, b| a.partial_cmp(b).unwrap())
400                    .unwrap_or(f32::NEG_INFINITY);
401                max_score >= threshold
402            });
403        }
404
405        selected_centroids.into_iter().collect()
406    };
407
408    // Get candidate documents from IVF
409    let mut candidates = index.get_candidates(&cells_to_probe);
410
411    // Filter by subset if provided
412    if let Some(subset_docs) = subset {
413        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
414        candidates.retain(|&c| subset_set.contains(&c));
415    }
416
417    if candidates.is_empty() {
418        return Ok(QueryResult {
419            query_id: 0,
420            passage_ids: vec![],
421            scores: vec![],
422        });
423    }
424
425    // Compute approximate scores
426    let mut approx_scores: Vec<(i64, f32)> = candidates
427        .par_iter()
428        .map(|&doc_id| {
429            let start = index.doc_offsets[doc_id as usize];
430            let end = index.doc_offsets[doc_id as usize + 1];
431            let codes = index.mmap_codes.slice(start, end);
432            let score = approximate_score_mmap(&query_centroid_scores, &codes);
433            (doc_id, score)
434        })
435        .collect();
436
437    // Sort by approximate score and take top candidates
438    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
439    let top_candidates: Vec<i64> = approx_scores
440        .iter()
441        .take(params.n_full_scores)
442        .map(|(id, _)| *id)
443        .collect();
444
445    // Further reduce for full decompression
446    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
447    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
448
449    if to_decompress.is_empty() {
450        return Ok(QueryResult {
451            query_id: 0,
452            passage_ids: vec![],
453            scores: vec![],
454        });
455    }
456
457    // Compute exact scores with decompressed embeddings
458    // Use chunked processing to limit concurrent memory from parallel decompression
459    let mut exact_scores: Vec<(i64, f32)> = to_decompress
460        .par_chunks(DECOMPRESS_CHUNK_SIZE)
461        .flat_map(|chunk| {
462            chunk
463                .iter()
464                .filter_map(|&doc_id| {
465                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
466                    let score = colbert_score(&query.view(), &doc_embeddings.view());
467                    Some((doc_id, score))
468                })
469                .collect::<Vec<_>>()
470        })
471        .collect();
472
473    // Sort by exact score
474    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
475
476    // Return top-k results
477    let result_count = params.top_k.min(exact_scores.len());
478    let passage_ids: Vec<i64> = exact_scores
479        .iter()
480        .take(result_count)
481        .map(|(id, _)| *id)
482        .collect();
483    let scores: Vec<f32> = exact_scores
484        .iter()
485        .take(result_count)
486        .map(|(_, s)| *s)
487        .collect();
488
489    Ok(QueryResult {
490        query_id: 0,
491        passage_ids,
492        scores,
493    })
494}
495
496/// Memory-efficient batched search for MmapIndex with large centroid counts.
497///
498/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
499fn search_one_mmap_batched(
500    index: &crate::index::MmapIndex,
501    query: &Array2<f32>,
502    params: &SearchParameters,
503    subset: Option<&[i64]>,
504) -> Result<QueryResult> {
505    let num_query_tokens = query.nrows();
506
507    // Step 1: Batched IVF probing
508    let cells_to_probe = ivf_probe_batched(
509        query,
510        &index.codec.centroids,
511        params.n_ivf_probe,
512        params.centroid_batch_size,
513        params.centroid_score_threshold,
514    );
515
516    // Step 2: Get candidate documents from IVF
517    let mut candidates = index.get_candidates(&cells_to_probe);
518
519    // Filter by subset if provided
520    if let Some(subset_docs) = subset {
521        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
522        candidates.retain(|&c| subset_set.contains(&c));
523    }
524
525    if candidates.is_empty() {
526        return Ok(QueryResult {
527            query_id: 0,
528            passage_ids: vec![],
529            scores: vec![],
530        });
531    }
532
533    // Step 3: Collect unique centroids from all candidate documents
534    let mut unique_centroids: HashSet<usize> = HashSet::new();
535    for &doc_id in &candidates {
536        let start = index.doc_offsets[doc_id as usize];
537        let end = index.doc_offsets[doc_id as usize + 1];
538        let codes = index.mmap_codes.slice(start, end);
539        for &code in codes.iter() {
540            unique_centroids.insert(code as usize);
541        }
542    }
543
544    // Step 4: Build sparse centroid scores
545    let sparse_scores =
546        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
547
548    // Step 5: Compute approximate scores using sparse lookup
549    let mut approx_scores: Vec<(i64, f32)> = candidates
550        .par_iter()
551        .map(|&doc_id| {
552            let start = index.doc_offsets[doc_id as usize];
553            let end = index.doc_offsets[doc_id as usize + 1];
554            let codes = index.mmap_codes.slice(start, end);
555            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
556            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
557            (doc_id, score)
558        })
559        .collect();
560
561    // Sort by approximate score and take top candidates
562    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
563    let top_candidates: Vec<i64> = approx_scores
564        .iter()
565        .take(params.n_full_scores)
566        .map(|(id, _)| *id)
567        .collect();
568
569    // Further reduce for full decompression
570    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
571    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
572
573    if to_decompress.is_empty() {
574        return Ok(QueryResult {
575            query_id: 0,
576            passage_ids: vec![],
577            scores: vec![],
578        });
579    }
580
581    // Compute exact scores with decompressed embeddings
582    // Use chunked processing to limit concurrent memory from parallel decompression
583    let mut exact_scores: Vec<(i64, f32)> = to_decompress
584        .par_chunks(DECOMPRESS_CHUNK_SIZE)
585        .flat_map(|chunk| {
586            chunk
587                .iter()
588                .filter_map(|&doc_id| {
589                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
590                    let score = colbert_score(&query.view(), &doc_embeddings.view());
591                    Some((doc_id, score))
592                })
593                .collect::<Vec<_>>()
594        })
595        .collect();
596
597    // Sort by exact score
598    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
599
600    // Return top-k results
601    let result_count = params.top_k.min(exact_scores.len());
602    let passage_ids: Vec<i64> = exact_scores
603        .iter()
604        .take(result_count)
605        .map(|(id, _)| *id)
606        .collect();
607    let scores: Vec<f32> = exact_scores
608        .iter()
609        .take(result_count)
610        .map(|(_, s)| *s)
611        .collect();
612
613    Ok(QueryResult {
614        query_id: 0,
615        passage_ids,
616        scores,
617    })
618}
619
620/// Search a memory-mapped index for multiple queries.
621pub fn search_many_mmap(
622    index: &crate::index::MmapIndex,
623    queries: &[Array2<f32>],
624    params: &SearchParameters,
625    parallel: bool,
626    subset: Option<&[i64]>,
627) -> Result<Vec<QueryResult>> {
628    if parallel {
629        let results: Vec<QueryResult> = queries
630            .par_iter()
631            .enumerate()
632            .map(|(i, query)| {
633                let mut result =
634                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
635                        query_id: i,
636                        passage_ids: vec![],
637                        scores: vec![],
638                    });
639                result.query_id = i;
640                result
641            })
642            .collect();
643        Ok(results)
644    } else {
645        let mut results = Vec::with_capacity(queries.len());
646        for (i, query) in queries.iter().enumerate() {
647            let mut result = search_one_mmap(index, query, params, subset)?;
648            result.query_id = i;
649            results.push(result);
650        }
651        Ok(results)
652    }
653}
654
655/// Alias type for search result (for API compatibility)
656pub type SearchResult = QueryResult;
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_colbert_score() {
664        // Query with 2 tokens, dim 4
665        let query =
666            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
667
668        // Document with 3 tokens
669        let doc = Array2::from_shape_vec(
670            (3, 4),
671            vec![
672                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
673                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
674                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
675            ],
676        )
677        .unwrap();
678
679        let score = colbert_score(&query.view(), &doc.view());
680        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
681        // Total: 0.8 + 0.9 = 1.7
682        assert!((score - 1.7).abs() < 1e-5);
683    }
684
685    #[test]
686    fn test_search_params_default() {
687        let params = SearchParameters::default();
688        assert_eq!(params.batch_size, 2000);
689        assert_eq!(params.n_full_scores, 4096);
690        assert_eq!(params.top_k, 10);
691        assert_eq!(params.n_ivf_probe, 8);
692        assert_eq!(params.centroid_score_threshold, Some(0.4));
693    }
694}