1use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15const DECOMPRESS_CHUNK_SIZE: usize = 128;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SearchParameters {
23 pub batch_size: usize,
25 pub n_full_scores: usize,
27 pub top_k: usize,
29 pub n_ivf_probe: usize,
31 #[serde(default = "default_centroid_batch_size")]
35 pub centroid_batch_size: usize,
36 #[serde(default = "default_centroid_score_threshold")]
41 pub centroid_score_threshold: Option<f32>,
42}
43
44fn default_centroid_batch_size() -> usize {
45 100_000
46}
47
48fn default_centroid_score_threshold() -> Option<f32> {
49 Some(0.4)
50}
51
52impl Default for SearchParameters {
53 fn default() -> Self {
54 Self {
55 batch_size: 2000,
56 n_full_scores: 4096,
57 top_k: 10,
58 n_ivf_probe: 8,
59 centroid_batch_size: default_centroid_batch_size(),
60 centroid_score_threshold: default_centroid_score_threshold(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct QueryResult {
68 pub query_id: usize,
70 pub passage_ids: Vec<i64>,
72 pub scores: Vec<f32>,
74}
75
76#[cfg(feature = "cuda")]
80const CUDA_COLBERT_MIN_SIZE: usize = 128 * 1024;
81
82fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
88 #[cfg(feature = "cuda")]
90 {
91 let matrix_size = query.nrows() * doc.nrows();
92 if matrix_size >= CUDA_COLBERT_MIN_SIZE {
93 if let Some(ctx) = crate::cuda::get_global_context() {
94 match crate::cuda::colbert_score_cuda(ctx, query, doc) {
95 Ok(score) => return score,
96 Err(_) => {
97 }
99 }
100 }
101 }
102 }
103
104 colbert_score_cpu(query, doc)
106}
107
108fn colbert_score_cpu(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
111 maxsim::maxsim_score(query, doc)
112}
113
114#[allow(clippy::too_many_arguments)]
120fn compute_adaptive_ivf_probe_mmap(
121 query_centroid_scores: &Array2<f32>,
122 mmap_codes: &crate::mmap::MmapNpyArray1I64,
123 doc_offsets: &[usize],
124 num_docs: usize,
125 subset: &[i64],
126 top_k: usize,
127 n_ivf_probe: usize,
128 centroid_score_threshold: Option<f32>,
129) -> Vec<usize> {
130 let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
132 for &doc_id in subset {
133 let doc_idx = doc_id as usize;
134 if doc_idx < num_docs {
135 let start = doc_offsets[doc_idx];
136 let end = doc_offsets[doc_idx + 1];
137 let codes = mmap_codes.slice(start, end);
138 for &c in codes.iter() {
139 centroid_doc_counts
140 .entry(c as usize)
141 .or_default()
142 .insert(doc_id);
143 }
144 }
145 }
146
147 if centroid_doc_counts.is_empty() {
148 return vec![];
149 }
150
151 let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
153 .into_iter()
154 .map(|(c, docs)| {
155 let max_score: f32 = query_centroid_scores
156 .axis_iter(Axis(0))
157 .map(|q| q[c])
158 .max_by(|a, b| a.partial_cmp(b).unwrap())
159 .unwrap_or(0.0);
160 (c, max_score, docs.len())
161 })
162 .collect();
163
164 if let Some(threshold) = centroid_score_threshold {
166 scored_centroids.retain(|(_, score, _)| *score >= threshold);
167 }
168
169 scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
171
172 let mut cumulative_docs = 0;
174 let mut n_probe = 0;
175
176 for (_, _, doc_count) in &scored_centroids {
177 cumulative_docs += doc_count;
178 n_probe += 1;
179 if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
181 break;
182 }
183 }
184
185 n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
187
188 scored_centroids
189 .iter()
190 .take(n_probe)
191 .map(|(c, _, _)| *c)
192 .collect()
193}
194
195#[derive(Clone, Copy, PartialEq)]
197struct OrdF32(f32);
198
199impl Eq for OrdF32 {}
200
201impl PartialOrd for OrdF32 {
202 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
203 Some(self.cmp(other))
204 }
205}
206
207impl Ord for OrdF32 {
208 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
209 self.0
210 .partial_cmp(&other.0)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 }
213}
214
215fn ivf_probe_batched(
221 query: &Array2<f32>,
222 centroids: &CentroidStore,
223 n_probe: usize,
224 batch_size: usize,
225 centroid_score_threshold: Option<f32>,
226) -> Vec<usize> {
227 let num_centroids = centroids.nrows();
228 let num_tokens = query.nrows();
229
230 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
233 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
234 .collect();
235
236 let mut max_scores: HashMap<usize, f32> = HashMap::new();
238
239 for batch_start in (0..num_centroids).step_by(batch_size) {
240 let batch_end = (batch_start + batch_size).min(num_centroids);
241
242 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
244
245 let batch_scores = query.dot(&batch_centroids.t());
247
248 for (q_idx, heap) in heaps.iter_mut().enumerate() {
250 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
251 let global_c = batch_start + local_c;
252 let entry = (Reverse(OrdF32(score)), global_c);
253
254 if heap.len() < n_probe {
255 heap.push(entry);
256 max_scores
258 .entry(global_c)
259 .and_modify(|s| *s = s.max(score))
260 .or_insert(score);
261 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
262 if score > min_score {
263 heap.pop();
264 heap.push(entry);
265 max_scores
267 .entry(global_c)
268 .and_modify(|s| *s = s.max(score))
269 .or_insert(score);
270 }
271 }
272 }
273 }
274 }
275
276 let mut selected: HashSet<usize> = HashSet::new();
278 for heap in heaps {
279 for (_, c) in heap {
280 selected.insert(c);
281 }
282 }
283
284 if let Some(threshold) = centroid_score_threshold {
286 selected.retain(|c| max_scores.get(c).copied().unwrap_or(f32::NEG_INFINITY) >= threshold);
287 }
288
289 selected.into_iter().collect()
290}
291
292fn build_sparse_centroid_scores(
296 query: &Array2<f32>,
297 centroids: &CentroidStore,
298 centroid_ids: &HashSet<usize>,
299) -> HashMap<usize, Array1<f32>> {
300 centroid_ids
301 .iter()
302 .map(|&c| {
303 let centroid = centroids.row(c);
304 let scores: Array1<f32> = query.dot(¢roid);
305 (c, scores)
306 })
307 .collect()
308}
309
310fn approximate_score_sparse(
312 sparse_scores: &HashMap<usize, Array1<f32>>,
313 doc_codes: &[usize],
314 num_query_tokens: usize,
315) -> f32 {
316 let mut score = 0.0;
317
318 for q_idx in 0..num_query_tokens {
320 let mut max_score = f32::NEG_INFINITY;
321
322 for &code in doc_codes.iter() {
324 if let Some(centroid_scores) = sparse_scores.get(&code) {
325 let centroid_score = centroid_scores[q_idx];
326 if centroid_score > max_score {
327 max_score = centroid_score;
328 }
329 }
330 }
331
332 if max_score > f32::NEG_INFINITY {
333 score += max_score;
334 }
335 }
336
337 score
338}
339
340fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
342 let mut score = 0.0;
343
344 for q_idx in 0..query_centroid_scores.nrows() {
345 let mut max_score = f32::NEG_INFINITY;
346
347 for &code in doc_codes.iter() {
348 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
349 if centroid_score > max_score {
350 max_score = centroid_score;
351 }
352 }
353
354 if max_score > f32::NEG_INFINITY {
355 score += max_score;
356 }
357 }
358
359 score
360}
361
362pub fn search_one_mmap(
364 index: &crate::index::MmapIndex,
365 query: &Array2<f32>,
366 params: &SearchParameters,
367 subset: Option<&[i64]>,
368) -> Result<QueryResult> {
369 let num_centroids = index.codec.num_centroids();
370 let num_query_tokens = query.nrows();
371
372 let use_batched = params.centroid_batch_size > 0
374 && num_centroids > params.centroid_batch_size
375 && subset.is_none();
376
377 if use_batched {
378 return search_one_mmap_batched(index, query, params);
380 }
381
382 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
384
385 let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
387 compute_adaptive_ivf_probe_mmap(
389 &query_centroid_scores,
390 &index.mmap_codes,
391 index.doc_offsets.as_slice().unwrap(),
392 index.doc_lengths.len(),
393 subset_docs,
394 params.top_k,
395 params.n_ivf_probe,
396 params.centroid_score_threshold,
397 )
398 } else {
399 let mut selected_centroids = HashSet::new();
401
402 for q_idx in 0..num_query_tokens {
403 let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
404 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
405 .collect();
406
407 if centroid_scores.len() > params.n_ivf_probe {
411 centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
412 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
413 });
414 }
415
416 for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
417 selected_centroids.insert(*c);
418 }
419 }
420
421 if let Some(threshold) = params.centroid_score_threshold {
423 selected_centroids.retain(|&c| {
424 let max_score: f32 = (0..num_query_tokens)
425 .map(|q_idx| query_centroid_scores[[q_idx, c]])
426 .max_by(|a, b| a.partial_cmp(b).unwrap())
427 .unwrap_or(f32::NEG_INFINITY);
428 max_score >= threshold
429 });
430 }
431
432 selected_centroids.into_iter().collect()
433 };
434
435 let mut candidates = index.get_candidates(&cells_to_probe);
437
438 if let Some(subset_docs) = subset {
440 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
441 candidates.retain(|&c| subset_set.contains(&c));
442 }
443
444 if candidates.is_empty() {
445 return Ok(QueryResult {
446 query_id: 0,
447 passage_ids: vec![],
448 scores: vec![],
449 });
450 }
451
452 let mut approx_scores: Vec<(i64, f32)> = candidates
454 .par_iter()
455 .map(|&doc_id| {
456 let start = index.doc_offsets[doc_id as usize];
457 let end = index.doc_offsets[doc_id as usize + 1];
458 let codes = index.mmap_codes.slice(start, end);
459 let score = approximate_score_mmap(&query_centroid_scores, &codes);
460 (doc_id, score)
461 })
462 .collect();
463
464 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
466 let top_candidates: Vec<i64> = approx_scores
467 .iter()
468 .take(params.n_full_scores)
469 .map(|(id, _)| *id)
470 .collect();
471
472 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
474 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
475
476 if to_decompress.is_empty() {
477 return Ok(QueryResult {
478 query_id: 0,
479 passage_ids: vec![],
480 scores: vec![],
481 });
482 }
483
484 let mut exact_scores: Vec<(i64, f32)> = to_decompress
487 .par_chunks(DECOMPRESS_CHUNK_SIZE)
488 .flat_map(|chunk| {
489 chunk
490 .iter()
491 .filter_map(|&doc_id| {
492 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
493 let score = colbert_score(&query.view(), &doc_embeddings.view());
494 Some((doc_id, score))
495 })
496 .collect::<Vec<_>>()
497 })
498 .collect();
499
500 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
502
503 let result_count = params.top_k.min(exact_scores.len());
505 let passage_ids: Vec<i64> = exact_scores
506 .iter()
507 .take(result_count)
508 .map(|(id, _)| *id)
509 .collect();
510 let scores: Vec<f32> = exact_scores
511 .iter()
512 .take(result_count)
513 .map(|(_, s)| *s)
514 .collect();
515
516 Ok(QueryResult {
517 query_id: 0,
518 passage_ids,
519 scores,
520 })
521}
522
523fn search_one_mmap_batched(
527 index: &crate::index::MmapIndex,
528 query: &Array2<f32>,
529 params: &SearchParameters,
530) -> Result<QueryResult> {
531 let num_query_tokens = query.nrows();
532
533 let cells_to_probe = ivf_probe_batched(
535 query,
536 &index.codec.centroids,
537 params.n_ivf_probe,
538 params.centroid_batch_size,
539 params.centroid_score_threshold,
540 );
541
542 let candidates = index.get_candidates(&cells_to_probe);
544
545 if candidates.is_empty() {
546 return Ok(QueryResult {
547 query_id: 0,
548 passage_ids: vec![],
549 scores: vec![],
550 });
551 }
552
553 let mut unique_centroids: HashSet<usize> = HashSet::new();
555 for &doc_id in &candidates {
556 let start = index.doc_offsets[doc_id as usize];
557 let end = index.doc_offsets[doc_id as usize + 1];
558 let codes = index.mmap_codes.slice(start, end);
559 for &code in codes.iter() {
560 unique_centroids.insert(code as usize);
561 }
562 }
563
564 let sparse_scores =
566 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
567
568 let mut approx_scores: Vec<(i64, f32)> = candidates
570 .par_iter()
571 .map(|&doc_id| {
572 let start = index.doc_offsets[doc_id as usize];
573 let end = index.doc_offsets[doc_id as usize + 1];
574 let codes = index.mmap_codes.slice(start, end);
575 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
576 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
577 (doc_id, score)
578 })
579 .collect();
580
581 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
583 let top_candidates: Vec<i64> = approx_scores
584 .iter()
585 .take(params.n_full_scores)
586 .map(|(id, _)| *id)
587 .collect();
588
589 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
591 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
592
593 if to_decompress.is_empty() {
594 return Ok(QueryResult {
595 query_id: 0,
596 passage_ids: vec![],
597 scores: vec![],
598 });
599 }
600
601 let mut exact_scores: Vec<(i64, f32)> = to_decompress
604 .par_chunks(DECOMPRESS_CHUNK_SIZE)
605 .flat_map(|chunk| {
606 chunk
607 .iter()
608 .filter_map(|&doc_id| {
609 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
610 let score = colbert_score(&query.view(), &doc_embeddings.view());
611 Some((doc_id, score))
612 })
613 .collect::<Vec<_>>()
614 })
615 .collect();
616
617 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
619
620 let result_count = params.top_k.min(exact_scores.len());
622 let passage_ids: Vec<i64> = exact_scores
623 .iter()
624 .take(result_count)
625 .map(|(id, _)| *id)
626 .collect();
627 let scores: Vec<f32> = exact_scores
628 .iter()
629 .take(result_count)
630 .map(|(_, s)| *s)
631 .collect();
632
633 Ok(QueryResult {
634 query_id: 0,
635 passage_ids,
636 scores,
637 })
638}
639
640pub fn search_many_mmap(
642 index: &crate::index::MmapIndex,
643 queries: &[Array2<f32>],
644 params: &SearchParameters,
645 parallel: bool,
646 subset: Option<&[i64]>,
647) -> Result<Vec<QueryResult>> {
648 if parallel {
649 let results: Vec<QueryResult> = queries
650 .par_iter()
651 .enumerate()
652 .map(|(i, query)| {
653 let mut result =
654 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
655 query_id: i,
656 passage_ids: vec![],
657 scores: vec![],
658 });
659 result.query_id = i;
660 result
661 })
662 .collect();
663 Ok(results)
664 } else {
665 let mut results = Vec::with_capacity(queries.len());
666 for (i, query) in queries.iter().enumerate() {
667 let mut result = search_one_mmap(index, query, params, subset)?;
668 result.query_id = i;
669 results.push(result);
670 }
671 Ok(results)
672 }
673}
674
675pub type SearchResult = QueryResult;
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681
682 #[test]
683 fn test_colbert_score() {
684 let query =
686 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
687
688 let doc = Array2::from_shape_vec(
690 (3, 4),
691 vec![
692 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
696 )
697 .unwrap();
698
699 let score = colbert_score(&query.view(), &doc.view());
700 assert!((score - 1.7).abs() < 1e-5);
703 }
704
705 #[test]
706 fn test_search_params_default() {
707 let params = SearchParameters::default();
708 assert_eq!(params.batch_size, 2000);
709 assert_eq!(params.n_full_scores, 4096);
710 assert_eq!(params.top_k, 10);
711 assert_eq!(params.n_ivf_probe, 8);
712 assert_eq!(params.centroid_score_threshold, Some(0.4));
713 }
714}