1use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15type ProbePartial = (
17 Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18 HashMap<usize, f32>,
19);
20
21const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29 pub batch_size: usize,
31 pub n_full_scores: usize,
33 pub top_k: usize,
35 pub n_ivf_probe: usize,
37 #[serde(default = "default_centroid_batch_size")]
41 pub centroid_batch_size: usize,
42 #[serde(default = "default_centroid_score_threshold")]
47 pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51 100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55 Some(0.4)
56}
57
58impl Default for SearchParameters {
59 fn default() -> Self {
60 Self {
61 batch_size: 2000,
62 n_full_scores: 4096,
63 top_k: 10,
64 n_ivf_probe: 8,
65 centroid_batch_size: default_centroid_batch_size(),
66 centroid_score_threshold: default_centroid_score_threshold(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74 pub query_id: usize,
76 pub passage_ids: Vec<i64>,
78 pub scores: Vec<f32>,
80}
81
82#[cfg(feature = "cuda")]
86const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
87
88fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
94 #[cfg(feature = "cuda")]
96 {
97 let matrix_size = query.nrows() * doc.nrows();
98 if matrix_size >= CUDA_COLBERT_MIN_SIZE {
99 if let Some(ctx) = crate::cuda::get_global_context() {
100 match crate::cuda::colbert_score_cuda(ctx, query, doc) {
101 Ok(score) => return score,
102 Err(_) => {
103 }
105 }
106 }
107 }
108 }
109
110 colbert_score_cpu(query, doc)
112}
113
114fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
117 maxsim::maxsim_score(query, doc)
118}
119
120#[allow(clippy::too_many_arguments)]
126fn compute_adaptive_ivf_probe_mmap(
127 query_centroid_scores: &Array2<f32>,
128 mmap_codes: &crate::mmap::MmapNpyArray1I64,
129 doc_offsets: &[usize],
130 num_docs: usize,
131 subset: &[i64],
132 top_k: usize,
133 n_ivf_probe: usize,
134 centroid_score_threshold: Option<f32>,
135) -> Vec<usize> {
136 let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
138 for &doc_id in subset {
139 let doc_idx = doc_id as usize;
140 if doc_idx < num_docs {
141 let start = doc_offsets[doc_idx];
142 let end = doc_offsets[doc_idx + 1];
143 let codes = mmap_codes.slice(start, end);
144 for &c in codes.iter() {
145 centroid_doc_counts
146 .entry(c as usize)
147 .or_default()
148 .insert(doc_id);
149 }
150 }
151 }
152
153 if centroid_doc_counts.is_empty() {
154 return vec![];
155 }
156
157 let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
159 .into_iter()
160 .map(|(c, docs)| {
161 let max_score: f32 = query_centroid_scores
162 .axis_iter(Axis(0))
163 .map(|q| q[c])
164 .max_by(|a, b| a.partial_cmp(b).unwrap())
165 .unwrap_or(0.0);
166 (c, max_score, docs.len())
167 })
168 .collect();
169
170 if let Some(threshold) = centroid_score_threshold {
172 scored_centroids.retain(|(_, score, _)| *score >= threshold);
173 }
174
175 scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
177
178 let mut cumulative_docs = 0;
180 let mut n_probe = 0;
181
182 for (_, _, doc_count) in &scored_centroids {
183 cumulative_docs += doc_count;
184 n_probe += 1;
185 if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
187 break;
188 }
189 }
190
191 n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
193
194 scored_centroids
195 .iter()
196 .take(n_probe)
197 .map(|(c, _, _)| *c)
198 .collect()
199}
200
201#[derive(Clone, Copy, PartialEq)]
203struct OrdF32(f32);
204
205impl Eq for OrdF32 {}
206
207impl PartialOrd for OrdF32 {
208 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
209 Some(self.cmp(other))
210 }
211}
212
213impl Ord for OrdF32 {
214 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
215 self.0
216 .partial_cmp(&other.0)
217 .unwrap_or(std::cmp::Ordering::Equal)
218 }
219}
220
221fn ivf_probe_batched(
227 query: &Array2<f32>,
228 centroids: &CentroidStore,
229 n_probe: usize,
230 batch_size: usize,
231 centroid_score_threshold: Option<f32>,
232) -> Vec<usize> {
233 let num_centroids = centroids.nrows();
234 let num_tokens = query.nrows();
235
236 let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
238 .step_by(batch_size)
239 .map(|start| (start, (start + batch_size).min(num_centroids)))
240 .collect();
241
242 let local_results: Vec<ProbePartial> = batch_ranges
249 .par_iter()
250 .map(|&(batch_start, batch_end)| {
251 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
252 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
253 .collect();
254 let mut max_scores: HashMap<usize, f32> = HashMap::new();
255
256 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
258
259 let batch_scores = query.dot(&batch_centroids.t());
261
262 for (q_idx, heap) in heaps.iter_mut().enumerate() {
264 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
265 let global_c = batch_start + local_c;
266 let entry = (Reverse(OrdF32(score)), global_c);
267
268 if heap.len() < n_probe {
269 heap.push(entry);
270 max_scores
271 .entry(global_c)
272 .and_modify(|s| *s = s.max(score))
273 .or_insert(score);
274 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
275 if score > min_score {
276 heap.pop();
277 heap.push(entry);
278 max_scores
279 .entry(global_c)
280 .and_modify(|s| *s = s.max(score))
281 .or_insert(score);
282 }
283 }
284 }
285 }
286
287 (heaps, max_scores)
288 })
289 .collect();
290
291 let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
294 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
295 .collect();
296 let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
297
298 for (local_heaps, local_max_scores) in local_results {
299 for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
300 for entry in local_heap {
301 let (Reverse(OrdF32(score)), _) = entry;
302 if final_heaps[q_idx].len() < n_probe {
303 final_heaps[q_idx].push(entry);
304 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
305 if score > min_score {
306 final_heaps[q_idx].pop();
307 final_heaps[q_idx].push(entry);
308 }
309 }
310 }
311 }
312 for (c, score) in local_max_scores {
313 final_max_scores
314 .entry(c)
315 .and_modify(|s| *s = s.max(score))
316 .or_insert(score);
317 }
318 }
319
320 let mut selected: HashSet<usize> = HashSet::new();
322 for heap in final_heaps {
323 for (_, c) in heap {
324 selected.insert(c);
325 }
326 }
327
328 if let Some(threshold) = centroid_score_threshold {
330 selected.retain(|c| {
331 final_max_scores
332 .get(c)
333 .copied()
334 .unwrap_or(f32::NEG_INFINITY)
335 >= threshold
336 });
337 }
338
339 selected.into_iter().collect()
340}
341
342fn build_sparse_centroid_scores(
346 query: &Array2<f32>,
347 centroids: &CentroidStore,
348 centroid_ids: &HashSet<usize>,
349) -> HashMap<usize, Array1<f32>> {
350 centroid_ids
351 .iter()
352 .map(|&c| {
353 let centroid = centroids.row(c);
354 let scores: Array1<f32> = query.dot(¢roid);
355 (c, scores)
356 })
357 .collect()
358}
359
360fn approximate_score_sparse(
362 sparse_scores: &HashMap<usize, Array1<f32>>,
363 doc_codes: &[usize],
364 num_query_tokens: usize,
365) -> f32 {
366 let mut score = 0.0;
367
368 for q_idx in 0..num_query_tokens {
370 let mut max_score = f32::NEG_INFINITY;
371
372 for &code in doc_codes.iter() {
374 if let Some(centroid_scores) = sparse_scores.get(&code) {
375 let centroid_score = centroid_scores[q_idx];
376 if centroid_score > max_score {
377 max_score = centroid_score;
378 }
379 }
380 }
381
382 if max_score > f32::NEG_INFINITY {
383 score += max_score;
384 }
385 }
386
387 score
388}
389
390fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
392 let mut score = 0.0;
393
394 for q_idx in 0..query_centroid_scores.nrows() {
395 let mut max_score = f32::NEG_INFINITY;
396
397 for &code in doc_codes.iter() {
398 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
399 if centroid_score > max_score {
400 max_score = centroid_score;
401 }
402 }
403
404 if max_score > f32::NEG_INFINITY {
405 score += max_score;
406 }
407 }
408
409 score
410}
411
412pub fn search_one_mmap(
414 index: &crate::index::MmapIndex,
415 query: &Array2<f32>,
416 params: &SearchParameters,
417 subset: Option<&[i64]>,
418) -> Result<QueryResult> {
419 let num_centroids = index.codec.num_centroids();
420 let num_query_tokens = query.nrows();
421
422 let use_batched = params.centroid_batch_size > 0
424 && num_centroids > params.centroid_batch_size
425 && subset.is_none();
426
427 if use_batched {
428 return search_one_mmap_batched(index, query, params);
430 }
431
432 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
434
435 let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
437 compute_adaptive_ivf_probe_mmap(
439 &query_centroid_scores,
440 &index.mmap_codes,
441 index.doc_offsets.as_slice().unwrap(),
442 index.doc_lengths.len(),
443 subset_docs,
444 params.top_k,
445 params.n_ivf_probe,
446 params.centroid_score_threshold,
447 )
448 } else {
449 let mut selected_centroids = HashSet::new();
451
452 for q_idx in 0..num_query_tokens {
453 let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
454 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
455 .collect();
456
457 if centroid_scores.len() > params.n_ivf_probe {
461 centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
462 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
463 });
464 }
465
466 for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
467 selected_centroids.insert(*c);
468 }
469 }
470
471 if let Some(threshold) = params.centroid_score_threshold {
473 selected_centroids.retain(|&c| {
474 let max_score: f32 = (0..num_query_tokens)
475 .map(|q_idx| query_centroid_scores[[q_idx, c]])
476 .max_by(|a, b| a.partial_cmp(b).unwrap())
477 .unwrap_or(f32::NEG_INFINITY);
478 max_score >= threshold
479 });
480 }
481
482 selected_centroids.into_iter().collect()
483 };
484
485 let mut candidates = index.get_candidates(&cells_to_probe);
487
488 if let Some(subset_docs) = subset {
490 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
491 candidates.retain(|&c| subset_set.contains(&c));
492 }
493
494 if candidates.is_empty() {
495 return Ok(QueryResult {
496 query_id: 0,
497 passage_ids: vec![],
498 scores: vec![],
499 });
500 }
501
502 let mut approx_scores: Vec<(i64, f32)> = candidates
504 .par_iter()
505 .map(|&doc_id| {
506 let start = index.doc_offsets[doc_id as usize];
507 let end = index.doc_offsets[doc_id as usize + 1];
508 let codes = index.mmap_codes.slice(start, end);
509 let score = approximate_score_mmap(&query_centroid_scores, &codes);
510 (doc_id, score)
511 })
512 .collect();
513
514 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
516 let top_candidates: Vec<i64> = approx_scores
517 .iter()
518 .take(params.n_full_scores)
519 .map(|(id, _)| *id)
520 .collect();
521
522 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
524 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
525
526 if to_decompress.is_empty() {
527 return Ok(QueryResult {
528 query_id: 0,
529 passage_ids: vec![],
530 scores: vec![],
531 });
532 }
533
534 let mut exact_scores: Vec<(i64, f32)> = to_decompress
537 .par_chunks(DECOMPRESS_CHUNK_SIZE)
538 .flat_map(|chunk| {
539 chunk
540 .iter()
541 .filter_map(|&doc_id| {
542 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
543 let score = colbert_score(&query.view(), &doc_embeddings.view());
544 Some((doc_id, score))
545 })
546 .collect::<Vec<_>>()
547 })
548 .collect();
549
550 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
552
553 let result_count = params.top_k.min(exact_scores.len());
555 let passage_ids: Vec<i64> = exact_scores
556 .iter()
557 .take(result_count)
558 .map(|(id, _)| *id)
559 .collect();
560 let scores: Vec<f32> = exact_scores
561 .iter()
562 .take(result_count)
563 .map(|(_, s)| *s)
564 .collect();
565
566 Ok(QueryResult {
567 query_id: 0,
568 passage_ids,
569 scores,
570 })
571}
572
573fn search_one_mmap_batched(
577 index: &crate::index::MmapIndex,
578 query: &Array2<f32>,
579 params: &SearchParameters,
580) -> Result<QueryResult> {
581 let num_query_tokens = query.nrows();
582
583 let cells_to_probe = ivf_probe_batched(
585 query,
586 &index.codec.centroids,
587 params.n_ivf_probe,
588 params.centroid_batch_size,
589 params.centroid_score_threshold,
590 );
591
592 let candidates = index.get_candidates(&cells_to_probe);
594
595 if candidates.is_empty() {
596 return Ok(QueryResult {
597 query_id: 0,
598 passage_ids: vec![],
599 scores: vec![],
600 });
601 }
602
603 let mut unique_centroids: HashSet<usize> = HashSet::new();
605 for &doc_id in &candidates {
606 let start = index.doc_offsets[doc_id as usize];
607 let end = index.doc_offsets[doc_id as usize + 1];
608 let codes = index.mmap_codes.slice(start, end);
609 for &code in codes.iter() {
610 unique_centroids.insert(code as usize);
611 }
612 }
613
614 let sparse_scores =
616 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
617
618 let mut approx_scores: Vec<(i64, f32)> = candidates
620 .par_iter()
621 .map(|&doc_id| {
622 let start = index.doc_offsets[doc_id as usize];
623 let end = index.doc_offsets[doc_id as usize + 1];
624 let codes = index.mmap_codes.slice(start, end);
625 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
626 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
627 (doc_id, score)
628 })
629 .collect();
630
631 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
633 let top_candidates: Vec<i64> = approx_scores
634 .iter()
635 .take(params.n_full_scores)
636 .map(|(id, _)| *id)
637 .collect();
638
639 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
641 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
642
643 if to_decompress.is_empty() {
644 return Ok(QueryResult {
645 query_id: 0,
646 passage_ids: vec![],
647 scores: vec![],
648 });
649 }
650
651 let mut exact_scores: Vec<(i64, f32)> = to_decompress
654 .par_chunks(DECOMPRESS_CHUNK_SIZE)
655 .flat_map(|chunk| {
656 chunk
657 .iter()
658 .filter_map(|&doc_id| {
659 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
660 let score = colbert_score(&query.view(), &doc_embeddings.view());
661 Some((doc_id, score))
662 })
663 .collect::<Vec<_>>()
664 })
665 .collect();
666
667 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
669
670 let result_count = params.top_k.min(exact_scores.len());
672 let passage_ids: Vec<i64> = exact_scores
673 .iter()
674 .take(result_count)
675 .map(|(id, _)| *id)
676 .collect();
677 let scores: Vec<f32> = exact_scores
678 .iter()
679 .take(result_count)
680 .map(|(_, s)| *s)
681 .collect();
682
683 Ok(QueryResult {
684 query_id: 0,
685 passage_ids,
686 scores,
687 })
688}
689
690pub fn search_many_mmap(
692 index: &crate::index::MmapIndex,
693 queries: &[Array2<f32>],
694 params: &SearchParameters,
695 parallel: bool,
696 subset: Option<&[i64]>,
697) -> Result<Vec<QueryResult>> {
698 if parallel {
699 let results: Vec<QueryResult> = queries
700 .par_iter()
701 .enumerate()
702 .map(|(i, query)| {
703 let mut result =
704 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
705 query_id: i,
706 passage_ids: vec![],
707 scores: vec![],
708 });
709 result.query_id = i;
710 result
711 })
712 .collect();
713 Ok(results)
714 } else {
715 let mut results = Vec::with_capacity(queries.len());
716 for (i, query) in queries.iter().enumerate() {
717 let mut result = search_one_mmap(index, query, params, subset)?;
718 result.query_id = i;
719 results.push(result);
720 }
721 Ok(results)
722 }
723}
724
725pub type SearchResult = QueryResult;
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731
732 #[test]
733 fn test_colbert_score() {
734 let query =
736 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
737
738 let doc = Array2::from_shape_vec(
740 (3, 4),
741 vec![
742 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
746 )
747 .unwrap();
748
749 let score = colbert_score(&query.view(), &doc.view());
750 assert!((score - 1.7).abs() < 1e-5);
753 }
754
755 #[test]
756 fn test_search_params_default() {
757 let params = SearchParameters::default();
758 assert_eq!(params.batch_size, 2000);
759 assert_eq!(params.n_full_scores, 4096);
760 assert_eq!(params.top_k, 10);
761 assert_eq!(params.n_ivf_probe, 8);
762 assert_eq!(params.centroid_score_threshold, Some(0.4));
763 }
764}