use codemem_core::{CodememError, VectorBackend, VectorConfig, VectorStats};
use std::collections::HashMap;
use std::path::Path;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
pub struct HnswIndex {
index: Index,
config: VectorConfig,
id_to_key: HashMap<String, u64>,
key_to_id: HashMap<u64, String>,
next_key: u64,
ghost_count: usize,
}
impl HnswIndex {
pub fn new(config: VectorConfig) -> Result<Self, CodememError> {
let metric = match config.metric {
codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
};
let options = IndexOptions {
dimensions: config.dimensions,
metric,
quantization: ScalarKind::F32,
connectivity: config.m,
expansion_add: config.ef_construction,
expansion_search: config.ef_search,
multi: false,
};
let index = Index::new(&options).map_err(|e| CodememError::Vector(e.to_string()))?;
index
.reserve(10_000)
.map_err(|e| CodememError::Vector(e.to_string()))?;
Ok(Self {
index,
config,
id_to_key: HashMap::new(),
key_to_id: HashMap::new(),
next_key: 0,
ghost_count: 0,
})
}
pub fn with_defaults() -> Result<Self, CodememError> {
Self::new(VectorConfig::default())
}
pub fn len(&self) -> usize {
self.index.size()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn allocate_key(&mut self) -> u64 {
let key = self.next_key;
self.next_key += 1;
key
}
pub fn rebuild_from_entries(
&mut self,
entries: &[(String, Vec<f32>)],
) -> Result<(), CodememError> {
let new_index = Index::new(&IndexOptions {
dimensions: self.config.dimensions,
metric: match self.config.metric {
codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
},
quantization: ScalarKind::F32,
connectivity: self.config.m,
expansion_add: self.config.ef_construction,
expansion_search: self.config.ef_search,
multi: false,
})
.map_err(|e| CodememError::Vector(e.to_string()))?;
let capacity = entries.len().max(1024);
new_index
.reserve(capacity)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.index = new_index;
self.id_to_key.clear();
self.key_to_id.clear();
self.next_key = 0;
self.ghost_count = 0;
for (id, embedding) in entries {
self.insert(id, embedding)?;
}
Ok(())
}
pub fn needs_compaction(&self) -> bool {
let live = self.id_to_key.len();
live > 0 && self.ghost_count > live / 5
}
pub fn ghost_count(&self) -> usize {
self.ghost_count
}
}
impl VectorBackend for HnswIndex {
fn insert(&mut self, id: &str, embedding: &[f32]) -> Result<(), CodememError> {
if embedding.len() != self.config.dimensions {
return Err(CodememError::Vector(format!(
"Expected {} dimensions, got {}",
self.config.dimensions,
embedding.len()
)));
}
if let Some(&old_key) = self.id_to_key.get(id) {
self.index
.remove(old_key)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.key_to_id.remove(&old_key);
self.ghost_count += 1;
}
let key = self.allocate_key();
if self.index.size() >= self.index.capacity() {
let cap = self.index.capacity();
let new_cap = cap + 1024.max(cap / 4);
self.index
.reserve(new_cap)
.map_err(|e| CodememError::Vector(e.to_string()))?;
}
self.index
.add(key, embedding)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.id_to_key.insert(id.to_string(), key);
self.key_to_id.insert(key, id.to_string());
Ok(())
}
fn insert_batch(&mut self, items: &[(String, Vec<f32>)]) -> Result<(), CodememError> {
let needed = self.index.size() + items.len();
if needed > self.index.capacity() {
self.index
.reserve(needed)
.map_err(|e| CodememError::Vector(e.to_string()))?;
}
for (id, embedding) in items {
self.insert(id, embedding)?;
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>, CodememError> {
if self.is_empty() {
return Ok(vec![]);
}
let results = self
.index
.search(query, k)
.map_err(|e| CodememError::Vector(e.to_string()))?;
let mut output = Vec::with_capacity(results.keys.len());
for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
if let Some(id) = self.key_to_id.get(key) {
let similarity = 1.0 - distance;
output.push((id.clone(), similarity));
}
}
Ok(output)
}
fn remove(&mut self, id: &str) -> Result<bool, CodememError> {
if let Some(key) = self.id_to_key.remove(id) {
self.index
.remove(key)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.key_to_id.remove(&key);
self.ghost_count += 1;
Ok(true)
} else {
Ok(false)
}
}
fn save(&self, path: &Path) -> Result<(), CodememError> {
let path_str = path
.to_str()
.ok_or_else(|| CodememError::Vector("Path contains non-UTF-8 characters".into()))?;
let idmap_path = path.with_extension("idmap");
let map_data = serde_json::to_string(&IdMapping {
id_to_key: &self.id_to_key,
next_key: self.next_key,
})
.map_err(|e| CodememError::Vector(e.to_string()))?;
let tmp_idmap = path.with_extension("idmap.tmp");
std::fs::write(&tmp_idmap, map_data)?;
let tmp_idx = path.with_extension("idx.tmp");
let tmp_idx_str = tmp_idx.to_str().ok_or_else(|| {
CodememError::Vector("Temp path contains non-UTF-8 characters".into())
})?;
self.index
.save(tmp_idx_str)
.map_err(|e| CodememError::Vector(e.to_string()))?;
std::fs::rename(&tmp_idmap, &idmap_path)?;
std::fs::rename(&tmp_idx, path_str)?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), CodememError> {
let path_str = path
.to_str()
.ok_or_else(|| CodememError::Vector("Path contains non-UTF-8 characters".into()))?;
self.index
.load(path_str)
.map_err(|e| CodememError::Vector(e.to_string()))?;
let map_path = path.with_extension("idmap");
if map_path.exists() {
let map_data = std::fs::read_to_string(map_path)?;
let mapping: IdMappingOwned =
serde_json::from_str(&map_data).map_err(|e| CodememError::Vector(e.to_string()))?;
self.id_to_key = mapping.id_to_key;
self.key_to_id = self
.id_to_key
.iter()
.map(|(id, key)| (*key, id.clone()))
.collect();
self.next_key = mapping.next_key;
self.ghost_count = 0; }
Ok(())
}
fn stats(&self) -> VectorStats {
VectorStats {
count: self.len(),
dimensions: self.config.dimensions,
metric: format!("{:?}", self.config.metric),
memory_bytes: self.index.memory_usage(),
}
}
fn needs_compaction(&self) -> bool {
let live = self.id_to_key.len();
live > 0 && self.ghost_count > live / 5
}
fn ghost_count(&self) -> usize {
self.ghost_count
}
fn rebuild_from_entries(&mut self, entries: &[(String, Vec<f32>)]) -> Result<(), CodememError> {
HnswIndex::rebuild_from_entries(self, entries)
}
}
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
struct IdMapping<'a> {
id_to_key: &'a HashMap<String, u64>,
next_key: u64,
}
#[derive(Deserialize)]
struct IdMappingOwned {
id_to_key: HashMap<String, u64>,
next_key: u64,
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f64;
let mut norm_a = 0.0f64;
let mut norm_b = 0.0f64;
for (x, y) in a.iter().zip(b.iter()) {
let x = *x as f64;
let y = *y as f64;
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < 1e-12 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
#[path = "tests/vector_tests.rs"]
mod tests;