use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use instant_distance::{Builder, HnswMap, Point, Search};
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument};
use crate::domain::{EmbeddingId, HnswConfig, VectorError};
#[derive(Clone, Debug)]
struct VectorPoint {
data: Vec<f32>,
}
impl instant_distance::Point for VectorPoint {
fn distance(&self, other: &Self) -> f32 {
let dot: f32 = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a: f32 = self.data.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = other.data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 1.0; }
let cosine_sim = dot / (norm_a * norm_b);
1.0 - cosine_sim.clamp(-1.0, 1.0)
}
}
#[derive(Serialize, Deserialize)]
struct SerializedIndex {
vectors: Vec<(EmbeddingId, Vec<f32>)>,
dimensions: usize,
}
pub struct HnswIndex {
inner: Option<HnswMap<VectorPoint, EmbeddingId>>,
id_to_idx: HashMap<EmbeddingId, usize>,
vectors: Vec<(EmbeddingId, Vec<f32>)>,
config: HnswConfig,
needs_rebuild: bool,
_search_buf: Search,
}
impl HnswIndex {
pub fn new(config: &HnswConfig) -> Self {
Self {
inner: None,
id_to_idx: HashMap::new(),
vectors: Vec::new(),
config: config.clone(),
needs_rebuild: false,
_search_buf: Search::default(),
}
}
#[instrument(skip(self, vector), fields(dim = vector.len()))]
pub fn insert(&mut self, id: EmbeddingId, vector: &[f32]) -> Result<(), VectorError> {
if vector.len() != self.config.dimensions {
return Err(VectorError::dimension_mismatch(
self.config.dimensions,
vector.len(),
));
}
if self.vectors.len() >= self.config.max_elements {
return Err(VectorError::capacity_exceeded(
self.config.max_elements,
self.vectors.len(),
));
}
if self.id_to_idx.contains_key(&id) {
return Err(VectorError::DuplicateId(id));
}
let idx = self.vectors.len();
self.id_to_idx.insert(id, idx);
self.vectors.push((id, vector.to_vec()));
self.needs_rebuild = true;
debug!(id = %id, idx = idx, "Inserted vector");
Ok(())
}
#[instrument(skip(self, query), fields(dim = query.len(), k))]
pub fn search(&self, query: &[f32], k: usize) -> Vec<(EmbeddingId, f32)> {
if self.vectors.is_empty() {
return Vec::new();
}
let inner = match &self.inner {
Some(inner) => inner,
None => return self.brute_force_search(query, k),
};
let query_point = VectorPoint {
data: query.to_vec(),
};
let mut search = Search::default();
let results = inner.search(&query_point, &mut search);
results
.take(k)
.map(|item| {
let id = *item.value;
let distance = item.distance;
(id, distance)
})
.collect()
}
fn brute_force_search(&self, query: &[f32], k: usize) -> Vec<(EmbeddingId, f32)> {
let query_point = VectorPoint {
data: query.to_vec(),
};
let mut distances: Vec<_> = self
.vectors
.iter()
.map(|(id, vec)| {
let point = VectorPoint { data: vec.clone() };
let dist = query_point.distance(&point);
(*id, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
distances
}
pub fn remove(&mut self, id: &EmbeddingId) -> Result<(), VectorError> {
let idx = self
.id_to_idx
.remove(id)
.ok_or_else(|| VectorError::NotFound(*id))?;
self.vectors.swap_remove(idx);
if idx < self.vectors.len() {
let swapped_id = self.vectors[idx].0;
self.id_to_idx.insert(swapped_id, idx);
}
self.needs_rebuild = true;
debug!(id = %id, "Removed vector");
Ok(())
}
#[inline]
pub fn contains(&self, id: &EmbeddingId) -> bool {
self.id_to_idx.contains_key(id)
}
#[inline]
pub fn len(&self) -> usize {
self.vectors.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn get_vector(&self, id: &EmbeddingId) -> Option<Vec<f32>> {
self.id_to_idx
.get(id)
.map(|&idx| self.vectors[idx].1.clone())
}
pub fn clear(&mut self) {
self.vectors.clear();
self.id_to_idx.clear();
self.inner = None;
self.needs_rebuild = false;
}
#[instrument(skip(self))]
pub fn build(&mut self) -> Result<(), VectorError> {
if self.vectors.is_empty() {
self.inner = None;
self.needs_rebuild = false;
return Ok(());
}
let points: Vec<VectorPoint> = self
.vectors
.iter()
.map(|(_, vec)| VectorPoint { data: vec.clone() })
.collect();
let values: Vec<EmbeddingId> = self.vectors.iter().map(|(id, _)| *id).collect();
let hnsw = Builder::default()
.ef_construction(self.config.ef_construction)
.build(points, values);
self.inner = Some(hnsw);
self.needs_rebuild = false;
debug!(
vectors = self.vectors.len(),
"Built HNSW index"
);
Ok(())
}
pub fn rebuild_if_needed(&mut self) -> Result<(), VectorError> {
if self.needs_rebuild {
self.build()
} else {
Ok(())
}
}
#[instrument(skip(self))]
pub fn save(&self, path: &Path) -> Result<(), VectorError> {
let file = File::create(path)?;
let writer = BufWriter::new(file);
let serialized = SerializedIndex {
vectors: self.vectors.clone(),
dimensions: self.config.dimensions,
};
bincode::serialize_into(writer, &serialized)?;
debug!(path = %path.display(), vectors = self.vectors.len(), "Saved index");
Ok(())
}
#[instrument]
pub fn load(path: &Path) -> Result<Self, VectorError> {
if !path.exists() {
return Err(VectorError::FileNotFound(path.to_path_buf()));
}
let file = File::open(path)?;
let reader = BufReader::new(file);
let serialized: SerializedIndex = bincode::deserialize_from(reader)
.map_err(|e| VectorError::corrupted(format!("Failed to deserialize: {e}")))?;
let config = HnswConfig::for_dimension(serialized.dimensions);
let mut index = Self::new(&config);
for (id, vector) in serialized.vectors {
let idx = index.vectors.len();
index.id_to_idx.insert(id, idx);
index.vectors.push((id, vector));
}
index.build()?;
debug!(
path = %path.display(),
vectors = index.vectors.len(),
"Loaded index"
);
Ok(index)
}
pub fn ids(&self) -> impl Iterator<Item = &EmbeddingId> {
self.id_to_idx.keys()
}
pub fn iter(&self) -> impl Iterator<Item = (&EmbeddingId, &[f32])> {
self.vectors.iter().map(|(id, vec)| (id, vec.as_slice()))
}
}
impl Default for HnswIndex {
fn default() -> Self {
Self::new(&HnswConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
fn create_test_index() -> HnswIndex {
let config = HnswConfig::for_dimension(64).with_max_elements(1000);
HnswIndex::new(&config)
}
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
(0..dim)
.map(|i| {
let mut hasher = DefaultHasher::new();
(seed, i).hash(&mut hasher);
let h = hasher.finish();
((h % 1000) as f32 / 1000.0) * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_insert_and_search() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let v1 = random_vector(64, 1);
let v2 = random_vector(64, 2);
index.insert(id1, &v1).unwrap();
index.insert(id2, &v2).unwrap();
index.build().unwrap();
let results = index.search(&v1, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, id1); }
#[test]
fn test_dimension_validation() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let wrong_dim = random_vector(32, 1);
let result = index.insert(id, &wrong_dim);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch { expected: 64, got: 32 })
));
}
#[test]
fn test_duplicate_detection() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let v = random_vector(64, 1);
index.insert(id, &v).unwrap();
let result = index.insert(id, &v);
assert!(matches!(result, Err(VectorError::DuplicateId(_))));
}
#[test]
fn test_remove() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
index.insert(id1, &random_vector(64, 1)).unwrap();
index.insert(id2, &random_vector(64, 2)).unwrap();
assert_eq!(index.len(), 2);
assert!(index.contains(&id1));
index.remove(&id1).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.contains(&id1));
assert!(index.contains(&id2));
}
#[test]
fn test_capacity_limit() {
let config = HnswConfig::for_dimension(64).with_max_elements(2);
let mut index = HnswIndex::new(&config);
index.insert(EmbeddingId::new(), &random_vector(64, 1)).unwrap();
index.insert(EmbeddingId::new(), &random_vector(64, 2)).unwrap();
let result = index.insert(EmbeddingId::new(), &random_vector(64, 3));
assert!(matches!(result, Err(VectorError::CapacityExceeded { .. })));
}
#[test]
fn test_save_and_load() {
let mut index = create_test_index();
let ids: Vec<_> = (0..10).map(|_| EmbeddingId::new()).collect();
for (i, id) in ids.iter().enumerate() {
index.insert(*id, &random_vector(64, i as u64)).unwrap();
}
index.build().unwrap();
let file = NamedTempFile::new().unwrap();
index.save(file.path()).unwrap();
let loaded = HnswIndex::load(file.path()).unwrap();
assert_eq!(loaded.len(), index.len());
for id in &ids {
assert!(loaded.contains(id));
}
}
#[test]
fn test_brute_force_fallback() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let v1 = random_vector(64, 1);
let v2 = random_vector(64, 2);
index.insert(id1, &v1).unwrap();
index.insert(id2, &v2).unwrap();
let results = index.search(&v1, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_get_vector() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let v = random_vector(64, 1);
index.insert(id, &v).unwrap();
let retrieved = index.get_vector(&id).unwrap();
assert_eq!(retrieved, v);
let unknown = EmbeddingId::new();
assert!(index.get_vector(&unknown).is_none());
}
#[test]
fn test_search_accuracy() {
let config = HnswConfig::for_dimension(64)
.with_max_elements(100)
.with_ef_construction(200)
.with_ef_search(128);
let mut index = HnswIndex::new(&config);
let base: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let id_base = EmbeddingId::new();
index.insert(id_base, &base).unwrap();
let similar_ids: Vec<_> = (0..5)
.map(|i| {
let id = EmbeddingId::new();
let v: Vec<f32> = base
.iter()
.map(|&x| x + 0.01 * (i as f32 + 1.0))
.collect();
index.insert(id, &v).unwrap();
id
})
.collect();
for i in 0..10 {
let id = EmbeddingId::new();
let v: Vec<f32> = (0..64).map(|j| ((i + j) % 7) as f32 / 7.0).collect();
index.insert(id, &v).unwrap();
}
index.build().unwrap();
let results = index.search(&base, 6);
assert_eq!(results[0].0, id_base);
let top_ids: std::collections::HashSet<_> =
results.iter().take(6).map(|(id, _)| *id).collect();
for similar_id in &similar_ids {
assert!(
top_ids.contains(similar_id),
"Similar vector not found in top results"
);
}
}
}