use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
use iqdb_distance::compute_batch;
use iqdb_filter::FilterEvaluator;
use iqdb_types::{DistanceMetric, Hit, IqdbError, Result, SearchParams};
use crate::graph::NodeIdx;
use crate::index::HnswIndex;
use crate::topk::{Scored, take_topk_sorted};
pub(crate) fn search(idx: &HnswIndex, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
idx.check_dim(query.len())?;
if params.metric != idx.metric {
return Err(IqdbError::InvalidMetric);
}
if params.k == 0 || idx.live_count == 0 || idx.entry.is_none() {
return Ok(Vec::new());
}
let ef_base = idx.cfg.ef_search.max(params.k);
let ef_effective = if params.filter.is_some() {
ef_base.saturating_mul(idx.cfg.filter_widen)
} else {
ef_base
};
let scored = knn_search(idx, query, ef_effective)?;
match ¶ms.filter {
None => Ok(scored
.into_iter()
.filter(|s| !idx.tombstoned[s.node as usize])
.take(params.k)
.map(|s| Hit {
id: idx.ids[s.node as usize].clone(),
distance: s.dist,
metadata: idx.metadata[s.node as usize].clone(),
})
.collect()),
Some(filter) => {
let evaluator = FilterEvaluator::new(filter.clone())?;
Ok(scored
.into_iter()
.filter(|s| !idx.tombstoned[s.node as usize])
.filter(|s| evaluator.evaluate(idx.metadata[s.node as usize].as_ref()))
.take(params.k)
.map(|s| Hit {
id: idx.ids[s.node as usize].clone(),
distance: s.dist,
metadata: idx.metadata[s.node as usize].clone(),
})
.collect())
}
}
}
pub(crate) fn search_with_ef(
idx: &HnswIndex,
query: &[f32],
params: &SearchParams,
ef: usize,
) -> Result<Vec<Hit>> {
idx.check_dim(query.len())?;
if params.metric != idx.metric {
return Err(IqdbError::InvalidMetric);
}
if params.k == 0 || idx.live_count == 0 || idx.entry.is_none() {
return Ok(Vec::new());
}
let ef_base = ef.max(params.k);
let ef_effective = if params.filter.is_some() {
ef_base.saturating_mul(idx.cfg.filter_widen)
} else {
ef_base
};
let scored = knn_search(idx, query, ef_effective)?;
match ¶ms.filter {
None => Ok(scored
.into_iter()
.filter(|s| !idx.tombstoned[s.node as usize])
.take(params.k)
.map(|s| Hit {
id: idx.ids[s.node as usize].clone(),
distance: s.dist,
metadata: idx.metadata[s.node as usize].clone(),
})
.collect()),
Some(filter) => {
let evaluator = FilterEvaluator::new(filter.clone())?;
Ok(scored
.into_iter()
.filter(|s| !idx.tombstoned[s.node as usize])
.filter(|s| evaluator.evaluate(idx.metadata[s.node as usize].as_ref()))
.take(params.k)
.map(|s| Hit {
id: idx.ids[s.node as usize].clone(),
distance: s.dist,
metadata: idx.metadata[s.node as usize].clone(),
})
.collect())
}
}
}
pub(crate) fn knn_search(idx: &HnswIndex, query: &[f32], ef: usize) -> Result<Vec<Scored>> {
let entry = match idx.entry {
Some(e) => e,
None => return Ok(Vec::new()),
};
let entry_dist = distance_to(idx, query, entry)?;
let mut cur = Scored {
dist: entry_dist,
seq: idx.seqs[entry as usize],
node: entry,
};
let mut layer = idx.top_layer;
while layer >= 1 {
let result_heap = search_layer(idx, query, &[cur], layer, 1)?;
if let Some(nearest) = best_of(&result_heap) {
cur = nearest;
}
layer -= 1;
}
let result_heap = search_layer(idx, query, &[cur], 0, ef)?;
Ok(take_topk_sorted(result_heap, ef))
}
pub(crate) fn search_layer(
idx: &HnswIndex,
query: &[f32],
entries: &[Scored],
layer: u8,
ef: usize,
) -> Result<BinaryHeap<Scored>> {
let mut visited: HashSet<NodeIdx> = HashSet::with_capacity(ef.saturating_mul(2));
let mut candidates: BinaryHeap<Reverse<Scored>> = BinaryHeap::with_capacity(ef);
let mut results: BinaryHeap<Scored> = BinaryHeap::with_capacity(ef);
for e in entries {
if !visited.insert(e.node) {
continue;
}
candidates.push(Reverse(*e));
push_to_results(&mut results, *e, ef);
}
while let Some(Reverse(next)) = candidates.peek().copied() {
if results.len() >= ef {
if let Some(worst) = results.peek() {
if next.dist > worst.dist {
break;
}
}
}
let _ = candidates.pop();
let c = next;
let c_node = c.node as usize;
let c_layers = &idx.layers[c_node];
if (layer as usize) >= c_layers.len() {
continue;
}
let adj = &c_layers[layer as usize];
let mut new_neighbours: Vec<NodeIdx> = Vec::with_capacity(adj.len());
for &n in adj {
if visited.insert(n) {
new_neighbours.push(n);
}
}
if new_neighbours.is_empty() {
continue;
}
let slices: Vec<&[f32]> = new_neighbours
.iter()
.map(|&n| &idx.vectors[n as usize][..])
.collect();
let mut dists = vec![0.0_f32; new_neighbours.len()];
compute_batch(idx.metric, query, &slices, &mut dists)?;
for (i, &n) in new_neighbours.iter().enumerate() {
let mut d = dists[i];
if matches!(idx.metric, DistanceMetric::DotProduct) {
d = -d;
}
let scored = Scored {
dist: d,
seq: idx.seqs[n as usize],
node: n,
};
let competitive = if results.len() < ef {
true
} else if let Some(worst) = results.peek() {
scored < *worst
} else {
true
};
if competitive {
candidates.push(Reverse(scored));
push_to_results(&mut results, scored, ef);
}
}
}
Ok(results)
}
pub(crate) fn distance_to(idx: &HnswIndex, query: &[f32], node: NodeIdx) -> Result<f32> {
let mut buf = [0.0_f32; 1];
let slice: [&[f32]; 1] = [&idx.vectors[node as usize][..]];
compute_batch(idx.metric, query, &slice, &mut buf)?;
let raw = buf[0];
Ok(if matches!(idx.metric, DistanceMetric::DotProduct) {
-raw
} else {
raw
})
}
pub(crate) fn distance_between(idx: &HnswIndex, a: NodeIdx, b: NodeIdx) -> Result<f32> {
let mut buf = [0.0_f32; 1];
let slice: [&[f32]; 1] = [&idx.vectors[b as usize][..]];
compute_batch(idx.metric, &idx.vectors[a as usize], &slice, &mut buf)?;
let raw = buf[0];
Ok(if matches!(idx.metric, DistanceMetric::DotProduct) {
-raw
} else {
raw
})
}
fn push_to_results(results: &mut BinaryHeap<Scored>, scored: Scored, ef: usize) {
if results.len() < ef {
results.push(scored);
} else if let Some(worst) = results.peek() {
if scored < *worst {
let _evicted = results.pop();
results.push(scored);
}
}
}
fn best_of(results: &BinaryHeap<Scored>) -> Option<Scored> {
results.iter().min().copied()
}