1#[cfg(feature = "npy")]
4use std::cmp::Reverse;
5#[cfg(feature = "npy")]
6use std::collections::{BinaryHeap, HashMap, HashSet};
7
8#[cfg(feature = "npy")]
9use ndarray::Array1;
10use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "npy")]
15use crate::codec::CentroidStore;
16use crate::error::Result;
17use crate::index::Index;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SearchParameters {
22 pub batch_size: usize,
24 pub n_full_scores: usize,
26 pub top_k: usize,
28 pub n_ivf_probe: usize,
30 #[serde(default = "default_centroid_batch_size")]
34 pub centroid_batch_size: usize,
35}
36
37fn default_centroid_batch_size() -> usize {
38 100_000
39}
40
41impl Default for SearchParameters {
42 fn default() -> Self {
43 Self {
44 batch_size: 2000,
45 n_full_scores: 4096,
46 top_k: 10,
47 n_ivf_probe: 8,
48 centroid_batch_size: default_centroid_batch_size(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QueryResult {
56 pub query_id: usize,
58 pub passage_ids: Vec<i64>,
60 pub scores: Vec<f32>,
62}
63
64fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
67 let mut total_score = 0.0;
68
69 for q_row in query.axis_iter(Axis(0)) {
71 let mut max_sim = f32::NEG_INFINITY;
72
73 for d_row in doc.axis_iter(Axis(0)) {
75 let sim: f32 = q_row.dot(&d_row);
76 if sim > max_sim {
77 max_sim = sim;
78 }
79 }
80
81 if max_sim > f32::NEG_INFINITY {
82 total_score += max_sim;
83 }
84 }
85
86 total_score
87}
88
89fn approximate_score(query_centroid_scores: &Array2<f32>, doc_codes: &ArrayView1<usize>) -> f32 {
91 let mut score = 0.0;
92
93 for q_idx in 0..query_centroid_scores.nrows() {
95 let mut max_score = f32::NEG_INFINITY;
96
97 for &code in doc_codes.iter() {
99 let centroid_score = query_centroid_scores[[q_idx, code]];
100 if centroid_score > max_score {
101 max_score = centroid_score;
102 }
103 }
104
105 if max_score > f32::NEG_INFINITY {
106 score += max_score;
107 }
108 }
109
110 score
111}
112
113#[cfg(feature = "npy")]
115#[derive(Clone, Copy, PartialEq)]
116struct OrdF32(f32);
117
118#[cfg(feature = "npy")]
119impl Eq for OrdF32 {}
120
121#[cfg(feature = "npy")]
122impl PartialOrd for OrdF32 {
123 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
124 Some(self.cmp(other))
125 }
126}
127
128#[cfg(feature = "npy")]
129impl Ord for OrdF32 {
130 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
131 self.0
132 .partial_cmp(&other.0)
133 .unwrap_or(std::cmp::Ordering::Equal)
134 }
135}
136
137#[cfg(feature = "npy")]
142fn ivf_probe_batched(
143 query: &Array2<f32>,
144 centroids: &CentroidStore,
145 n_probe: usize,
146 batch_size: usize,
147) -> Vec<usize> {
148 let num_centroids = centroids.nrows();
149 let num_tokens = query.nrows();
150
151 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
154 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
155 .collect();
156
157 for batch_start in (0..num_centroids).step_by(batch_size) {
158 let batch_end = (batch_start + batch_size).min(num_centroids);
159
160 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
162
163 let batch_scores = query.dot(&batch_centroids.t());
165
166 for (q_idx, heap) in heaps.iter_mut().enumerate() {
168 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
169 let global_c = batch_start + local_c;
170 let entry = (Reverse(OrdF32(score)), global_c);
171
172 if heap.len() < n_probe {
173 heap.push(entry);
174 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
175 if score > min_score {
176 heap.pop();
177 heap.push(entry);
178 }
179 }
180 }
181 }
182 }
183
184 let mut selected: HashSet<usize> = HashSet::new();
186 for heap in heaps {
187 for (_, c) in heap {
188 selected.insert(c);
189 }
190 }
191 selected.into_iter().collect()
192}
193
194#[cfg(feature = "npy")]
198fn build_sparse_centroid_scores(
199 query: &Array2<f32>,
200 centroids: &CentroidStore,
201 centroid_ids: &HashSet<usize>,
202) -> HashMap<usize, Array1<f32>> {
203 centroid_ids
204 .iter()
205 .map(|&c| {
206 let centroid = centroids.row(c);
207 let scores: Array1<f32> = query.dot(¢roid);
208 (c, scores)
209 })
210 .collect()
211}
212
213#[cfg(feature = "npy")]
215fn approximate_score_sparse(
216 sparse_scores: &HashMap<usize, Array1<f32>>,
217 doc_codes: &[usize],
218 num_query_tokens: usize,
219) -> f32 {
220 let mut score = 0.0;
221
222 for q_idx in 0..num_query_tokens {
224 let mut max_score = f32::NEG_INFINITY;
225
226 for &code in doc_codes.iter() {
228 if let Some(centroid_scores) = sparse_scores.get(&code) {
229 let centroid_score = centroid_scores[q_idx];
230 if centroid_score > max_score {
231 max_score = centroid_score;
232 }
233 }
234 }
235
236 if max_score > f32::NEG_INFINITY {
237 score += max_score;
238 }
239 }
240
241 score
242}
243
244pub fn search_one(
246 query: &Array2<f32>,
247 index: &Index,
248 params: &SearchParameters,
249 subset: Option<&[i64]>,
250) -> Result<(Vec<i64>, Vec<f32>)> {
251 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
255
256 let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
258 let mut subset_centroids: Vec<usize> = Vec::new();
260 for &doc_id in subset_docs {
261 if (doc_id as usize) < index.doc_codes.len() {
262 subset_centroids.extend(index.doc_codes[doc_id as usize].iter().copied());
263 }
264 }
265 subset_centroids.sort_unstable();
266 subset_centroids.dedup();
267
268 if subset_centroids.is_empty() {
269 return Ok((vec![], vec![]));
270 }
271
272 let mut centroid_scores: Vec<(usize, f32)> = subset_centroids
274 .iter()
275 .map(|&c| {
276 let score: f32 = query_centroid_scores
277 .axis_iter(Axis(0))
278 .map(|q| q[c])
279 .max_by(|a, b| a.partial_cmp(b).unwrap())
280 .unwrap_or(0.0);
281 (c, score)
282 })
283 .collect();
284
285 centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
286 centroid_scores
287 .iter()
288 .take(params.n_ivf_probe)
289 .map(|(c, _)| *c)
290 .collect()
291 } else {
292 let num_centroids = index.codec.num_centroids();
295 let num_query_tokens = query_centroid_scores.nrows();
296
297 let mut selected_centroids = std::collections::HashSet::new();
299
300 for q_idx in 0..num_query_tokens {
301 let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
303 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
304 .collect();
305
306 centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
308
309 for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
310 selected_centroids.insert(*c);
311 }
312 }
313
314 selected_centroids.into_iter().collect()
315 };
316
317 let mut candidates = index.get_candidates(&cells_to_probe);
319
320 if let Some(subset_docs) = subset {
322 let subset_set: std::collections::HashSet<i64> = subset_docs.iter().copied().collect();
323 candidates.retain(|&c| subset_set.contains(&c));
324 }
325
326 if candidates.is_empty() {
327 return Ok((vec![], vec![]));
328 }
329
330 let mut approx_scores: Vec<(i64, f32)> = candidates
332 .par_iter()
333 .map(|&doc_id| {
334 let codes = &index.doc_codes[doc_id as usize];
335 let score = approximate_score(&query_centroid_scores, &codes.view());
336 (doc_id, score)
337 })
338 .collect();
339
340 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
342 let top_candidates: Vec<i64> = approx_scores
343 .iter()
344 .take(params.n_full_scores)
345 .map(|(id, _)| *id)
346 .collect();
347
348 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
350 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
351
352 if to_decompress.is_empty() {
353 return Ok((vec![], vec![]));
354 }
355
356 let mut exact_scores: Vec<(i64, f32)> = to_decompress
358 .par_iter()
359 .filter_map(|&doc_id| {
360 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
361 let score = colbert_score(&query.view(), &doc_embeddings.view());
362 Some((doc_id, score))
363 })
364 .collect();
365
366 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
368
369 let result_count = params.top_k.min(exact_scores.len());
371 let passage_ids: Vec<i64> = exact_scores
372 .iter()
373 .take(result_count)
374 .map(|(id, _)| *id)
375 .collect();
376 let scores: Vec<f32> = exact_scores
377 .iter()
378 .take(result_count)
379 .map(|(_, s)| *s)
380 .collect();
381
382 Ok((passage_ids, scores))
383}
384
385pub fn search_many(
387 queries: &[Array2<f32>],
388 index: &Index,
389 params: &SearchParameters,
390 show_progress: bool,
391 subsets: Option<&[Vec<i64>]>,
392) -> Result<Vec<QueryResult>> {
393 let progress = if show_progress {
394 let bar = indicatif::ProgressBar::new(queries.len() as u64);
395 bar.set_message("Searching...");
396 Some(bar)
397 } else {
398 None
399 };
400
401 let results: Vec<QueryResult> = queries
402 .par_iter()
403 .enumerate()
404 .map(|(i, query)| {
405 let subset = subsets.and_then(|s| s.get(i).map(|v| v.as_slice()));
406 let (passage_ids, scores) =
407 search_one(query, index, params, subset).unwrap_or_default();
408
409 if let Some(ref bar) = progress {
410 bar.inc(1);
411 }
412
413 QueryResult {
414 query_id: i,
415 passage_ids,
416 scores,
417 }
418 })
419 .collect();
420
421 if let Some(bar) = progress {
422 bar.finish();
423 }
424
425 Ok(results)
426}
427
428impl Index {
430 pub fn search(
432 &self,
433 query: &Array2<f32>,
434 params: &SearchParameters,
435 subset: Option<&[i64]>,
436 ) -> Result<QueryResult> {
437 let (passage_ids, scores) = search_one(query, self, params, subset)?;
438 Ok(QueryResult {
439 query_id: 0,
440 passage_ids,
441 scores,
442 })
443 }
444
445 pub fn search_batch(
447 &self,
448 queries: &[Array2<f32>],
449 params: &SearchParameters,
450 show_progress: bool,
451 subsets: Option<&[Vec<i64>]>,
452 ) -> Result<Vec<QueryResult>> {
453 search_many(queries, self, params, show_progress, subsets)
454 }
455}
456
457#[cfg(feature = "npy")]
463fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
464 let mut score = 0.0;
465
466 for q_idx in 0..query_centroid_scores.nrows() {
467 let mut max_score = f32::NEG_INFINITY;
468
469 for &code in doc_codes.iter() {
470 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
471 if centroid_score > max_score {
472 max_score = centroid_score;
473 }
474 }
475
476 if max_score > f32::NEG_INFINITY {
477 score += max_score;
478 }
479 }
480
481 score
482}
483
484#[cfg(feature = "npy")]
486pub fn search_one_mmap(
487 index: &crate::index::MmapIndex,
488 query: &Array2<f32>,
489 params: &SearchParameters,
490 subset: Option<&[i64]>,
491) -> Result<QueryResult> {
492 use ndarray::Axis;
493
494 let num_centroids = index.codec.num_centroids();
495 let num_query_tokens = query.nrows();
496
497 let use_batched = params.centroid_batch_size > 0
499 && num_centroids > params.centroid_batch_size
500 && subset.is_none();
501
502 if use_batched {
503 return search_one_mmap_batched(index, query, params);
505 }
506
507 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
509
510 let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
512 let mut subset_centroids: Vec<usize> = Vec::new();
514 for &doc_id in subset_docs {
515 if (doc_id as usize) < index.doc_lengths.len() {
516 let start = index.doc_offsets[doc_id as usize];
517 let end = index.doc_offsets[doc_id as usize + 1];
518 let codes = index.mmap_codes.slice(start, end);
519 subset_centroids.extend(codes.iter().map(|&c| c as usize));
520 }
521 }
522 subset_centroids.sort_unstable();
523 subset_centroids.dedup();
524
525 if subset_centroids.is_empty() {
526 return Ok(QueryResult {
527 query_id: 0,
528 passage_ids: vec![],
529 scores: vec![],
530 });
531 }
532
533 let mut centroid_scores: Vec<(usize, f32)> = subset_centroids
535 .iter()
536 .map(|&c| {
537 let score: f32 = query_centroid_scores
538 .axis_iter(Axis(0))
539 .map(|q| q[c])
540 .max_by(|a, b| a.partial_cmp(b).unwrap())
541 .unwrap_or(0.0);
542 (c, score)
543 })
544 .collect();
545
546 centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
547 centroid_scores
548 .iter()
549 .take(params.n_ivf_probe)
550 .map(|(c, _)| *c)
551 .collect()
552 } else {
553 let mut selected_centroids = HashSet::new();
555
556 for q_idx in 0..num_query_tokens {
557 let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
558 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
559 .collect();
560
561 centroid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
562
563 for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
564 selected_centroids.insert(*c);
565 }
566 }
567
568 selected_centroids.into_iter().collect()
569 };
570
571 let mut candidates = index.get_candidates(&cells_to_probe);
573
574 if let Some(subset_docs) = subset {
576 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
577 candidates.retain(|&c| subset_set.contains(&c));
578 }
579
580 if candidates.is_empty() {
581 return Ok(QueryResult {
582 query_id: 0,
583 passage_ids: vec![],
584 scores: vec![],
585 });
586 }
587
588 let mut approx_scores: Vec<(i64, f32)> = candidates
590 .par_iter()
591 .map(|&doc_id| {
592 let start = index.doc_offsets[doc_id as usize];
593 let end = index.doc_offsets[doc_id as usize + 1];
594 let codes = index.mmap_codes.slice(start, end);
595 let score = approximate_score_mmap(&query_centroid_scores, codes);
596 (doc_id, score)
597 })
598 .collect();
599
600 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
602 let top_candidates: Vec<i64> = approx_scores
603 .iter()
604 .take(params.n_full_scores)
605 .map(|(id, _)| *id)
606 .collect();
607
608 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
610 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
611
612 if to_decompress.is_empty() {
613 return Ok(QueryResult {
614 query_id: 0,
615 passage_ids: vec![],
616 scores: vec![],
617 });
618 }
619
620 let mut exact_scores: Vec<(i64, f32)> = to_decompress
622 .par_iter()
623 .filter_map(|&doc_id| {
624 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
625 let score = colbert_score(&query.view(), &doc_embeddings.view());
626 Some((doc_id, score))
627 })
628 .collect();
629
630 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
632
633 let result_count = params.top_k.min(exact_scores.len());
635 let passage_ids: Vec<i64> = exact_scores
636 .iter()
637 .take(result_count)
638 .map(|(id, _)| *id)
639 .collect();
640 let scores: Vec<f32> = exact_scores
641 .iter()
642 .take(result_count)
643 .map(|(_, s)| *s)
644 .collect();
645
646 Ok(QueryResult {
647 query_id: 0,
648 passage_ids,
649 scores,
650 })
651}
652
653#[cfg(feature = "npy")]
657fn search_one_mmap_batched(
658 index: &crate::index::MmapIndex,
659 query: &Array2<f32>,
660 params: &SearchParameters,
661) -> Result<QueryResult> {
662 let num_query_tokens = query.nrows();
663
664 let cells_to_probe = ivf_probe_batched(
666 query,
667 &index.codec.centroids,
668 params.n_ivf_probe,
669 params.centroid_batch_size,
670 );
671
672 let candidates = index.get_candidates(&cells_to_probe);
674
675 if candidates.is_empty() {
676 return Ok(QueryResult {
677 query_id: 0,
678 passage_ids: vec![],
679 scores: vec![],
680 });
681 }
682
683 let mut unique_centroids: HashSet<usize> = HashSet::new();
685 for &doc_id in &candidates {
686 let start = index.doc_offsets[doc_id as usize];
687 let end = index.doc_offsets[doc_id as usize + 1];
688 let codes = index.mmap_codes.slice(start, end);
689 for &code in codes.iter() {
690 unique_centroids.insert(code as usize);
691 }
692 }
693
694 let sparse_scores =
696 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
697
698 let mut approx_scores: Vec<(i64, f32)> = candidates
700 .par_iter()
701 .map(|&doc_id| {
702 let start = index.doc_offsets[doc_id as usize];
703 let end = index.doc_offsets[doc_id as usize + 1];
704 let codes = index.mmap_codes.slice(start, end);
705 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
706 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
707 (doc_id, score)
708 })
709 .collect();
710
711 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
713 let top_candidates: Vec<i64> = approx_scores
714 .iter()
715 .take(params.n_full_scores)
716 .map(|(id, _)| *id)
717 .collect();
718
719 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
721 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
722
723 if to_decompress.is_empty() {
724 return Ok(QueryResult {
725 query_id: 0,
726 passage_ids: vec![],
727 scores: vec![],
728 });
729 }
730
731 let mut exact_scores: Vec<(i64, f32)> = to_decompress
733 .par_iter()
734 .filter_map(|&doc_id| {
735 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
736 let score = colbert_score(&query.view(), &doc_embeddings.view());
737 Some((doc_id, score))
738 })
739 .collect();
740
741 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
743
744 let result_count = params.top_k.min(exact_scores.len());
746 let passage_ids: Vec<i64> = exact_scores
747 .iter()
748 .take(result_count)
749 .map(|(id, _)| *id)
750 .collect();
751 let scores: Vec<f32> = exact_scores
752 .iter()
753 .take(result_count)
754 .map(|(_, s)| *s)
755 .collect();
756
757 Ok(QueryResult {
758 query_id: 0,
759 passage_ids,
760 scores,
761 })
762}
763
764#[cfg(feature = "npy")]
766pub fn search_many_mmap(
767 index: &crate::index::MmapIndex,
768 queries: &[Array2<f32>],
769 params: &SearchParameters,
770 parallel: bool,
771 subset: Option<&[i64]>,
772) -> Result<Vec<QueryResult>> {
773 if parallel {
774 let results: Vec<QueryResult> = queries
775 .par_iter()
776 .enumerate()
777 .map(|(i, query)| {
778 let mut result =
779 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
780 query_id: i,
781 passage_ids: vec![],
782 scores: vec![],
783 });
784 result.query_id = i;
785 result
786 })
787 .collect();
788 Ok(results)
789 } else {
790 let mut results = Vec::with_capacity(queries.len());
791 for (i, query) in queries.iter().enumerate() {
792 let mut result = search_one_mmap(index, query, params, subset)?;
793 result.query_id = i;
794 results.push(result);
795 }
796 Ok(results)
797 }
798}
799
800pub type SearchResult = QueryResult;
802
803#[cfg(test)]
804mod tests {
805 use super::*;
806
807 #[test]
808 fn test_colbert_score() {
809 let query =
811 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
812
813 let doc = Array2::from_shape_vec(
815 (3, 4),
816 vec![
817 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
821 )
822 .unwrap();
823
824 let score = colbert_score(&query.view(), &doc.view());
825 assert!((score - 1.7).abs() < 1e-5);
828 }
829
830 #[test]
831 fn test_search_params_default() {
832 let params = SearchParameters::default();
833 assert_eq!(params.batch_size, 2000);
834 assert_eq!(params.n_full_scores, 4096);
835 assert_eq!(params.top_k, 10);
836 assert_eq!(params.n_ivf_probe, 8);
837 }
838}