use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap, HashSet};
use ndarray::Array1;
use ndarray::{Array2, ArrayView2};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::codec::CentroidStore;
use crate::error::Result;
use crate::maxsim;
type ProbePartial = (
Vec<BinaryHeap<(Reverse<OrdF32>, usize)>>,
HashMap<usize, f32>,
);
const DECOMPRESS_CHUNK_SIZE: usize = 128;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParameters {
pub batch_size: usize,
pub n_full_scores: usize,
pub top_k: usize,
pub n_ivf_probe: usize,
#[serde(default = "default_centroid_batch_size")]
pub centroid_batch_size: usize,
#[serde(default = "default_centroid_score_threshold")]
pub centroid_score_threshold: Option<f32>,
}
fn default_centroid_batch_size() -> usize {
100_000
}
fn default_centroid_score_threshold() -> Option<f32> {
Some(0.4)
}
impl Default for SearchParameters {
fn default() -> Self {
Self {
batch_size: 2000,
n_full_scores: 4096,
top_k: 10,
n_ivf_probe: 8,
centroid_batch_size: default_centroid_batch_size(),
centroid_score_threshold: default_centroid_score_threshold(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub query_id: usize,
pub passage_ids: Vec<i64>,
pub scores: Vec<f32>,
}
fn colbert_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
maxsim::maxsim_score(query, doc)
}
#[derive(Clone, Copy, PartialEq)]
struct OrdF32(f32);
impl Eq for OrdF32 {}
impl PartialOrd for OrdF32 {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrdF32 {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
fn ivf_probe_batched(
query: &Array2<f32>,
centroids: &CentroidStore,
n_probe: usize,
batch_size: usize,
centroid_score_threshold: Option<f32>,
) -> Vec<usize> {
let num_centroids = centroids.nrows();
let num_tokens = query.nrows();
let batch_ranges: Vec<(usize, usize)> = (0..num_centroids)
.step_by(batch_size)
.map(|start| (start, (start + batch_size).min(num_centroids)))
.collect();
let local_results: Vec<ProbePartial> = batch_ranges
.par_iter()
.map(|&(batch_start, batch_end)| {
let mut heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
.map(|_| BinaryHeap::with_capacity(n_probe + 1))
.collect();
let mut max_scores: HashMap<usize, f32> = HashMap::new();
let batch_centroids = centroids.slice_rows(batch_start, batch_end);
let batch_scores = query.dot(&batch_centroids.t());
for (q_idx, heap) in heaps.iter_mut().enumerate() {
for (local_c, &score) in batch_scores.row(q_idx).iter().enumerate() {
let global_c = batch_start + local_c;
let entry = (Reverse(OrdF32(score)), global_c);
if heap.len() < n_probe {
heap.push(entry);
max_scores
.entry(global_c)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
} else if let Some(&(Reverse(OrdF32(min_score)), _)) = heap.peek() {
if score > min_score {
heap.pop();
heap.push(entry);
max_scores
.entry(global_c)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
}
}
}
(heaps, max_scores)
})
.collect();
let mut final_heaps: Vec<BinaryHeap<(Reverse<OrdF32>, usize)>> = (0..num_tokens)
.map(|_| BinaryHeap::with_capacity(n_probe + 1))
.collect();
let mut final_max_scores: HashMap<usize, f32> = HashMap::new();
for (local_heaps, local_max_scores) in local_results {
for (q_idx, local_heap) in local_heaps.into_iter().enumerate() {
for entry in local_heap {
let (Reverse(OrdF32(score)), _) = entry;
if final_heaps[q_idx].len() < n_probe {
final_heaps[q_idx].push(entry);
} else if let Some(&(Reverse(OrdF32(min_score)), _)) = final_heaps[q_idx].peek() {
if score > min_score {
final_heaps[q_idx].pop();
final_heaps[q_idx].push(entry);
}
}
}
}
for (c, score) in local_max_scores {
final_max_scores
.entry(c)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
}
let mut selected: HashSet<usize> = HashSet::new();
for heap in final_heaps {
for (_, c) in heap {
selected.insert(c);
}
}
if let Some(threshold) = centroid_score_threshold {
selected.retain(|c| {
final_max_scores
.get(c)
.copied()
.unwrap_or(f32::NEG_INFINITY)
>= threshold
});
}
selected.into_iter().collect()
}
fn build_sparse_centroid_scores(
query: &Array2<f32>,
centroids: &CentroidStore,
centroid_ids: &HashSet<usize>,
) -> HashMap<usize, Array1<f32>> {
centroid_ids
.iter()
.map(|&c| {
let centroid = centroids.row(c);
let scores: Array1<f32> = query.dot(¢roid);
(c, scores)
})
.collect()
}
fn approximate_score_sparse(
sparse_scores: &HashMap<usize, Array1<f32>>,
doc_codes: &[usize],
num_query_tokens: usize,
) -> f32 {
let mut score = 0.0;
for q_idx in 0..num_query_tokens {
let mut max_score = f32::NEG_INFINITY;
for &code in doc_codes.iter() {
if let Some(centroid_scores) = sparse_scores.get(&code) {
let centroid_score = centroid_scores[q_idx];
if centroid_score > max_score {
max_score = centroid_score;
}
}
}
if max_score > f32::NEG_INFINITY {
score += max_score;
}
}
score
}
fn approximate_score_mmap(query_centroid_scores: &Array2<f32>, doc_codes: &[i64]) -> f32 {
let mut score = 0.0;
for q_idx in 0..query_centroid_scores.nrows() {
let mut max_score = f32::NEG_INFINITY;
for &code in doc_codes.iter() {
let centroid_score = query_centroid_scores[[q_idx, code as usize]];
if centroid_score > max_score {
max_score = centroid_score;
}
}
if max_score > f32::NEG_INFINITY {
score += max_score;
}
}
score
}
pub fn search_one_mmap(
index: &crate::index::MmapIndex,
query: &Array2<f32>,
params: &SearchParameters,
subset: Option<&[i64]>,
) -> Result<QueryResult> {
let num_centroids = index.codec.num_centroids();
let num_query_tokens = query.nrows();
let use_batched = params.centroid_batch_size > 0 && num_centroids > params.centroid_batch_size;
if use_batched {
return search_one_mmap_batched(index, query, params, subset);
}
let query_centroid_scores = query.dot(&index.codec.centroids_view().t());
let eligible_centroids: Option<HashSet<usize>> = subset.map(|subset_docs| {
let mut centroids = HashSet::new();
for &doc_id in subset_docs {
let doc_idx = doc_id as usize;
if doc_idx < index.doc_lengths.len() {
let start = index.doc_offsets[doc_idx];
let end = index.doc_offsets[doc_idx + 1];
let codes = index.mmap_codes.slice(start, end);
for &c in codes.iter() {
centroids.insert(c as usize);
}
}
}
centroids
});
let effective_n_ivf_probe = match (&eligible_centroids, subset) {
(Some(eligible), Some(subset_docs)) if !eligible.is_empty() => {
let num_docs = index.doc_lengths.len();
let subset_len = subset_docs.len();
let scaled = if subset_len > 0 {
(params.n_ivf_probe as u64 * num_docs as u64 / subset_len as u64) as usize
} else {
params.n_ivf_probe
};
scaled.max(params.n_ivf_probe).min(eligible.len())
}
_ => params.n_ivf_probe,
};
let cells_to_probe: Vec<usize> = {
let mut selected_centroids = HashSet::new();
for q_idx in 0..num_query_tokens {
let mut centroid_scores: Vec<(usize, f32)> = match &eligible_centroids {
Some(eligible) => eligible
.iter()
.map(|&c| (c, query_centroid_scores[[q_idx, c]]))
.collect(),
None => (0..num_centroids)
.map(|c| (c, query_centroid_scores[[q_idx, c]]))
.collect(),
};
let n_probe = effective_n_ivf_probe.min(centroid_scores.len());
if centroid_scores.len() > n_probe {
centroid_scores.select_nth_unstable_by(n_probe - 1, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
}
for (c, _) in centroid_scores.iter().take(n_probe) {
selected_centroids.insert(*c);
}
}
if let Some(threshold) = params.centroid_score_threshold {
selected_centroids.retain(|&c| {
let max_score: f32 = (0..num_query_tokens)
.map(|q_idx| query_centroid_scores[[q_idx, c]])
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f32::NEG_INFINITY);
max_score >= threshold
});
}
selected_centroids.into_iter().collect()
};
let mut candidates = index.get_candidates(&cells_to_probe);
if let Some(subset_docs) = subset {
let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
candidates.retain(|&c| subset_set.contains(&c));
}
if candidates.is_empty() {
return Ok(QueryResult {
query_id: 0,
passage_ids: vec![],
scores: vec![],
});
}
let mut approx_scores: Vec<(i64, f32)> = candidates
.par_iter()
.map(|&doc_id| {
let start = index.doc_offsets[doc_id as usize];
let end = index.doc_offsets[doc_id as usize + 1];
let codes = index.mmap_codes.slice(start, end);
let score = approximate_score_mmap(&query_centroid_scores, &codes);
(doc_id, score)
})
.collect();
approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let top_candidates: Vec<i64> = approx_scores
.iter()
.take(params.n_full_scores)
.map(|(id, _)| *id)
.collect();
let n_decompress = (params.n_full_scores / 4).max(params.top_k);
let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
if to_decompress.is_empty() {
return Ok(QueryResult {
query_id: 0,
passage_ids: vec![],
scores: vec![],
});
}
let mut exact_scores: Vec<(i64, f32)> = to_decompress
.par_chunks(DECOMPRESS_CHUNK_SIZE)
.flat_map(|chunk| {
chunk
.iter()
.filter_map(|&doc_id| {
let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
let score = colbert_score(&query.view(), &doc_embeddings.view());
Some((doc_id, score))
})
.collect::<Vec<_>>()
})
.collect();
exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let result_count = params.top_k.min(exact_scores.len());
let passage_ids: Vec<i64> = exact_scores
.iter()
.take(result_count)
.map(|(id, _)| *id)
.collect();
let scores: Vec<f32> = exact_scores
.iter()
.take(result_count)
.map(|(_, s)| *s)
.collect();
Ok(QueryResult {
query_id: 0,
passage_ids,
scores,
})
}
fn search_one_mmap_batched(
index: &crate::index::MmapIndex,
query: &Array2<f32>,
params: &SearchParameters,
subset: Option<&[i64]>,
) -> Result<QueryResult> {
let num_query_tokens = query.nrows();
let cells_to_probe = ivf_probe_batched(
query,
&index.codec.centroids,
params.n_ivf_probe,
params.centroid_batch_size,
params.centroid_score_threshold,
);
let mut candidates = index.get_candidates(&cells_to_probe);
if let Some(subset_docs) = subset {
let subset_set: HashSet<i64> = subset_docs.iter().copied().collect();
candidates.retain(|&c| subset_set.contains(&c));
}
if candidates.is_empty() {
return Ok(QueryResult {
query_id: 0,
passage_ids: vec![],
scores: vec![],
});
}
let mut unique_centroids: HashSet<usize> = HashSet::new();
for &doc_id in &candidates {
let start = index.doc_offsets[doc_id as usize];
let end = index.doc_offsets[doc_id as usize + 1];
let codes = index.mmap_codes.slice(start, end);
for &code in codes.iter() {
unique_centroids.insert(code as usize);
}
}
let sparse_scores =
build_sparse_centroid_scores(query, &index.codec.centroids, &unique_centroids);
let mut approx_scores: Vec<(i64, f32)> = candidates
.par_iter()
.map(|&doc_id| {
let start = index.doc_offsets[doc_id as usize];
let end = index.doc_offsets[doc_id as usize + 1];
let codes = index.mmap_codes.slice(start, end);
let doc_codes: Vec<usize> = codes.iter().map(|&c| c as usize).collect();
let score = approximate_score_sparse(&sparse_scores, &doc_codes, num_query_tokens);
(doc_id, score)
})
.collect();
approx_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let top_candidates: Vec<i64> = approx_scores
.iter()
.take(params.n_full_scores)
.map(|(id, _)| *id)
.collect();
let n_decompress = (params.n_full_scores / 4).max(params.top_k);
let to_decompress: Vec<i64> = top_candidates.into_iter().take(n_decompress).collect();
if to_decompress.is_empty() {
return Ok(QueryResult {
query_id: 0,
passage_ids: vec![],
scores: vec![],
});
}
let mut exact_scores: Vec<(i64, f32)> = to_decompress
.par_chunks(DECOMPRESS_CHUNK_SIZE)
.flat_map(|chunk| {
chunk
.iter()
.filter_map(|&doc_id| {
let doc_embeddings = index.get_document_embeddings(doc_id as usize).ok()?;
let score = colbert_score(&query.view(), &doc_embeddings.view());
Some((doc_id, score))
})
.collect::<Vec<_>>()
})
.collect();
exact_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let result_count = params.top_k.min(exact_scores.len());
let passage_ids: Vec<i64> = exact_scores
.iter()
.take(result_count)
.map(|(id, _)| *id)
.collect();
let scores: Vec<f32> = exact_scores
.iter()
.take(result_count)
.map(|(_, s)| *s)
.collect();
Ok(QueryResult {
query_id: 0,
passage_ids,
scores,
})
}
pub fn search_many_mmap(
index: &crate::index::MmapIndex,
queries: &[Array2<f32>],
params: &SearchParameters,
parallel: bool,
subset: Option<&[i64]>,
) -> Result<Vec<QueryResult>> {
if parallel {
let results: Vec<QueryResult> = queries
.par_iter()
.enumerate()
.map(|(i, query)| {
let mut result =
search_one_mmap(index, query, params, subset).unwrap_or_else(|_| QueryResult {
query_id: i,
passage_ids: vec![],
scores: vec![],
});
result.query_id = i;
result
})
.collect();
Ok(results)
} else {
let mut results = Vec::with_capacity(queries.len());
for (i, query) in queries.iter().enumerate() {
let mut result = search_one_mmap(index, query, params, subset)?;
result.query_id = i;
results.push(result);
}
Ok(results)
}
}
pub type SearchResult = QueryResult;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_colbert_score() {
let query =
Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let doc = Array2::from_shape_vec(
(3, 4),
vec![
0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
)
.unwrap();
let score = colbert_score(&query.view(), &doc.view());
assert!((score - 1.7).abs() < 1e-5);
}
#[test]
fn test_search_params_default() {
let params = SearchParameters::default();
assert_eq!(params.batch_size, 2000);
assert_eq!(params.n_full_scores, 4096);
assert_eq!(params.top_k, 10);
assert_eq!(params.n_ivf_probe, 8);
assert_eq!(params.centroid_score_threshold, Some(0.4));
}
}