use crate::node::{NodeId, SearchHit};
use crate::VectorType;
use instant_distance::{Builder, Point, Search};
#[derive(Clone, Debug)]
pub struct VectorPoint<T: VectorType> {
pub id: NodeId,
pub vec: Vec<T>,
}
unsafe impl<T: VectorType> Sync for VectorPoint<T> {}
impl<T: VectorType> Point for VectorPoint<T> {
fn distance(&self, other: &Self) -> f32 {
let sim = T::similarity(&self.vec, &other.vec);
1.0 - sim
}
}
pub struct HnswIndex<T: VectorType> {
dim: usize,
index: Option<instant_distance::HnswMap<VectorPoint<T>, NodeId>>,
}
impl<T: VectorType> HnswIndex<T> {
pub fn new(dim: usize) -> Self {
Self { dim, index: None }
}
pub fn rebuild(
&mut self,
flat_vectors: &[T],
dim: usize,
id_mapper: impl Fn(usize) -> NodeId,
is_active: impl Fn(usize) -> bool,
) {
self.dim = dim;
let num_vectors = flat_vectors.len() / self.dim;
if num_vectors == 0 {
self.index = None;
return;
}
let mut points = Vec::with_capacity(num_vectors);
let mut values = Vec::with_capacity(num_vectors);
for i in 0..num_vectors {
if !is_active(i) {
continue;
}
let offset = i * self.dim;
let vec_slice = &flat_vectors[offset..offset + self.dim];
let id = id_mapper(i);
points.push(VectorPoint {
id,
vec: vec_slice.to_vec(),
});
values.push(id);
}
if points.is_empty() {
self.index = None;
return;
}
let hnsw = Builder::default().build(points, values);
self.index = Some(hnsw);
}
pub fn search(&self, query: &[T], top_k: usize, min_score: f32) -> Vec<SearchHit> {
if let Some(ref hnsw) = self.index {
let q_point = VectorPoint {
id: 0,
vec: query.to_vec(),
};
let mut search = Search::default();
let results = hnsw.search(&q_point, &mut search);
let mut hits = Vec::new();
for item in results.take(top_k) {
let score = 1.0 - item.distance; if score >= min_score {
hits.push(SearchHit {
id: *item.value,
score,
payload: serde_json::Value::Null,
});
}
}
hits
} else {
Vec::new()
}
}
pub fn is_built(&self) -> bool {
self.index.is_some()
}
}