1use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15type ProbePartial = (
17 Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18 HashMap<usize, f32>,
19);
20
21const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29 pub batch_size: usize,
31 pub n_full_scores: usize,
33 pub top_k: usize,
35 pub n_ivf_probe: usize,
37 #[serde(default = "default_centroid_batch_size")]
41 pub centroid_batch_size: usize,
42 #[serde(default = "default_centroid_score_threshold")]
47 pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51 100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55 Some(0.4)
56}
57
58impl Default for SearchParameters {
59 fn default() -> Self {
60 Self {
61 batch_size: 2000,
62 n_full_scores: 4096,
63 top_k: 10,
64 n_ivf_probe: 8,
65 centroid_batch_size: default_centroid_batch_size(),
66 centroid_score_threshold: default_centroid_score_threshold(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74 pub query_id: usize,
76 pub passage_ids: Vec<i64>,
78 pub scores: Vec<f32>,
80}
81
82#[cfg(feature = "cuda")]
86const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
87
88fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
94 #[cfg(feature = "cuda")]
96 {
97 let matrix_size = query.nrows() * doc.nrows();
98 if matrix_size >= CUDA_COLBERT_MIN_SIZE {
99 if let Some(ctx) = crate::cuda::get_global_context() {
100 match crate::cuda::colbert_score_cuda(ctx, query, doc) {
101 Ok(score) => return score,
102 Err(_) => {
103 }
105 }
106 }
107 }
108 }
109
110 colbert_score_cpu(query, doc)
112}
113
114fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
117 maxsim::maxsim_score(query, doc)
118}
119
120#[derive(Clone, Copy, PartialEq)]
122struct OrdF32(f32);
123
124impl Eq for OrdF32 {}
125
126impl PartialOrd for OrdF32 {
127 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
128 Some(self.cmp(other))
129 }
130}
131
132impl Ord for OrdF32 {
133 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
134 self.0
135 .partial_cmp(&other.0)
136 .unwrap_or(std::cmp::Ordering::Equal)
137 }
138}
139
140fn ivf_probe_batched(
146 query: &Array2<f32>,
147 centroids: &CentroidStore,
148 n_probe: usize,
149 batch_size: usize,
150 centroid_score_threshold: Option<f32>,
151) -> Vec<usize> {
152 let num_centroids = centroids.nrows();
153 let num_tokens = query.nrows();
154
155 let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
157 .step_by(batch_size)
158 .map(|start| (start, (start + batch_size).min(num_centroids)))
159 .collect();
160
161 let local_results: Vec<ProbePartial> = batch_ranges
168 .par_iter()
169 .map(|&(batch_start, batch_end)| {
170 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
171 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
172 .collect();
173 let mut max_scores: HashMap<usize, f32> = HashMap::new();
174
175 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
177
178 let batch_scores = query.dot(&batch_centroids.t());
180
181 for (q_idx, heap) in heaps.iter_mut().enumerate() {
183 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
184 let global_c = batch_start + local_c;
185 let entry = (Reverse(OrdF32(score)), global_c);
186
187 if heap.len() < n_probe {
188 heap.push(entry);
189 max_scores
190 .entry(global_c)
191 .and_modify(|s| *s = s.max(score))
192 .or_insert(score);
193 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
194 if score > min_score {
195 heap.pop();
196 heap.push(entry);
197 max_scores
198 .entry(global_c)
199 .and_modify(|s| *s = s.max(score))
200 .or_insert(score);
201 }
202 }
203 }
204 }
205
206 (heaps, max_scores)
207 })
208 .collect();
209
210 let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
213 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
214 .collect();
215 let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
216
217 for (local_heaps, local_max_scores) in local_results {
218 for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
219 for entry in local_heap {
220 let (Reverse(OrdF32(score)), _) = entry;
221 if final_heaps[q_idx].len() < n_probe {
222 final_heaps[q_idx].push(entry);
223 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
224 if score > min_score {
225 final_heaps[q_idx].pop();
226 final_heaps[q_idx].push(entry);
227 }
228 }
229 }
230 }
231 for (c, score) in local_max_scores {
232 final_max_scores
233 .entry(c)
234 .and_modify(|s| *s = s.max(score))
235 .or_insert(score);
236 }
237 }
238
239 let mut selected: HashSet<usize> = HashSet::new();
241 for heap in final_heaps {
242 for (_, c) in heap {
243 selected.insert(c);
244 }
245 }
246
247 if let Some(threshold) = centroid_score_threshold {
249 selected.retain(|c| {
250 final_max_scores
251 .get(c)
252 .copied()
253 .unwrap_or(f32::NEG_INFINITY)
254 >= threshold
255 });
256 }
257
258 selected.into_iter().collect()
259}
260
261fn build_sparse_centroid_scores(
265 query: &Array2<f32>,
266 centroids: &CentroidStore,
267 centroid_ids: &HashSet<usize>,
268) -> HashMap<usize, Array1<f32>> {
269 centroid_ids
270 .iter()
271 .map(|&c| {
272 let centroid = centroids.row(c);
273 let scores: Array1<f32> = query.dot(¢roid);
274 (c, scores)
275 })
276 .collect()
277}
278
279fn approximate_score_sparse(
281 sparse_scores: &HashMap<usize, Array1<f32>>,
282 doc_codes: &[usize],
283 num_query_tokens: usize,
284) -> f32 {
285 let mut score = 0.0;
286
287 for q_idx in 0..num_query_tokens {
289 let mut max_score = f32::NEG_INFINITY;
290
291 for &code in doc_codes.iter() {
293 if let Some(centroid_scores) = sparse_scores.get(&code) {
294 let centroid_score = centroid_scores[q_idx];
295 if centroid_score > max_score {
296 max_score = centroid_score;
297 }
298 }
299 }
300
301 if max_score > f32::NEG_INFINITY {
302 score += max_score;
303 }
304 }
305
306 score
307}
308
309fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
311 let mut score = 0.0;
312
313 for q_idx in 0..query_centroid_scores.nrows() {
314 let mut max_score = f32::NEG_INFINITY;
315
316 for &code in doc_codes.iter() {
317 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
318 if centroid_score > max_score {
319 max_score = centroid_score;
320 }
321 }
322
323 if max_score > f32::NEG_INFINITY {
324 score += max_score;
325 }
326 }
327
328 score
329}
330
331pub fn search_one_mmap(
333 index: &crate::index::MmapIndex,
334 query: &Array2<f32>,
335 params: &SearchParameters,
336 subset: Option<&[i64]>,
337) -> Result<QueryResult> {
338 let num_centroids = index.codec.num_centroids();
339 let num_query_tokens = query.nrows();
340
341 let use_batched = params.centroid_batch_size > 0 && num_centroids > params.centroid_batch_size;
343
344 if use_batched {
345 return search_one_mmap_batched(index, query, params, subset);
347 }
348
349 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
351
352 let eligible_centroids: Option<HashSet<usize>> = subset.map(|subset_docs| {
356 let mut centroids = HashSet::new();
357 for &doc_id in subset_docs {
358 let doc_idx = doc_id as usize;
359 if doc_idx < index.doc_lengths.len() {
360 let start = index.doc_offsets[doc_idx];
361 let end = index.doc_offsets[doc_idx + 1];
362 let codes = index.mmap_codes.slice(start, end);
363 for &c in codes.iter() {
364 centroids.insert(c as usize);
365 }
366 }
367 }
368 centroids
369 });
370
371 let effective_n_ivf_probe = match (&eligible_centroids, subset) {
376 (Some(eligible), Some(subset_docs)) if !eligible.is_empty() => {
377 let num_docs = index.doc_lengths.len();
378 let subset_len = subset_docs.len();
379 let scaled = if subset_len > 0 {
380 (params.n_ivf_probe as u64 * num_docs as u64 / subset_len as u64) as usize
381 } else {
382 params.n_ivf_probe
383 };
384 scaled.max(params.n_ivf_probe).min(eligible.len())
385 }
386 _ => params.n_ivf_probe,
387 };
388
389 let cells_to_probe: Vec<usize> = {
394 let mut selected_centroids = HashSet::new();
395
396 for q_idx in 0..num_query_tokens {
397 let mut centroid_scores: Vec<(usize, f32)> = match &eligible_centroids {
398 Some(eligible) => eligible
399 .iter()
400 .map(|&c| (c, query_centroid_scores[[q_idx, c]]))
401 .collect(),
402 None => (0..num_centroids)
403 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
404 .collect(),
405 };
406
407 let n_probe = effective_n_ivf_probe.min(centroid_scores.len());
411 if centroid_scores.len() > n_probe {
412 centroid_scores.select_nth_unstable_by(n_probe - 1, |a, b| {
413 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
414 });
415 }
416
417 for (c, _) in centroid_scores.iter().take(n_probe) {
418 selected_centroids.insert(*c);
419 }
420 }
421
422 if let Some(threshold) = params.centroid_score_threshold {
424 selected_centroids.retain(|&c| {
425 let max_score: f32 = (0..num_query_tokens)
426 .map(|q_idx| query_centroid_scores[[q_idx, c]])
427 .max_by(|a, b| a.partial_cmp(b).unwrap())
428 .unwrap_or(f32::NEG_INFINITY);
429 max_score >= threshold
430 });
431 }
432
433 selected_centroids.into_iter().collect()
434 };
435
436 let mut candidates = index.get_candidates(&cells_to_probe);
438
439 if let Some(subset_docs) = subset {
441 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
442 candidates.retain(|&c| subset_set.contains(&c));
443 }
444
445 if candidates.is_empty() {
446 return Ok(QueryResult {
447 query_id: 0,
448 passage_ids: vec![],
449 scores: vec![],
450 });
451 }
452
453 let mut approx_scores: Vec<(i64, f32)> = candidates
455 .par_iter()
456 .map(|&doc_id| {
457 let start = index.doc_offsets[doc_id as usize];
458 let end = index.doc_offsets[doc_id as usize + 1];
459 let codes = index.mmap_codes.slice(start, end);
460 let score = approximate_score_mmap(&query_centroid_scores, &codes);
461 (doc_id, score)
462 })
463 .collect();
464
465 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
467 let top_candidates: Vec<i64> = approx_scores
468 .iter()
469 .take(params.n_full_scores)
470 .map(|(id, _)| *id)
471 .collect();
472
473 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
475 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
476
477 if to_decompress.is_empty() {
478 return Ok(QueryResult {
479 query_id: 0,
480 passage_ids: vec![],
481 scores: vec![],
482 });
483 }
484
485 let mut exact_scores: Vec<(i64, f32)> = to_decompress
488 .par_chunks(DECOMPRESS_CHUNK_SIZE)
489 .flat_map(|chunk| {
490 chunk
491 .iter()
492 .filter_map(|&doc_id| {
493 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
494 let score = colbert_score(&query.view(), &doc_embeddings.view());
495 Some((doc_id, score))
496 })
497 .collect::<Vec<_>>()
498 })
499 .collect();
500
501 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
503
504 let result_count = params.top_k.min(exact_scores.len());
506 let passage_ids: Vec<i64> = exact_scores
507 .iter()
508 .take(result_count)
509 .map(|(id, _)| *id)
510 .collect();
511 let scores: Vec<f32> = exact_scores
512 .iter()
513 .take(result_count)
514 .map(|(_, s)| *s)
515 .collect();
516
517 Ok(QueryResult {
518 query_id: 0,
519 passage_ids,
520 scores,
521 })
522}
523
524fn search_one_mmap_batched(
528 index: &crate::index::MmapIndex,
529 query: &Array2<f32>,
530 params: &SearchParameters,
531 subset: Option<&[i64]>,
532) -> Result<QueryResult> {
533 let num_query_tokens = query.nrows();
534
535 let cells_to_probe = ivf_probe_batched(
537 query,
538 &index.codec.centroids,
539 params.n_ivf_probe,
540 params.centroid_batch_size,
541 params.centroid_score_threshold,
542 );
543
544 let mut candidates = index.get_candidates(&cells_to_probe);
546
547 if let Some(subset_docs) = subset {
549 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
550 candidates.retain(|&c| subset_set.contains(&c));
551 }
552
553 if candidates.is_empty() {
554 return Ok(QueryResult {
555 query_id: 0,
556 passage_ids: vec![],
557 scores: vec![],
558 });
559 }
560
561 let mut unique_centroids: HashSet<usize> = HashSet::new();
563 for &doc_id in &candidates {
564 let start = index.doc_offsets[doc_id as usize];
565 let end = index.doc_offsets[doc_id as usize + 1];
566 let codes = index.mmap_codes.slice(start, end);
567 for &code in codes.iter() {
568 unique_centroids.insert(code as usize);
569 }
570 }
571
572 let sparse_scores =
574 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
575
576 let mut approx_scores: Vec<(i64, f32)> = candidates
578 .par_iter()
579 .map(|&doc_id| {
580 let start = index.doc_offsets[doc_id as usize];
581 let end = index.doc_offsets[doc_id as usize + 1];
582 let codes = index.mmap_codes.slice(start, end);
583 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
584 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
585 (doc_id, score)
586 })
587 .collect();
588
589 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
591 let top_candidates: Vec<i64> = approx_scores
592 .iter()
593 .take(params.n_full_scores)
594 .map(|(id, _)| *id)
595 .collect();
596
597 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
599 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
600
601 if to_decompress.is_empty() {
602 return Ok(QueryResult {
603 query_id: 0,
604 passage_ids: vec![],
605 scores: vec![],
606 });
607 }
608
609 let mut exact_scores: Vec<(i64, f32)> = to_decompress
612 .par_chunks(DECOMPRESS_CHUNK_SIZE)
613 .flat_map(|chunk| {
614 chunk
615 .iter()
616 .filter_map(|&doc_id| {
617 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
618 let score = colbert_score(&query.view(), &doc_embeddings.view());
619 Some((doc_id, score))
620 })
621 .collect::<Vec<_>>()
622 })
623 .collect();
624
625 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
627
628 let result_count = params.top_k.min(exact_scores.len());
630 let passage_ids: Vec<i64> = exact_scores
631 .iter()
632 .take(result_count)
633 .map(|(id, _)| *id)
634 .collect();
635 let scores: Vec<f32> = exact_scores
636 .iter()
637 .take(result_count)
638 .map(|(_, s)| *s)
639 .collect();
640
641 Ok(QueryResult {
642 query_id: 0,
643 passage_ids,
644 scores,
645 })
646}
647
648pub fn search_many_mmap(
650 index: &crate::index::MmapIndex,
651 queries: &[Array2<f32>],
652 params: &SearchParameters,
653 parallel: bool,
654 subset: Option<&[i64]>,
655) -> Result<Vec<QueryResult>> {
656 if parallel {
657 let results: Vec<QueryResult> = queries
658 .par_iter()
659 .enumerate()
660 .map(|(i, query)| {
661 let mut result =
662 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
663 query_id: i,
664 passage_ids: vec![],
665 scores: vec![],
666 });
667 result.query_id = i;
668 result
669 })
670 .collect();
671 Ok(results)
672 } else {
673 let mut results = Vec::with_capacity(queries.len());
674 for (i, query) in queries.iter().enumerate() {
675 let mut result = search_one_mmap(index, query, params, subset)?;
676 result.query_id = i;
677 results.push(result);
678 }
679 Ok(results)
680 }
681}
682
683pub type SearchResult = QueryResult;
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
691 fn test_colbert_score() {
692 let query =
694 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
695
696 let doc = Array2::from_shape_vec(
698 (3, 4),
699 vec![
700 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
704 )
705 .unwrap();
706
707 let score = colbert_score(&query.view(), &doc.view());
708 assert!((score - 1.7).abs() < 1e-5);
711 }
712
713 #[test]
714 fn test_search_params_default() {
715 let params = SearchParameters::default();
716 assert_eq!(params.batch_size, 2000);
717 assert_eq!(params.n_full_scores, 4096);
718 assert_eq!(params.top_k, 10);
719 assert_eq!(params.n_ivf_probe, 8);
720 assert_eq!(params.centroid_score_threshold, Some(0.4));
721 }
722}