use hnsw_rs::api::AnnT;
use hnsw_rs::filter::FilterT;
use hnsw_rs::prelude::DataId;
use crate::embedder::Embedding;
use crate::index::IndexResult;
use super::HnswIndex;
struct PredicateFilter<'a>(&'a dyn Fn(&usize) -> bool);
impl FilterT for PredicateFilter<'_> {
fn hnsw_filter(&self, id: &DataId) -> bool {
(self.0)(id)
}
}
impl HnswIndex {
pub fn search(&self, query: &Embedding, k: usize) -> Vec<IndexResult> {
self.search_impl(query, k, None)
}
pub fn search_filtered(
&self,
query: &Embedding,
k: usize,
filter: &dyn Fn(&str) -> bool,
) -> Vec<IndexResult> {
let id_filter = |id: &usize| -> bool {
self.id_map
.get(*id)
.is_some_and(|chunk_id| filter(chunk_id))
};
self.search_impl(query, k, Some(&id_filter))
}
fn search_impl(
&self,
query: &Embedding,
k: usize,
filter: Option<&dyn Fn(&usize) -> bool>,
) -> Vec<IndexResult> {
if self.id_map.is_empty() {
return Vec::new();
}
let _span = tracing::debug_span!(
"hnsw_search",
k,
index_size = self.id_map.len(),
filtered = filter.is_some()
)
.entered();
if query.is_empty() || query.len() != self.dim {
if !query.is_empty() {
tracing::warn!(
expected = self.dim,
actual = query.len(),
"Query embedding dimension mismatch"
);
}
return Vec::new();
}
let index_size = self.id_map.len();
let ef_search = self.ef_search.max(k * 2).min(index_size);
let neighbors = match filter {
Some(f) => {
let wrapper = PredicateFilter(f);
self.inner
.with_hnsw(|h| h.search_filter(query.as_slice(), k, ef_search, Some(&wrapper)))
}
None => self
.inner
.with_hnsw(|h| h.search_neighbours(query.as_slice(), k, ef_search)),
};
neighbors
.into_iter()
.filter_map(|n| {
let idx = n.d_id;
if idx < self.id_map.len() {
let score = 1.0 - n.distance;
if !score.is_finite() {
tracing::warn!(
idx,
distance = n.distance,
"Non-finite HNSW score, skipping"
);
return None;
}
Some(IndexResult {
id: self.id_map[idx].clone(),
score,
})
} else {
tracing::warn!(idx, "Invalid index in HNSW result");
None
}
})
.collect()
}
}