Skip to main content

next_plaid/
search.rs

1//! Search functionality for PLAID
2
3use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13
14/// Maximum number of documents to decompress concurrently during exact scoring.
15/// This limits peak memory usage from parallel decompression.
16/// With 128 docs × ~300KB per doc = ~40MB max concurrent decompression memory.
17const DECOMPRESS_CHUNK_SIZE: usize = 128;
18
19/// Search parameters
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SearchParameters {
22    /// Number of queries per batch
23    pub batch_size: usize,
24    /// Number of documents to re-rank with exact scores
25    pub n_full_scores: usize,
26    /// Number of final results to return per query
27    pub top_k: usize,
28    /// Number of IVF cells to probe during search
29    pub n_ivf_probe: usize,
30    /// Batch size for centroid scoring during IVF probing (0 = exhaustive).
31    /// Lower values use less memory but are slower. Default 100_000.
32    /// Only used when num_centroids > centroid_batch_size.
33    #[serde(default = "default_centroid_batch_size")]
34    pub centroid_batch_size: usize,
35    /// Centroid score threshold (t_cs) for centroid pruning.
36    /// A centroid is only included if its maximum score across all query tokens
37    /// meets or exceeds this threshold. Set to None to disable pruning.
38    /// Default: Some(0.4)
39    #[serde(default = "default_centroid_score_threshold")]
40    pub centroid_score_threshold: Option<f32>,
41}
42
43fn default_centroid_batch_size() -> usize {
44    100_000
45}
46
47fn default_centroid_score_threshold() -> Option<f32> {
48    Some(0.4)
49}
50
51impl Default for SearchParameters {
52    fn default() -> Self {
53        Self {
54            batch_size: 2000,
55            n_full_scores: 4096,
56            top_k: 10,
57            n_ivf_probe: 8,
58            centroid_batch_size: default_centroid_batch_size(),
59            centroid_score_threshold: default_centroid_score_threshold(),
60        }
61    }
62}
63
64/// Result of a single query
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QueryResult {
67    /// Query ID
68    pub query_id: usize,
69    /// Retrieved document IDs (ranked by relevance)
70    pub passage_ids: Vec<i64>,
71    /// Relevance scores for each document
72    pub scores: Vec<f32>,
73}
74
75/// ColBERT-style MaxSim scoring: for each query token, find the max similarity
76/// with any document token, then sum across query tokens.
77fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
78    let mut total_score = 0.0;
79
80    // For each query token
81    for q_row in query.axis_iter(Axis(0)) {
82        let mut max_sim = f32::NEG_INFINITY;
83
84        // Find max similarity with any document token
85        for d_row in doc.axis_iter(Axis(0)) {
86            let sim: f32 = q_row.dot(&d_row);
87            if sim > max_sim {
88                max_sim = sim;
89            }
90        }
91
92        if max_sim > f32::NEG_INFINITY {
93            total_score += max_sim;
94        }
95    }
96
97    total_score
98}
99
100/// Compute adaptive IVF probe for filtered search on memory-mapped index.
101///
102/// Ensures enough centroids are probed to cover at least `top_k` candidates from the subset.
103/// This function counts how many subset documents are in each centroid, then greedily
104/// selects centroids (by query similarity) until the cumulative document count reaches `top_k`.
105#[allow(clippy::too_many_arguments)]
106fn compute_adaptive_ivf_probe_mmap(
107    query_centroid_scores: &Array2<f32>,
108    mmap_codes: &crate::mmap::MmapNpyArray1I64,
109    doc_offsets: &[usize],
110    num_docs: usize,
111    subset: &[i64],
112    top_k: usize,
113    n_ivf_probe: usize,
114    centroid_score_threshold: Option<f32>,
115) -> Vec<usize> {
116    // Count unique docs per centroid for subset
117    let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
118    for &doc_id in subset {
119        let doc_idx = doc_id as usize;
120        if doc_idx < num_docs {
121            let start = doc_offsets[doc_idx];
122            let end = doc_offsets[doc_idx + 1];
123            let codes = mmap_codes.slice(start, end);
124            for &c in codes.iter() {
125                centroid_doc_counts
126                    .entry(c as usize)
127                    .or_default()
128                    .insert(doc_id);
129            }
130        }
131    }
132
133    if centroid_doc_counts.is_empty() {
134        return vec![];
135    }
136
137    // Score each centroid by max query-centroid similarity
138    let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
139        .into_iter()
140        .map(|(c, docs)| {
141            let max_score: f32 = query_centroid_scores
142                .axis_iter(Axis(0))
143                .map(|q| q[c])
144                .max_by(|a, b| a.partial_cmp(b).unwrap())
145                .unwrap_or(0.0);
146            (c, max_score, docs.len())
147        })
148        .collect();
149
150    // Apply threshold if set
151    if let Some(threshold) = centroid_score_threshold {
152        scored_centroids.retain(|(_, score, _)| *score >= threshold);
153    }
154
155    // Sort by score descending
156    scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
157
158    // Greedily select centroids until we cover top_k candidates
159    let mut cumulative_docs = 0;
160    let mut n_probe = 0;
161
162    for (_, _, doc_count) in &scored_centroids {
163        cumulative_docs += doc_count;
164        n_probe += 1;
165        // Stop when we have enough coverage AND met minimum probe requirement
166        if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
167            break;
168        }
169    }
170
171    // Ensure at least n_ivf_probe centroids (unless fewer exist)
172    n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
173
174    scored_centroids
175        .iter()
176        .take(n_probe)
177        .map(|(c, _, _)| *c)
178        .collect()
179}
180
181/// Wrapper for f32 to use with BinaryHeap (implements Ord)
182#[derive(Clone, Copy, PartialEq)]
183struct OrdF32(f32);
184
185impl Eq for OrdF32 {}
186
187impl PartialOrd for OrdF32 {
188    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
189        Some(self.cmp(other))
190    }
191}
192
193impl Ord for OrdF32 {
194    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
195        self.0
196            .partial_cmp(&other.0)
197            .unwrap_or(std::cmp::Ordering::Equal)
198    }
199}
200
201/// Batched IVF probing for memory-efficient centroid scoring.
202///
203/// Processes centroids in chunks, keeping only top-k scores per query token in a heap.
204/// Returns the union of top centroids across all query tokens.
205/// If a threshold is provided, filters out centroids where max score < threshold.
206fn ivf_probe_batched(
207    query: &Array2<f32>,
208    centroids: &CentroidStore,
209    n_probe: usize,
210    batch_size: usize,
211    centroid_score_threshold: Option<f32>,
212) -> Vec<usize> {
213    let num_centroids = centroids.nrows();
214    let num_tokens = query.nrows();
215
216    // Min-heap per query token to track top centroids
217    // Entry: (Reverse(score), centroid_id) - Reverse for min-heap behavior
218    let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
219        .map(|_| BinaryHeap::with_capacity(n_probe + 1))
220        .collect();
221
222    // Track max score per centroid for threshold filtering
223    let mut max_scores: HashMap<usize, f32> = HashMap::new();
224
225    for batch_start in (0..num_centroids).step_by(batch_size) {
226        let batch_end = (batch_start + batch_size).min(num_centroids);
227
228        // Get batch view (zero-copy from mmap)
229        let batch_centroids = centroids.slice_rows(batch_start, batch_end);
230
231        // Compute scores: [num_tokens, batch_size]
232        let batch_scores = query.dot(&batch_centroids.t());
233
234        // Update heaps with this batch's scores
235        for (q_idx, heap) in heaps.iter_mut().enumerate() {
236            for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
237                let global_c = batch_start + local_c;
238                let entry = (Reverse(OrdF32(score)), global_c);
239
240                if heap.len() < n_probe {
241                    heap.push(entry);
242                    // Track max score for threshold filtering
243                    max_scores
244                        .entry(global_c)
245                        .and_modify(|s| *s = s.max(score))
246                        .or_insert(score);
247                } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
248                    if score > min_score {
249                        heap.pop();
250                        heap.push(entry);
251                        // Track max score for threshold filtering
252                        max_scores
253                            .entry(global_c)
254                            .and_modify(|s| *s = s.max(score))
255                            .or_insert(score);
256                    }
257                }
258            }
259        }
260    }
261
262    // Union top centroids across all query tokens
263    let mut selected: HashSet<usize> = HashSet::new();
264    for heap in heaps {
265        for (_, c) in heap {
266            selected.insert(c);
267        }
268    }
269
270    // Apply centroid score threshold if set
271    if let Some(threshold) = centroid_score_threshold {
272        selected.retain(|c| max_scores.get(c).copied().unwrap_or(f32::NEG_INFINITY) >= threshold);
273    }
274
275    selected.into_iter().collect()
276}
277
278/// Build sparse centroid scores for a set of centroid IDs.
279///
280/// Returns a HashMap mapping centroid_id -> query scores array.
281fn build_sparse_centroid_scores(
282    query: &Array2<f32>,
283    centroids: &CentroidStore,
284    centroid_ids: &HashSet<usize>,
285) -> HashMap<usize, Array1<f32>> {
286    centroid_ids
287        .iter()
288        .map(|&c| {
289            let centroid = centroids.row(c);
290            let scores: Array1<f32> = query.dot(&centroid);
291            (c, scores)
292        })
293        .collect()
294}
295
296/// Compute approximate scores using sparse centroid score lookup.
297fn approximate_score_sparse(
298    sparse_scores: &HashMap<usize, Array1<f32>>,
299    doc_codes: &[usize],
300    num_query_tokens: usize,
301) -> f32 {
302    let mut score = 0.0;
303
304    // For each query token
305    for q_idx in 0..num_query_tokens {
306        let mut max_score = f32::NEG_INFINITY;
307
308        // For each document token's code
309        for &code in doc_codes.iter() {
310            if let Some(centroid_scores) = sparse_scores.get(&code) {
311                let centroid_score = centroid_scores[q_idx];
312                if centroid_score > max_score {
313                    max_score = centroid_score;
314                }
315            }
316        }
317
318        if max_score > f32::NEG_INFINITY {
319            score += max_score;
320        }
321    }
322
323    score
324}
325
326/// Compute approximate scores for mmap index using code lookups.
327fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
328    let mut score = 0.0;
329
330    for q_idx in 0..query_centroid_scores.nrows() {
331        let mut max_score = f32::NEG_INFINITY;
332
333        for &code in doc_codes.iter() {
334            let centroid_score = query_centroid_scores[[q_idx, code as usize]];
335            if centroid_score > max_score {
336                max_score = centroid_score;
337            }
338        }
339
340        if max_score > f32::NEG_INFINITY {
341            score += max_score;
342        }
343    }
344
345    score
346}
347
348/// Search a memory-mapped index for a single query.
349pub fn search_one_mmap(
350    index: &crate::index::MmapIndex,
351    query: &Array2<f32>,
352    params: &SearchParameters,
353    subset: Option<&[i64]>,
354) -> Result<QueryResult> {
355    let num_centroids = index.codec.num_centroids();
356    let num_query_tokens = query.nrows();
357
358    // Decide whether to use batched mode for memory efficiency
359    let use_batched = params.centroid_batch_size > 0
360        && num_centroids > params.centroid_batch_size
361        && subset.is_none();
362
363    if use_batched {
364        // Batched path: memory-efficient IVF probing for large centroid counts
365        return search_one_mmap_batched(index, query, params);
366    }
367
368    // Standard path: compute full query-centroid scores upfront
369    let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
370
371    // Find top IVF cells to probe
372    let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
373        // Use adaptive IVF probing that ensures enough centroids to cover top_k candidates
374        compute_adaptive_ivf_probe_mmap(
375            &query_centroid_scores,
376            &index.mmap_codes,
377            index.doc_offsets.as_slice().unwrap(),
378            index.doc_lengths.len(),
379            subset_docs,
380            params.top_k,
381            params.n_ivf_probe,
382            params.centroid_score_threshold,
383        )
384    } else {
385        // Standard path: select top-k centroids per query token
386        let mut selected_centroids = HashSet::new();
387
388        for q_idx in 0..num_query_tokens {
389            let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
390                .map(|c| (c, query_centroid_scores[[q_idx, c]]))
391                .collect();
392
393            // Partial selection: O(K) average instead of O(K log K) for full sort
394            // After this, the top n_ivf_probe elements are in positions 0..n_ivf_probe
395            // (but not sorted among themselves - which is fine since we use a HashSet)
396            if centroid_scores.len() > params.n_ivf_probe {
397                centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
398                    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
399                });
400            }
401
402            for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
403                selected_centroids.insert(*c);
404            }
405        }
406
407        // Apply centroid score threshold: filter out centroids where max score < threshold
408        if let Some(threshold) = params.centroid_score_threshold {
409            selected_centroids.retain(|&c| {
410                let max_score: f32 = (0..num_query_tokens)
411                    .map(|q_idx| query_centroid_scores[[q_idx, c]])
412                    .max_by(|a, b| a.partial_cmp(b).unwrap())
413                    .unwrap_or(f32::NEG_INFINITY);
414                max_score >= threshold
415            });
416        }
417
418        selected_centroids.into_iter().collect()
419    };
420
421    // Get candidate documents from IVF
422    let mut candidates = index.get_candidates(&cells_to_probe);
423
424    // Filter by subset if provided
425    if let Some(subset_docs) = subset {
426        let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
427        candidates.retain(|&c| subset_set.contains(&c));
428    }
429
430    if candidates.is_empty() {
431        return Ok(QueryResult {
432            query_id: 0,
433            passage_ids: vec![],
434            scores: vec![],
435        });
436    }
437
438    // Compute approximate scores
439    let mut approx_scores: Vec<(i64, f32)> = candidates
440        .par_iter()
441        .map(|&doc_id| {
442            let start = index.doc_offsets[doc_id as usize];
443            let end = index.doc_offsets[doc_id as usize + 1];
444            let codes = index.mmap_codes.slice(start, end);
445            let score = approximate_score_mmap(&query_centroid_scores, &codes);
446            (doc_id, score)
447        })
448        .collect();
449
450    // Sort by approximate score and take top candidates
451    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
452    let top_candidates: Vec<i64> = approx_scores
453        .iter()
454        .take(params.n_full_scores)
455        .map(|(id, _)| *id)
456        .collect();
457
458    // Further reduce for full decompression
459    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
460    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
461
462    if to_decompress.is_empty() {
463        return Ok(QueryResult {
464            query_id: 0,
465            passage_ids: vec![],
466            scores: vec![],
467        });
468    }
469
470    // Compute exact scores with decompressed embeddings
471    // Use chunked processing to limit concurrent memory from parallel decompression
472    let mut exact_scores: Vec<(i64, f32)> = to_decompress
473        .par_chunks(DECOMPRESS_CHUNK_SIZE)
474        .flat_map(|chunk| {
475            chunk
476                .iter()
477                .filter_map(|&doc_id| {
478                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
479                    let score = colbert_score(&query.view(), &doc_embeddings.view());
480                    Some((doc_id, score))
481                })
482                .collect::<Vec<_>>()
483        })
484        .collect();
485
486    // Sort by exact score
487    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
488
489    // Return top-k results
490    let result_count = params.top_k.min(exact_scores.len());
491    let passage_ids: Vec<i64> = exact_scores
492        .iter()
493        .take(result_count)
494        .map(|(id, _)| *id)
495        .collect();
496    let scores: Vec<f32> = exact_scores
497        .iter()
498        .take(result_count)
499        .map(|(_, s)| *s)
500        .collect();
501
502    Ok(QueryResult {
503        query_id: 0,
504        passage_ids,
505        scores,
506    })
507}
508
509/// Memory-efficient batched search for MmapIndex with large centroid counts.
510///
511/// Uses batched IVF probing and sparse centroid scoring to minimize memory usage.
512fn search_one_mmap_batched(
513    index: &crate::index::MmapIndex,
514    query: &Array2<f32>,
515    params: &SearchParameters,
516) -> Result<QueryResult> {
517    let num_query_tokens = query.nrows();
518
519    // Step 1: Batched IVF probing
520    let cells_to_probe = ivf_probe_batched(
521        query,
522        &index.codec.centroids,
523        params.n_ivf_probe,
524        params.centroid_batch_size,
525        params.centroid_score_threshold,
526    );
527
528    // Step 2: Get candidate documents from IVF
529    let candidates = index.get_candidates(&cells_to_probe);
530
531    if candidates.is_empty() {
532        return Ok(QueryResult {
533            query_id: 0,
534            passage_ids: vec![],
535            scores: vec![],
536        });
537    }
538
539    // Step 3: Collect unique centroids from all candidate documents
540    let mut unique_centroids: HashSet<usize> = HashSet::new();
541    for &doc_id in &candidates {
542        let start = index.doc_offsets[doc_id as usize];
543        let end = index.doc_offsets[doc_id as usize + 1];
544        let codes = index.mmap_codes.slice(start, end);
545        for &code in codes.iter() {
546            unique_centroids.insert(code as usize);
547        }
548    }
549
550    // Step 4: Build sparse centroid scores
551    let sparse_scores =
552        build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
553
554    // Step 5: Compute approximate scores using sparse lookup
555    let mut approx_scores: Vec<(i64, f32)> = candidates
556        .par_iter()
557        .map(|&doc_id| {
558            let start = index.doc_offsets[doc_id as usize];
559            let end = index.doc_offsets[doc_id as usize + 1];
560            let codes = index.mmap_codes.slice(start, end);
561            let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
562            let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
563            (doc_id, score)
564        })
565        .collect();
566
567    // Sort by approximate score and take top candidates
568    approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
569    let top_candidates: Vec<i64> = approx_scores
570        .iter()
571        .take(params.n_full_scores)
572        .map(|(id, _)| *id)
573        .collect();
574
575    // Further reduce for full decompression
576    let n_decompress = (params.n_full_scores / 4).max(params.top_k);
577    let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
578
579    if to_decompress.is_empty() {
580        return Ok(QueryResult {
581            query_id: 0,
582            passage_ids: vec![],
583            scores: vec![],
584        });
585    }
586
587    // Compute exact scores with decompressed embeddings
588    // Use chunked processing to limit concurrent memory from parallel decompression
589    let mut exact_scores: Vec<(i64, f32)> = to_decompress
590        .par_chunks(DECOMPRESS_CHUNK_SIZE)
591        .flat_map(|chunk| {
592            chunk
593                .iter()
594                .filter_map(|&doc_id| {
595                    let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
596                    let score = colbert_score(&query.view(), &doc_embeddings.view());
597                    Some((doc_id, score))
598                })
599                .collect::<Vec<_>>()
600        })
601        .collect();
602
603    // Sort by exact score
604    exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
605
606    // Return top-k results
607    let result_count = params.top_k.min(exact_scores.len());
608    let passage_ids: Vec<i64> = exact_scores
609        .iter()
610        .take(result_count)
611        .map(|(id, _)| *id)
612        .collect();
613    let scores: Vec<f32> = exact_scores
614        .iter()
615        .take(result_count)
616        .map(|(_, s)| *s)
617        .collect();
618
619    Ok(QueryResult {
620        query_id: 0,
621        passage_ids,
622        scores,
623    })
624}
625
626/// Search a memory-mapped index for multiple queries.
627pub fn search_many_mmap(
628    index: &crate::index::MmapIndex,
629    queries: &[Array2<f32>],
630    params: &SearchParameters,
631    parallel: bool,
632    subset: Option<&[i64]>,
633) -> Result<Vec<QueryResult>> {
634    if parallel {
635        let results: Vec<QueryResult> = queries
636            .par_iter()
637            .enumerate()
638            .map(|(i, query)| {
639                let mut result =
640                    search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
641                        query_id: i,
642                        passage_ids: vec![],
643                        scores: vec![],
644                    });
645                result.query_id = i;
646                result
647            })
648            .collect();
649        Ok(results)
650    } else {
651        let mut results = Vec::with_capacity(queries.len());
652        for (i, query) in queries.iter().enumerate() {
653            let mut result = search_one_mmap(index, query, params, subset)?;
654            result.query_id = i;
655            results.push(result);
656        }
657        Ok(results)
658    }
659}
660
661/// Alias type for search result (for API compatibility)
662pub type SearchResult = QueryResult;
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    #[test]
669    fn test_colbert_score() {
670        // Query with 2 tokens, dim 4
671        let query =
672            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
673
674        // Document with 3 tokens
675        let doc = Array2::from_shape_vec(
676            (3, 4),
677            vec![
678                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
679                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
680                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
681            ],
682        )
683        .unwrap();
684
685        let score = colbert_score(&query.view(), &doc.view());
686        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
687        // Total: 0.8 + 0.9 = 1.7
688        assert!((score - 1.7).abs() < 1e-5);
689    }
690
691    #[test]
692    fn test_search_params_default() {
693        let params = SearchParameters::default();
694        assert_eq!(params.batch_size, 2000);
695        assert_eq!(params.n_full_scores, 4096);
696        assert_eq!(params.top_k, 10);
697        assert_eq!(params.n_ivf_probe, 8);
698        assert_eq!(params.centroid_score_threshold, Some(0.4));
699    }
700}