use instant_distance::{Builder, HnswMap, Point, Search};
use std::sync::Mutex;
const REBUILD_THRESHOLD: usize = 200;
#[derive(Clone, Debug)]
pub struct EmbeddingPoint(pub Vec<f32>);
impl instant_distance::Point for EmbeddingPoint {
fn distance(&self, other: &Self) -> f32 {
let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
1.0 - dot
}
}
pub struct VectorIndex {
inner: Mutex<IndexState>,
}
struct IndexState {
hnsw: Option<HnswMap<EmbeddingPoint, String>>,
overflow: Vec<(String, Vec<f32>)>,
all_entries: Vec<(String, Vec<f32>)>,
}
#[derive(Debug, Clone)]
pub struct VectorHit {
pub id: String,
pub distance: f32,
}
impl VectorIndex {
pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
let hnsw = Self::build_hnsw(&entries);
VectorIndex {
inner: Mutex::new(IndexState {
hnsw,
overflow: Vec::new(),
all_entries: entries,
}),
}
}
pub fn empty() -> Self {
VectorIndex {
inner: Mutex::new(IndexState {
hnsw: None,
overflow: Vec::new(),
all_entries: Vec::new(),
}),
}
}
fn build_hnsw(entries: &[(String, Vec<f32>)]) -> Option<HnswMap<EmbeddingPoint, String>> {
if entries.is_empty() {
return None;
}
let points: Vec<EmbeddingPoint> = entries
.iter()
.map(|(_, emb)| EmbeddingPoint(emb.clone()))
.collect();
let values: Vec<String> = entries.iter().map(|(id, _)| id.clone()).collect();
Some(Builder::default().build(points, values))
}
pub fn insert(&self, id: String, embedding: Vec<f32>) {
let mut state = self.inner.lock().unwrap();
state.all_entries.push((id.clone(), embedding.clone()));
state.overflow.push((id, embedding));
if state.overflow.len() >= REBUILD_THRESHOLD {
state.hnsw = Self::build_hnsw(&state.all_entries);
state.overflow.clear();
}
}
pub fn remove(&self, id: &str) {
let mut state = self.inner.lock().unwrap();
state.all_entries.retain(|(eid, _)| eid != id);
state.overflow.retain(|(eid, _)| eid != id);
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorHit> {
let state = self.inner.lock().unwrap();
let query_point = EmbeddingPoint(query.to_vec());
let mut results: Vec<VectorHit> = Vec::with_capacity(k * 2);
let valid_ids: std::collections::HashSet<&str> = state
.all_entries
.iter()
.map(|(id, _)| id.as_str())
.collect();
if let Some(ref hnsw) = state.hnsw {
let mut search = Search::default();
for item in hnsw.search(&query_point, &mut search) {
if !valid_ids.contains(item.value.as_str()) {
continue; }
results.push(VectorHit {
id: item.value.clone(),
distance: item.distance,
});
if results.len() >= k * 2 {
break;
}
}
}
let mut overflow_hits: Vec<VectorHit> = state
.overflow
.iter()
.map(|(id, emb)| {
let point = EmbeddingPoint(emb.clone());
VectorHit {
id: id.clone(),
distance: query_point.distance(&point),
}
})
.collect();
overflow_hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results.extend(overflow_hits);
let mut seen = std::collections::HashSet::new();
results.retain(|hit| seen.insert(hit.id.clone()));
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results.truncate(k);
results
}
pub fn len(&self) -> usize {
let state = self.inner.lock().unwrap();
state.all_entries.len()
}
pub fn rebuild(&self) {
let mut state = self.inner.lock().unwrap();
state.hnsw = Self::build_hnsw(&state.all_entries);
state.overflow.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_embedding(values: &[f32]) -> Vec<f32> {
let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
values.iter().map(|v| v / norm).collect()
}
#[test]
fn empty_index_returns_empty() {
let idx = VectorIndex::empty();
let results = idx.search(&[1.0, 0.0, 0.0], 10);
assert!(results.is_empty());
}
#[test]
fn basic_search() {
let entries = vec![
("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
("b".into(), make_embedding(&[0.0, 1.0, 0.0])),
("c".into(), make_embedding(&[0.0, 0.0, 1.0])),
];
let idx = VectorIndex::build(entries);
let results = idx.search(&make_embedding(&[1.0, 0.1, 0.0]), 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "a"); }
#[test]
fn insert_and_search_overflow() {
let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
let idx = VectorIndex::build(entries);
idx.insert("b".into(), make_embedding(&[0.9, 0.1, 0.0]));
let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "a");
assert_eq!(results[1].id, "b");
}
#[test]
fn remove_excludes_from_results() {
let entries = vec![
("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
("b".into(), make_embedding(&[0.9, 0.1, 0.0])),
];
let idx = VectorIndex::build(entries);
idx.remove("a");
let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 5);
assert!(results.iter().all(|h| h.id != "a"));
}
}