use iqdb_distance::compute_batch;
use iqdb_filter::FilterEvaluator;
use iqdb_types::{DistanceMetric, Hit, IqdbError, Result, SearchParams};
use crate::index::{InvertedListEntry, IvfIndex};
use crate::pq_variant;
use crate::topk::select_topk_indices;
pub(crate) fn ivf_search(
index: &IvfIndex,
query: &[f32],
params: &SearchParams,
) -> Result<Vec<Hit>> {
if query.len() != index.dim() {
return Err(IqdbError::DimensionMismatch {
expected: index.dim(),
found: query.len(),
});
}
if params.metric != index.metric() {
return Err(IqdbError::InvalidMetric);
}
if params.k == 0 || index.is_empty() {
return Ok(Vec::new());
}
let evaluator: Option<FilterEvaluator> = match ¶ms.filter {
None => None,
Some(f) => Some(FilterEvaluator::new(f.clone())?),
};
let centroids = index.centroids_slice();
if centroids.is_empty() {
return Ok(Vec::new());
}
let centroid_refs: Vec<&[f32]> = centroids.iter().map(|c| c.as_slice()).collect();
let mut centroid_dists = vec![0.0_f32; centroid_refs.len()];
compute_batch(index.metric(), query, ¢roid_refs, &mut centroid_dists)?;
if matches!(index.metric(), DistanceMetric::DotProduct) {
for d in centroid_dists.iter_mut() {
*d = -*d;
}
}
let centroid_seqs: Vec<u64> = (0..centroid_refs.len() as u64).collect();
let n_probes = index.n_probes().min(centroid_refs.len());
let probed: Vec<usize> = select_topk_indices(¢roid_dists, ¢roid_seqs, n_probes);
match (index.cfg().use_pq, evaluator.as_ref()) {
(false, None) => scan_flat_unfiltered(index, query, &probed, params.k),
(false, Some(eval)) => scan_flat_filtered(index, eval, query, &probed, params.k),
(true, None) => pq_variant::scan_pq_unfiltered(index, query, &probed, params.k),
(true, Some(eval)) => pq_variant::scan_pq_filtered(index, eval, query, &probed, params.k),
}
}
fn scan_flat_unfiltered(
index: &IvfIndex,
query: &[f32],
probed: &[usize],
k: usize,
) -> Result<Vec<Hit>> {
let mut candidate_distances: Vec<f32> = Vec::new();
let mut candidate_seqs: Vec<u64> = Vec::new();
let mut candidate_addrs: Vec<(usize, usize)> = Vec::new();
for &c in probed {
let list: &[InvertedListEntry] = index.inverted_list(c);
if list.is_empty() {
continue;
}
let vec_refs: Vec<&[f32]> = list.iter().map(|e| e.vector_slice()).collect();
let mut dists = vec![0.0_f32; vec_refs.len()];
compute_batch(index.metric(), query, &vec_refs, &mut dists)?;
if matches!(index.metric(), DistanceMetric::DotProduct) {
for d in dists.iter_mut() {
*d = -*d;
}
}
for (pos, entry) in list.iter().enumerate() {
candidate_distances.push(dists[pos]);
candidate_seqs.push(entry.seq());
candidate_addrs.push((c, pos));
}
}
if candidate_distances.is_empty() {
return Ok(Vec::new());
}
let chosen = select_topk_indices(&candidate_distances, &candidate_seqs, k);
let mut hits = Vec::with_capacity(chosen.len());
for cand_idx in chosen {
let (c, pos) = candidate_addrs[cand_idx];
let entry = &index.inverted_list(c)[pos];
hits.push(Hit {
id: entry.id().clone(),
distance: candidate_distances[cand_idx],
metadata: entry.metadata().cloned(),
});
}
Ok(hits)
}
fn scan_flat_filtered(
index: &IvfIndex,
evaluator: &FilterEvaluator,
query: &[f32],
probed: &[usize],
k: usize,
) -> Result<Vec<Hit>> {
let mut candidate_distances: Vec<f32> = Vec::new();
let mut candidate_seqs: Vec<u64> = Vec::new();
let mut candidate_addrs: Vec<(usize, usize)> = Vec::new();
for &c in probed {
let list: &[InvertedListEntry] = index.inverted_list(c);
if list.is_empty() {
continue;
}
let mut survivor_positions: Vec<usize> = Vec::new();
let mut survivor_refs: Vec<&[f32]> = Vec::new();
for (pos, entry) in list.iter().enumerate() {
if evaluator.evaluate(entry.metadata()) {
survivor_positions.push(pos);
survivor_refs.push(entry.vector_slice());
}
}
if survivor_refs.is_empty() {
continue;
}
let mut dists = vec![0.0_f32; survivor_refs.len()];
compute_batch(index.metric(), query, &survivor_refs, &mut dists)?;
if matches!(index.metric(), DistanceMetric::DotProduct) {
for d in dists.iter_mut() {
*d = -*d;
}
}
for (i, &pos) in survivor_positions.iter().enumerate() {
candidate_distances.push(dists[i]);
candidate_seqs.push(list[pos].seq());
candidate_addrs.push((c, pos));
}
}
if candidate_distances.is_empty() {
return Ok(Vec::new());
}
let chosen = select_topk_indices(&candidate_distances, &candidate_seqs, k);
let mut hits = Vec::with_capacity(chosen.len());
for cand_idx in chosen {
let (c, pos) = candidate_addrs[cand_idx];
let entry = &index.inverted_list(c)[pos];
hits.push(Hit {
id: entry.id().clone(),
distance: candidate_distances[cand_idx],
metadata: entry.metadata().cloned(),
});
}
Ok(hits)
}