1use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2, Axis};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13
14const DECOMPRESS_CHUNK_SIZE: usize = 128;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SearchParameters {
22 pub batch_size: usize,
24 pub n_full_scores: usize,
26 pub top_k: usize,
28 pub n_ivf_probe: usize,
30 #[serde(default = "default_centroid_batch_size")]
34 pub centroid_batch_size: usize,
35 #[serde(default = "default_centroid_score_threshold")]
40 pub centroid_score_threshold: Option<f32>,
41}
42
43fn default_centroid_batch_size() -> usize {
44 100_000
45}
46
47fn default_centroid_score_threshold() -> Option<f32> {
48 Some(0.4)
49}
50
51impl Default for SearchParameters {
52 fn default() -> Self {
53 Self {
54 batch_size: 2000,
55 n_full_scores: 4096,
56 top_k: 10,
57 n_ivf_probe: 8,
58 centroid_batch_size: default_centroid_batch_size(),
59 centroid_score_threshold: default_centroid_score_threshold(),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QueryResult {
67 pub query_id: usize,
69 pub passage_ids: Vec<i64>,
71 pub scores: Vec<f32>,
73}
74
75fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
78 let mut total_score = 0.0;
79
80 for q_row in query.axis_iter(Axis(0)) {
82 let mut max_sim = f32::NEG_INFINITY;
83
84 for d_row in doc.axis_iter(Axis(0)) {
86 let sim: f32 = q_row.dot(&d_row);
87 if sim > max_sim {
88 max_sim = sim;
89 }
90 }
91
92 if max_sim > f32::NEG_INFINITY {
93 total_score += max_sim;
94 }
95 }
96
97 total_score
98}
99
100#[allow(clippy::too_many_arguments)]
106fn compute_adaptive_ivf_probe_mmap(
107 query_centroid_scores: &Array2<f32>,
108 mmap_codes: &crate::mmap::MmapNpyArray1I64,
109 doc_offsets: &[usize],
110 num_docs: usize,
111 subset: &[i64],
112 top_k: usize,
113 n_ivf_probe: usize,
114 centroid_score_threshold: Option<f32>,
115) -> Vec<usize> {
116 let mut centroid_doc_counts: HashMap<usize, HashSet<i64>> = HashMap::new();
118 for &doc_id in subset {
119 let doc_idx = doc_id as usize;
120 if doc_idx < num_docs {
121 let start = doc_offsets[doc_idx];
122 let end = doc_offsets[doc_idx + 1];
123 let codes = mmap_codes.slice(start, end);
124 for &c in codes.iter() {
125 centroid_doc_counts
126 .entry(c as usize)
127 .or_default()
128 .insert(doc_id);
129 }
130 }
131 }
132
133 if centroid_doc_counts.is_empty() {
134 return vec![];
135 }
136
137 let mut scored_centroids: Vec<(usize, f32, usize)> = centroid_doc_counts
139 .into_iter()
140 .map(|(c, docs)| {
141 let max_score: f32 = query_centroid_scores
142 .axis_iter(Axis(0))
143 .map(|q| q[c])
144 .max_by(|a, b| a.partial_cmp(b).unwrap())
145 .unwrap_or(0.0);
146 (c, max_score, docs.len())
147 })
148 .collect();
149
150 if let Some(threshold) = centroid_score_threshold {
152 scored_centroids.retain(|(_, score, _)| *score >= threshold);
153 }
154
155 scored_centroids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
157
158 let mut cumulative_docs = 0;
160 let mut n_probe = 0;
161
162 for (_, _, doc_count) in &scored_centroids {
163 cumulative_docs += doc_count;
164 n_probe += 1;
165 if cumulative_docs >= top_k && n_probe >= n_ivf_probe {
167 break;
168 }
169 }
170
171 n_probe = n_probe.max(n_ivf_probe.min(scored_centroids.len()));
173
174 scored_centroids
175 .iter()
176 .take(n_probe)
177 .map(|(c, _, _)| *c)
178 .collect()
179}
180
181#[derive(Clone, Copy, PartialEq)]
183struct OrdF32(f32);
184
185impl Eq for OrdF32 {}
186
187impl PartialOrd for OrdF32 {
188 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
189 Some(self.cmp(other))
190 }
191}
192
193impl Ord for OrdF32 {
194 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
195 self.0
196 .partial_cmp(&other.0)
197 .unwrap_or(std::cmp::Ordering::Equal)
198 }
199}
200
201fn ivf_probe_batched(
207 query: &Array2<f32>,
208 centroids: &CentroidStore,
209 n_probe: usize,
210 batch_size: usize,
211 centroid_score_threshold: Option<f32>,
212) -> Vec<usize> {
213 let num_centroids = centroids.nrows();
214 let num_tokens = query.nrows();
215
216 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
219 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
220 .collect();
221
222 let mut max_scores: HashMap<usize, f32> = HashMap::new();
224
225 for batch_start in (0..num_centroids).step_by(batch_size) {
226 let batch_end = (batch_start + batch_size).min(num_centroids);
227
228 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
230
231 let batch_scores = query.dot(&batch_centroids.t());
233
234 for (q_idx, heap) in heaps.iter_mut().enumerate() {
236 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
237 let global_c = batch_start + local_c;
238 let entry = (Reverse(OrdF32(score)), global_c);
239
240 if heap.len() < n_probe {
241 heap.push(entry);
242 max_scores
244 .entry(global_c)
245 .and_modify(|s| *s = s.max(score))
246 .or_insert(score);
247 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
248 if score > min_score {
249 heap.pop();
250 heap.push(entry);
251 max_scores
253 .entry(global_c)
254 .and_modify(|s| *s = s.max(score))
255 .or_insert(score);
256 }
257 }
258 }
259 }
260 }
261
262 let mut selected: HashSet<usize> = HashSet::new();
264 for heap in heaps {
265 for (_, c) in heap {
266 selected.insert(c);
267 }
268 }
269
270 if let Some(threshold) = centroid_score_threshold {
272 selected.retain(|c| max_scores.get(c).copied().unwrap_or(f32::NEG_INFINITY) >= threshold);
273 }
274
275 selected.into_iter().collect()
276}
277
278fn build_sparse_centroid_scores(
282 query: &Array2<f32>,
283 centroids: &CentroidStore,
284 centroid_ids: &HashSet<usize>,
285) -> HashMap<usize, Array1<f32>> {
286 centroid_ids
287 .iter()
288 .map(|&c| {
289 let centroid = centroids.row(c);
290 let scores: Array1<f32> = query.dot(¢roid);
291 (c, scores)
292 })
293 .collect()
294}
295
296fn approximate_score_sparse(
298 sparse_scores: &HashMap<usize, Array1<f32>>,
299 doc_codes: &[usize],
300 num_query_tokens: usize,
301) -> f32 {
302 let mut score = 0.0;
303
304 for q_idx in 0..num_query_tokens {
306 let mut max_score = f32::NEG_INFINITY;
307
308 for &code in doc_codes.iter() {
310 if let Some(centroid_scores) = sparse_scores.get(&code) {
311 let centroid_score = centroid_scores[q_idx];
312 if centroid_score > max_score {
313 max_score = centroid_score;
314 }
315 }
316 }
317
318 if max_score > f32::NEG_INFINITY {
319 score += max_score;
320 }
321 }
322
323 score
324}
325
326fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
328 let mut score = 0.0;
329
330 for q_idx in 0..query_centroid_scores.nrows() {
331 let mut max_score = f32::NEG_INFINITY;
332
333 for &code in doc_codes.iter() {
334 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
335 if centroid_score > max_score {
336 max_score = centroid_score;
337 }
338 }
339
340 if max_score > f32::NEG_INFINITY {
341 score += max_score;
342 }
343 }
344
345 score
346}
347
348pub fn search_one_mmap(
350 index: &crate::index::MmapIndex,
351 query: &Array2<f32>,
352 params: &SearchParameters,
353 subset: Option<&[i64]>,
354) -> Result<QueryResult> {
355 let num_centroids = index.codec.num_centroids();
356 let num_query_tokens = query.nrows();
357
358 let use_batched = params.centroid_batch_size > 0
360 && num_centroids > params.centroid_batch_size
361 && subset.is_none();
362
363 if use_batched {
364 return search_one_mmap_batched(index, query, params);
366 }
367
368 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
370
371 let cells_to_probe: Vec<usize> = if let Some(subset_docs) = subset {
373 compute_adaptive_ivf_probe_mmap(
375 &query_centroid_scores,
376 &index.mmap_codes,
377 index.doc_offsets.as_slice().unwrap(),
378 index.doc_lengths.len(),
379 subset_docs,
380 params.top_k,
381 params.n_ivf_probe,
382 params.centroid_score_threshold,
383 )
384 } else {
385 let mut selected_centroids = HashSet::new();
387
388 for q_idx in 0..num_query_tokens {
389 let mut centroid_scores: Vec<(usize, f32)> = (0..num_centroids)
390 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
391 .collect();
392
393 if centroid_scores.len() > params.n_ivf_probe {
397 centroid_scores.select_nth_unstable_by(params.n_ivf_probe - 1, |a, b| {
398 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
399 });
400 }
401
402 for (c, _) in centroid_scores.iter().take(params.n_ivf_probe) {
403 selected_centroids.insert(*c);
404 }
405 }
406
407 if let Some(threshold) = params.centroid_score_threshold {
409 selected_centroids.retain(|&c| {
410 let max_score: f32 = (0..num_query_tokens)
411 .map(|q_idx| query_centroid_scores[[q_idx, c]])
412 .max_by(|a, b| a.partial_cmp(b).unwrap())
413 .unwrap_or(f32::NEG_INFINITY);
414 max_score >= threshold
415 });
416 }
417
418 selected_centroids.into_iter().collect()
419 };
420
421 let mut candidates = index.get_candidates(&cells_to_probe);
423
424 if let Some(subset_docs) = subset {
426 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
427 candidates.retain(|&c| subset_set.contains(&c));
428 }
429
430 if candidates.is_empty() {
431 return Ok(QueryResult {
432 query_id: 0,
433 passage_ids: vec![],
434 scores: vec![],
435 });
436 }
437
438 let mut approx_scores: Vec<(i64, f32)> = candidates
440 .par_iter()
441 .map(|&doc_id| {
442 let start = index.doc_offsets[doc_id as usize];
443 let end = index.doc_offsets[doc_id as usize + 1];
444 let codes = index.mmap_codes.slice(start, end);
445 let score = approximate_score_mmap(&query_centroid_scores, &codes);
446 (doc_id, score)
447 })
448 .collect();
449
450 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
452 let top_candidates: Vec<i64> = approx_scores
453 .iter()
454 .take(params.n_full_scores)
455 .map(|(id, _)| *id)
456 .collect();
457
458 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
460 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
461
462 if to_decompress.is_empty() {
463 return Ok(QueryResult {
464 query_id: 0,
465 passage_ids: vec![],
466 scores: vec![],
467 });
468 }
469
470 let mut exact_scores: Vec<(i64, f32)> = to_decompress
473 .par_chunks(DECOMPRESS_CHUNK_SIZE)
474 .flat_map(|chunk| {
475 chunk
476 .iter()
477 .filter_map(|&doc_id| {
478 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
479 let score = colbert_score(&query.view(), &doc_embeddings.view());
480 Some((doc_id, score))
481 })
482 .collect::<Vec<_>>()
483 })
484 .collect();
485
486 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
488
489 let result_count = params.top_k.min(exact_scores.len());
491 let passage_ids: Vec<i64> = exact_scores
492 .iter()
493 .take(result_count)
494 .map(|(id, _)| *id)
495 .collect();
496 let scores: Vec<f32> = exact_scores
497 .iter()
498 .take(result_count)
499 .map(|(_, s)| *s)
500 .collect();
501
502 Ok(QueryResult {
503 query_id: 0,
504 passage_ids,
505 scores,
506 })
507}
508
509fn search_one_mmap_batched(
513 index: &crate::index::MmapIndex,
514 query: &Array2<f32>,
515 params: &SearchParameters,
516) -> Result<QueryResult> {
517 let num_query_tokens = query.nrows();
518
519 let cells_to_probe = ivf_probe_batched(
521 query,
522 &index.codec.centroids,
523 params.n_ivf_probe,
524 params.centroid_batch_size,
525 params.centroid_score_threshold,
526 );
527
528 let candidates = index.get_candidates(&cells_to_probe);
530
531 if candidates.is_empty() {
532 return Ok(QueryResult {
533 query_id: 0,
534 passage_ids: vec![],
535 scores: vec![],
536 });
537 }
538
539 let mut unique_centroids: HashSet<usize> = HashSet::new();
541 for &doc_id in &candidates {
542 let start = index.doc_offsets[doc_id as usize];
543 let end = index.doc_offsets[doc_id as usize + 1];
544 let codes = index.mmap_codes.slice(start, end);
545 for &code in codes.iter() {
546 unique_centroids.insert(code as usize);
547 }
548 }
549
550 let sparse_scores =
552 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
553
554 let mut approx_scores: Vec<(i64, f32)> = candidates
556 .par_iter()
557 .map(|&doc_id| {
558 let start = index.doc_offsets[doc_id as usize];
559 let end = index.doc_offsets[doc_id as usize + 1];
560 let codes = index.mmap_codes.slice(start, end);
561 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
562 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
563 (doc_id, score)
564 })
565 .collect();
566
567 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
569 let top_candidates: Vec<i64> = approx_scores
570 .iter()
571 .take(params.n_full_scores)
572 .map(|(id, _)| *id)
573 .collect();
574
575 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
577 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
578
579 if to_decompress.is_empty() {
580 return Ok(QueryResult {
581 query_id: 0,
582 passage_ids: vec![],
583 scores: vec![],
584 });
585 }
586
587 let mut exact_scores: Vec<(i64, f32)> = to_decompress
590 .par_chunks(DECOMPRESS_CHUNK_SIZE)
591 .flat_map(|chunk| {
592 chunk
593 .iter()
594 .filter_map(|&doc_id| {
595 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
596 let score = colbert_score(&query.view(), &doc_embeddings.view());
597 Some((doc_id, score))
598 })
599 .collect::<Vec<_>>()
600 })
601 .collect();
602
603 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
605
606 let result_count = params.top_k.min(exact_scores.len());
608 let passage_ids: Vec<i64> = exact_scores
609 .iter()
610 .take(result_count)
611 .map(|(id, _)| *id)
612 .collect();
613 let scores: Vec<f32> = exact_scores
614 .iter()
615 .take(result_count)
616 .map(|(_, s)| *s)
617 .collect();
618
619 Ok(QueryResult {
620 query_id: 0,
621 passage_ids,
622 scores,
623 })
624}
625
626pub fn search_many_mmap(
628 index: &crate::index::MmapIndex,
629 queries: &[Array2<f32>],
630 params: &SearchParameters,
631 parallel: bool,
632 subset: Option<&[i64]>,
633) -> Result<Vec<QueryResult>> {
634 if parallel {
635 let results: Vec<QueryResult> = queries
636 .par_iter()
637 .enumerate()
638 .map(|(i, query)| {
639 let mut result =
640 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
641 query_id: i,
642 passage_ids: vec![],
643 scores: vec![],
644 });
645 result.query_id = i;
646 result
647 })
648 .collect();
649 Ok(results)
650 } else {
651 let mut results = Vec::with_capacity(queries.len());
652 for (i, query) in queries.iter().enumerate() {
653 let mut result = search_one_mmap(index, query, params, subset)?;
654 result.query_id = i;
655 results.push(result);
656 }
657 Ok(results)
658 }
659}
660
661pub type SearchResult = QueryResult;
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_colbert_score() {
670 let query =
672 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
673
674 let doc = Array2::from_shape_vec(
676 (3, 4),
677 vec![
678 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
682 )
683 .unwrap();
684
685 let score = colbert_score(&query.view(), &doc.view());
686 assert!((score - 1.7).abs() < 1e-5);
689 }
690
691 #[test]
692 fn test_search_params_default() {
693 let params = SearchParameters::default();
694 assert_eq!(params.batch_size, 2000);
695 assert_eq!(params.n_full_scores, 4096);
696 assert_eq!(params.top_k, 10);
697 assert_eq!(params.n_ivf_probe, 8);
698 assert_eq!(params.centroid_score_threshold, Some(0.4));
699 }
700}