use iqdb_distance::compute_batch;
use iqdb_filter::FilterEvaluator;
use iqdb_quantize::{ProductQuantizer, Quantizer};
use iqdb_types::{DistanceMetric, Hit, IqdbError, Result};
use crate::config::IvfConfig;
use crate::index::IvfIndex;
use crate::topk::select_topk_indices;
const PQ_K_CENTROIDS: usize = 256;
pub(crate) fn train_pq(cfg: &IvfConfig, sample: &[&[f32]]) -> Result<ProductQuantizer> {
let m = cfg.pq_subvectors.ok_or(IqdbError::InvalidConfig {
reason: "IvfConfig.pq_subvectors must be Some(_) when use_pq = true",
})?;
let mut pq = ProductQuantizer::with_config(m, PQ_K_CENTROIDS, cfg.seed);
pq.train(sample)?;
Ok(pq)
}
pub(crate) fn scan_pq_unfiltered(
index: &IvfIndex,
query: &[f32],
probed: &[usize],
k: usize,
) -> Result<Vec<Hit>> {
let pq = index.pq().ok_or(IqdbError::InvalidConfig {
reason: "IvfIndex is in IVF-PQ mode but the quantizer was not trained",
})?;
let metric = index.metric();
let tables = pq.build_query_tables(query, metric)?;
let mut cand_distances: Vec<f32> = Vec::new();
let mut cand_seqs: Vec<u64> = Vec::new();
let mut cand_addrs: Vec<(usize, usize)> = Vec::new();
for &c in probed {
let list = index.inverted_list(c);
for (pos, entry) in list.iter().enumerate() {
let code = entry.pq_code_or_err()?;
let mut d = tables.distance(code)?;
if matches!(metric, DistanceMetric::DotProduct) {
d = -d;
}
cand_distances.push(d);
cand_seqs.push(entry.seq());
cand_addrs.push((c, pos));
}
}
if cand_distances.is_empty() {
return Ok(Vec::new());
}
finish_scan_pq(index, query, k, cand_distances, cand_seqs, cand_addrs)
}
pub(crate) fn scan_pq_filtered(
index: &IvfIndex,
evaluator: &FilterEvaluator,
query: &[f32],
probed: &[usize],
k: usize,
) -> Result<Vec<Hit>> {
let pq = index.pq().ok_or(IqdbError::InvalidConfig {
reason: "IvfIndex is in IVF-PQ mode but the quantizer was not trained",
})?;
let metric = index.metric();
let tables = pq.build_query_tables(query, metric)?;
let mut cand_distances: Vec<f32> = Vec::new();
let mut cand_seqs: Vec<u64> = Vec::new();
let mut cand_addrs: Vec<(usize, usize)> = Vec::new();
for &c in probed {
let list = index.inverted_list(c);
for (pos, entry) in list.iter().enumerate() {
if !evaluator.evaluate(entry.metadata()) {
continue;
}
let code = entry.pq_code_or_err()?;
let mut d = tables.distance(code)?;
if matches!(metric, DistanceMetric::DotProduct) {
d = -d;
}
cand_distances.push(d);
cand_seqs.push(entry.seq());
cand_addrs.push((c, pos));
}
}
if cand_distances.is_empty() {
return Ok(Vec::new());
}
finish_scan_pq(index, query, k, cand_distances, cand_seqs, cand_addrs)
}
fn finish_scan_pq(
index: &IvfIndex,
query: &[f32],
k: usize,
cand_distances: Vec<f32>,
cand_seqs: Vec<u64>,
cand_addrs: Vec<(usize, usize)>,
) -> Result<Vec<Hit>> {
let refine_factor = index.cfg().pq_refine_factor;
if refine_factor == 0 {
let chosen = select_topk_indices(&cand_distances, &cand_seqs, k);
return build_hits_from_adc(index, &chosen, &cand_distances, &cand_addrs);
}
let shortlist_k = (refine_factor as usize)
.saturating_mul(k)
.min(cand_distances.len());
let shortlist = select_topk_indices(&cand_distances, &cand_seqs, shortlist_k);
refine_shortlist(index, query, &shortlist, &cand_addrs, k)
}
fn build_hits_from_adc(
index: &IvfIndex,
chosen: &[usize],
cand_distances: &[f32],
cand_addrs: &[(usize, usize)],
) -> Result<Vec<Hit>> {
let mut hits = Vec::with_capacity(chosen.len());
for &cand_idx in chosen {
let (c, pos) = cand_addrs[cand_idx];
let entry = &index.inverted_list(c)[pos];
hits.push(Hit {
id: entry.id().clone(),
distance: cand_distances[cand_idx],
metadata: entry.metadata().cloned(),
});
}
Ok(hits)
}
fn refine_shortlist(
index: &IvfIndex,
query: &[f32],
shortlist: &[usize],
cand_addrs: &[(usize, usize)],
k: usize,
) -> Result<Vec<Hit>> {
if shortlist.is_empty() {
return Ok(Vec::new());
}
let metric = index.metric();
let mut refs: Vec<&[f32]> = Vec::with_capacity(shortlist.len());
let mut seqs: Vec<u64> = Vec::with_capacity(shortlist.len());
let mut addrs: Vec<(usize, usize)> = Vec::with_capacity(shortlist.len());
for &cand_idx in shortlist {
let (c, pos) = cand_addrs[cand_idx];
let entry = &index.inverted_list(c)[pos];
refs.push(entry.vector_slice());
seqs.push(entry.seq());
addrs.push((c, pos));
}
let mut exact = vec![0.0_f32; refs.len()];
compute_batch(metric, query, &refs, &mut exact)?;
if matches!(metric, DistanceMetric::DotProduct) {
for d in exact.iter_mut() {
*d = -*d;
}
}
let chosen = select_topk_indices(&exact, &seqs, k);
let mut hits = Vec::with_capacity(chosen.len());
for &idx in &chosen {
let (c, pos) = addrs[idx];
let entry = &index.inverted_list(c)[pos];
hits.push(Hit {
id: entry.id().clone(),
distance: exact[idx],
metadata: entry.metadata().cloned(),
});
}
Ok(hits)
}