1use std::cmp::Reverse;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6use ndarray::Array1;
7use ndarray::{Array2, ArrayView2};
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::codec::CentroidStore;
12use crate::error::Result;
13use crate::maxsim;
14
15type ProbePartial = (
17 Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
18 HashMap<usize, f32>,
19);
20
21const DECOMPRESS_CHUNK_SIZE: usize = 128;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SearchParameters {
29 pub batch_size: usize,
31 pub n_full_scores: usize,
33 pub top_k: usize,
35 pub n_ivf_probe: usize,
37 #[serde(default = "default_centroid_batch_size")]
41 pub centroid_batch_size: usize,
42 #[serde(default = "default_centroid_score_threshold")]
47 pub centroid_score_threshold: Option<f32>,
48}
49
50fn default_centroid_batch_size() -> usize {
51 100_000
52}
53
54fn default_centroid_score_threshold() -> Option<f32> {
55 Some(0.4)
56}
57
58impl Default for SearchParameters {
59 fn default() -> Self {
60 Self {
61 batch_size: 2000,
62 n_full_scores: 4096,
63 top_k: 10,
64 n_ivf_probe: 8,
65 centroid_batch_size: default_centroid_batch_size(),
66 centroid_score_threshold: default_centroid_score_threshold(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct QueryResult {
74 pub query_id: usize,
76 pub passage_ids: Vec<i64>,
78 pub scores: Vec<f32>,
80}
81
82fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
89 maxsim::maxsim_score(query, doc)
90}
91
92#[derive(Clone, Copy, PartialEq)]
94struct OrdF32(f32);
95
96impl Eq for OrdF32 {}
97
98impl PartialOrd for OrdF32 {
99 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
100 Some(self.cmp(other))
101 }
102}
103
104impl Ord for OrdF32 {
105 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
106 self.0
107 .partial_cmp(&other.0)
108 .unwrap_or(std::cmp::Ordering::Equal)
109 }
110}
111
112fn ivf_probe_batched(
118 query: &Array2<f32>,
119 centroids: &CentroidStore,
120 n_probe: usize,
121 batch_size: usize,
122 centroid_score_threshold: Option<f32>,
123) -> Vec<usize> {
124 let num_centroids = centroids.nrows();
125 let num_tokens = query.nrows();
126
127 let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
129 .step_by(batch_size)
130 .map(|start| (start, (start + batch_size).min(num_centroids)))
131 .collect();
132
133 let local_results: Vec<ProbePartial> = batch_ranges
140 .par_iter()
141 .map(|&(batch_start, batch_end)| {
142 let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
143 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
144 .collect();
145 let mut max_scores: HashMap<usize, f32> = HashMap::new();
146
147 let batch_centroids = centroids.slice_rows(batch_start, batch_end);
149
150 let batch_scores = query.dot(&batch_centroids.t());
152
153 for (q_idx, heap) in heaps.iter_mut().enumerate() {
155 for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
156 let global_c = batch_start + local_c;
157 let entry = (Reverse(OrdF32(score)), global_c);
158
159 if heap.len() < n_probe {
160 heap.push(entry);
161 max_scores
162 .entry(global_c)
163 .and_modify(|s| *s = s.max(score))
164 .or_insert(score);
165 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
166 if score > min_score {
167 heap.pop();
168 heap.push(entry);
169 max_scores
170 .entry(global_c)
171 .and_modify(|s| *s = s.max(score))
172 .or_insert(score);
173 }
174 }
175 }
176 }
177
178 (heaps, max_scores)
179 })
180 .collect();
181
182 let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
185 .map(|_| BinaryHeap::with_capacity(n_probe + 1))
186 .collect();
187 let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
188
189 for (local_heaps, local_max_scores) in local_results {
190 for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
191 for entry in local_heap {
192 let (Reverse(OrdF32(score)), _) = entry;
193 if final_heaps[q_idx].len() < n_probe {
194 final_heaps[q_idx].push(entry);
195 } else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
196 if score > min_score {
197 final_heaps[q_idx].pop();
198 final_heaps[q_idx].push(entry);
199 }
200 }
201 }
202 }
203 for (c, score) in local_max_scores {
204 final_max_scores
205 .entry(c)
206 .and_modify(|s| *s = s.max(score))
207 .or_insert(score);
208 }
209 }
210
211 let mut selected: HashSet<usize> = HashSet::new();
213 for heap in final_heaps {
214 for (_, c) in heap {
215 selected.insert(c);
216 }
217 }
218
219 if let Some(threshold) = centroid_score_threshold {
221 selected.retain(|c| {
222 final_max_scores
223 .get(c)
224 .copied()
225 .unwrap_or(f32::NEG_INFINITY)
226 >= threshold
227 });
228 }
229
230 selected.into_iter().collect()
231}
232
233fn build_sparse_centroid_scores(
237 query: &Array2<f32>,
238 centroids: &CentroidStore,
239 centroid_ids: &HashSet<usize>,
240) -> HashMap<usize, Array1<f32>> {
241 centroid_ids
242 .iter()
243 .map(|&c| {
244 let centroid = centroids.row(c);
245 let scores: Array1<f32> = query.dot(¢roid);
246 (c, scores)
247 })
248 .collect()
249}
250
251fn approximate_score_sparse(
253 sparse_scores: &HashMap<usize, Array1<f32>>,
254 doc_codes: &[usize],
255 num_query_tokens: usize,
256) -> f32 {
257 let mut score = 0.0;
258
259 for q_idx in 0..num_query_tokens {
261 let mut max_score = f32::NEG_INFINITY;
262
263 for &code in doc_codes.iter() {
265 if let Some(centroid_scores) = sparse_scores.get(&code) {
266 let centroid_score = centroid_scores[q_idx];
267 if centroid_score > max_score {
268 max_score = centroid_score;
269 }
270 }
271 }
272
273 if max_score > f32::NEG_INFINITY {
274 score += max_score;
275 }
276 }
277
278 score
279}
280
281fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
283 let mut score = 0.0;
284
285 for q_idx in 0..query_centroid_scores.nrows() {
286 let mut max_score = f32::NEG_INFINITY;
287
288 for &code in doc_codes.iter() {
289 let centroid_score = query_centroid_scores[[q_idx, code as usize]];
290 if centroid_score > max_score {
291 max_score = centroid_score;
292 }
293 }
294
295 if max_score > f32::NEG_INFINITY {
296 score += max_score;
297 }
298 }
299
300 score
301}
302
303pub fn search_one_mmap(
305 index: &crate::index::MmapIndex,
306 query: &Array2<f32>,
307 params: &SearchParameters,
308 subset: Option<&[i64]>,
309) -> Result<QueryResult> {
310 let num_centroids = index.codec.num_centroids();
311 let num_query_tokens = query.nrows();
312
313 let use_batched = params.centroid_batch_size > 0 && num_centroids > params.centroid_batch_size;
315
316 if use_batched {
317 return search_one_mmap_batched(index, query, params, subset);
319 }
320
321 let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
323
324 let eligible_centroids: Option<HashSet<usize>> = subset.map(|subset_docs| {
328 let mut centroids = HashSet::new();
329 for &doc_id in subset_docs {
330 let doc_idx = doc_id as usize;
331 if doc_idx < index.doc_lengths.len() {
332 let start = index.doc_offsets[doc_idx];
333 let end = index.doc_offsets[doc_idx + 1];
334 let codes = index.mmap_codes.slice(start, end);
335 for &c in codes.iter() {
336 centroids.insert(c as usize);
337 }
338 }
339 }
340 centroids
341 });
342
343 let effective_n_ivf_probe = match (&eligible_centroids, subset) {
348 (Some(eligible), Some(subset_docs)) if !eligible.is_empty() => {
349 let num_docs = index.doc_lengths.len();
350 let subset_len = subset_docs.len();
351 let scaled = if subset_len > 0 {
352 (params.n_ivf_probe as u64 * num_docs as u64 / subset_len as u64) as usize
353 } else {
354 params.n_ivf_probe
355 };
356 scaled.max(params.n_ivf_probe).min(eligible.len())
357 }
358 _ => params.n_ivf_probe,
359 };
360
361 let cells_to_probe: Vec<usize> = {
366 let mut selected_centroids = HashSet::new();
367
368 for q_idx in 0..num_query_tokens {
369 let mut centroid_scores: Vec<(usize, f32)> = match &eligible_centroids {
370 Some(eligible) => eligible
371 .iter()
372 .map(|&c| (c, query_centroid_scores[[q_idx, c]]))
373 .collect(),
374 None => (0..num_centroids)
375 .map(|c| (c, query_centroid_scores[[q_idx, c]]))
376 .collect(),
377 };
378
379 let n_probe = effective_n_ivf_probe.min(centroid_scores.len());
383 if centroid_scores.len() > n_probe {
384 centroid_scores.select_nth_unstable_by(n_probe - 1, |a, b| {
385 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
386 });
387 }
388
389 for (c, _) in centroid_scores.iter().take(n_probe) {
390 selected_centroids.insert(*c);
391 }
392 }
393
394 if let Some(threshold) = params.centroid_score_threshold {
396 selected_centroids.retain(|&c| {
397 let max_score: f32 = (0..num_query_tokens)
398 .map(|q_idx| query_centroid_scores[[q_idx, c]])
399 .max_by(|a, b| a.partial_cmp(b).unwrap())
400 .unwrap_or(f32::NEG_INFINITY);
401 max_score >= threshold
402 });
403 }
404
405 selected_centroids.into_iter().collect()
406 };
407
408 let mut candidates = index.get_candidates(&cells_to_probe);
410
411 if let Some(subset_docs) = subset {
413 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
414 candidates.retain(|&c| subset_set.contains(&c));
415 }
416
417 if candidates.is_empty() {
418 return Ok(QueryResult {
419 query_id: 0,
420 passage_ids: vec![],
421 scores: vec![],
422 });
423 }
424
425 let mut approx_scores: Vec<(i64, f32)> = candidates
427 .par_iter()
428 .map(|&doc_id| {
429 let start = index.doc_offsets[doc_id as usize];
430 let end = index.doc_offsets[doc_id as usize + 1];
431 let codes = index.mmap_codes.slice(start, end);
432 let score = approximate_score_mmap(&query_centroid_scores, &codes);
433 (doc_id, score)
434 })
435 .collect();
436
437 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
439 let top_candidates: Vec<i64> = approx_scores
440 .iter()
441 .take(params.n_full_scores)
442 .map(|(id, _)| *id)
443 .collect();
444
445 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
447 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
448
449 if to_decompress.is_empty() {
450 return Ok(QueryResult {
451 query_id: 0,
452 passage_ids: vec![],
453 scores: vec![],
454 });
455 }
456
457 let mut exact_scores: Vec<(i64, f32)> = to_decompress
460 .par_chunks(DECOMPRESS_CHUNK_SIZE)
461 .flat_map(|chunk| {
462 chunk
463 .iter()
464 .filter_map(|&doc_id| {
465 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
466 let score = colbert_score(&query.view(), &doc_embeddings.view());
467 Some((doc_id, score))
468 })
469 .collect::<Vec<_>>()
470 })
471 .collect();
472
473 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
475
476 let result_count = params.top_k.min(exact_scores.len());
478 let passage_ids: Vec<i64> = exact_scores
479 .iter()
480 .take(result_count)
481 .map(|(id, _)| *id)
482 .collect();
483 let scores: Vec<f32> = exact_scores
484 .iter()
485 .take(result_count)
486 .map(|(_, s)| *s)
487 .collect();
488
489 Ok(QueryResult {
490 query_id: 0,
491 passage_ids,
492 scores,
493 })
494}
495
496fn search_one_mmap_batched(
500 index: &crate::index::MmapIndex,
501 query: &Array2<f32>,
502 params: &SearchParameters,
503 subset: Option<&[i64]>,
504) -> Result<QueryResult> {
505 let num_query_tokens = query.nrows();
506
507 let cells_to_probe = ivf_probe_batched(
509 query,
510 &index.codec.centroids,
511 params.n_ivf_probe,
512 params.centroid_batch_size,
513 params.centroid_score_threshold,
514 );
515
516 let mut candidates = index.get_candidates(&cells_to_probe);
518
519 if let Some(subset_docs) = subset {
521 let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
522 candidates.retain(|&c| subset_set.contains(&c));
523 }
524
525 if candidates.is_empty() {
526 return Ok(QueryResult {
527 query_id: 0,
528 passage_ids: vec![],
529 scores: vec![],
530 });
531 }
532
533 let mut unique_centroids: HashSet<usize> = HashSet::new();
535 for &doc_id in &candidates {
536 let start = index.doc_offsets[doc_id as usize];
537 let end = index.doc_offsets[doc_id as usize + 1];
538 let codes = index.mmap_codes.slice(start, end);
539 for &code in codes.iter() {
540 unique_centroids.insert(code as usize);
541 }
542 }
543
544 let sparse_scores =
546 build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
547
548 let mut approx_scores: Vec<(i64, f32)> = candidates
550 .par_iter()
551 .map(|&doc_id| {
552 let start = index.doc_offsets[doc_id as usize];
553 let end = index.doc_offsets[doc_id as usize + 1];
554 let codes = index.mmap_codes.slice(start, end);
555 let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
556 let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
557 (doc_id, score)
558 })
559 .collect();
560
561 approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
563 let top_candidates: Vec<i64> = approx_scores
564 .iter()
565 .take(params.n_full_scores)
566 .map(|(id, _)| *id)
567 .collect();
568
569 let n_decompress = (params.n_full_scores / 4).max(params.top_k);
571 let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
572
573 if to_decompress.is_empty() {
574 return Ok(QueryResult {
575 query_id: 0,
576 passage_ids: vec![],
577 scores: vec![],
578 });
579 }
580
581 let mut exact_scores: Vec<(i64, f32)> = to_decompress
584 .par_chunks(DECOMPRESS_CHUNK_SIZE)
585 .flat_map(|chunk| {
586 chunk
587 .iter()
588 .filter_map(|&doc_id| {
589 let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
590 let score = colbert_score(&query.view(), &doc_embeddings.view());
591 Some((doc_id, score))
592 })
593 .collect::<Vec<_>>()
594 })
595 .collect();
596
597 exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
599
600 let result_count = params.top_k.min(exact_scores.len());
602 let passage_ids: Vec<i64> = exact_scores
603 .iter()
604 .take(result_count)
605 .map(|(id, _)| *id)
606 .collect();
607 let scores: Vec<f32> = exact_scores
608 .iter()
609 .take(result_count)
610 .map(|(_, s)| *s)
611 .collect();
612
613 Ok(QueryResult {
614 query_id: 0,
615 passage_ids,
616 scores,
617 })
618}
619
620pub fn search_many_mmap(
622 index: &crate::index::MmapIndex,
623 queries: &[Array2<f32>],
624 params: &SearchParameters,
625 parallel: bool,
626 subset: Option<&[i64]>,
627) -> Result<Vec<QueryResult>> {
628 if parallel {
629 let results: Vec<QueryResult> = queries
630 .par_iter()
631 .enumerate()
632 .map(|(i, query)| {
633 let mut result =
634 search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
635 query_id: i,
636 passage_ids: vec![],
637 scores: vec![],
638 });
639 result.query_id = i;
640 result
641 })
642 .collect();
643 Ok(results)
644 } else {
645 let mut results = Vec::with_capacity(queries.len());
646 for (i, query) in queries.iter().enumerate() {
647 let mut result = search_one_mmap(index, query, params, subset)?;
648 result.query_id = i;
649 results.push(result);
650 }
651 Ok(results)
652 }
653}
654
655pub type SearchResult = QueryResult;
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_colbert_score() {
664 let query =
666 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
667
668 let doc = Array2::from_shape_vec(
670 (3, 4),
671 vec![
672 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
676 )
677 .unwrap();
678
679 let score = colbert_score(&query.view(), &doc.view());
680 assert!((score - 1.7).abs() < 1e-5);
683 }
684
685 #[test]
686 fn test_search_params_default() {
687 let params = SearchParameters::default();
688 assert_eq!(params.batch_size, 2000);
689 assert_eq!(params.n_full_scores, 4096);
690 assert_eq!(params.top_k, 10);
691 assert_eq!(params.n_ivf_probe, 8);
692 assert_eq!(params.centroid_score_threshold, Some(0.4));
693 }
694}