use codemem_core::{CodememError, VectorBackend, VectorConfig, VectorStats};
use std::collections::HashMap;
use std::path::Path;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
pub struct HnswIndex {
index: Index,
config: VectorConfig,
id_to_key: HashMap<String, u64>,
key_to_id: HashMap<u64, String>,
next_key: u64,
}
impl HnswIndex {
pub fn new(config: VectorConfig) -> Result<Self, CodememError> {
let metric = match config.metric {
codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
};
let options = IndexOptions {
dimensions: config.dimensions,
metric,
quantization: ScalarKind::F32,
connectivity: config.m,
expansion_add: config.ef_construction,
expansion_search: config.ef_search,
multi: false,
};
let index = Index::new(&options).map_err(|e| CodememError::Vector(e.to_string()))?;
index
.reserve(10_000)
.map_err(|e| CodememError::Vector(e.to_string()))?;
Ok(Self {
index,
config,
id_to_key: HashMap::new(),
key_to_id: HashMap::new(),
next_key: 0,
})
}
pub fn with_defaults() -> Result<Self, CodememError> {
Self::new(VectorConfig::default())
}
pub fn len(&self) -> usize {
self.index.size()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn allocate_key(&mut self) -> u64 {
let key = self.next_key;
self.next_key += 1;
key
}
}
impl VectorBackend for HnswIndex {
fn insert(&mut self, id: &str, embedding: &[f32]) -> Result<(), CodememError> {
if embedding.len() != self.config.dimensions {
return Err(CodememError::Vector(format!(
"Expected {} dimensions, got {}",
self.config.dimensions,
embedding.len()
)));
}
if let Some(&old_key) = self.id_to_key.get(id) {
self.index
.remove(old_key)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.key_to_id.remove(&old_key);
}
let key = self.allocate_key();
if self.index.size() >= self.index.capacity() {
let new_cap = self.index.capacity() * 2;
self.index
.reserve(new_cap)
.map_err(|e| CodememError::Vector(e.to_string()))?;
}
self.index
.add(key, embedding)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.id_to_key.insert(id.to_string(), key);
self.key_to_id.insert(key, id.to_string());
Ok(())
}
fn insert_batch(&mut self, items: &[(String, Vec<f32>)]) -> Result<(), CodememError> {
for (id, embedding) in items {
self.insert(id, embedding)?;
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>, CodememError> {
if self.is_empty() {
return Ok(vec![]);
}
let results = self
.index
.search(query, k)
.map_err(|e| CodememError::Vector(e.to_string()))?;
let mut output = Vec::with_capacity(results.keys.len());
for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
if let Some(id) = self.key_to_id.get(key) {
let similarity = 1.0 - distance;
output.push((id.clone(), similarity));
}
}
Ok(output)
}
fn remove(&mut self, id: &str) -> Result<bool, CodememError> {
if let Some(key) = self.id_to_key.remove(id) {
self.index
.remove(key)
.map_err(|e| CodememError::Vector(e.to_string()))?;
self.key_to_id.remove(&key);
Ok(true)
} else {
Ok(false)
}
}
fn save(&self, path: &Path) -> Result<(), CodememError> {
self.index
.save(path.to_str().unwrap_or("hnsw.index"))
.map_err(|e| CodememError::Vector(e.to_string()))?;
let map_path = path.with_extension("idmap");
let map_data = serde_json::to_string(&IdMapping {
id_to_key: &self.id_to_key,
next_key: self.next_key,
})
.map_err(|e| CodememError::Vector(e.to_string()))?;
std::fs::write(map_path, map_data)?;
Ok(())
}
fn load(&mut self, path: &Path) -> Result<(), CodememError> {
self.index
.load(path.to_str().unwrap_or("hnsw.index"))
.map_err(|e| CodememError::Vector(e.to_string()))?;
let map_path = path.with_extension("idmap");
if map_path.exists() {
let map_data = std::fs::read_to_string(map_path)?;
let mapping: IdMappingOwned =
serde_json::from_str(&map_data).map_err(|e| CodememError::Vector(e.to_string()))?;
self.id_to_key = mapping.id_to_key;
self.key_to_id = self
.id_to_key
.iter()
.map(|(id, key)| (*key, id.clone()))
.collect();
self.next_key = mapping.next_key;
}
Ok(())
}
fn stats(&self) -> VectorStats {
VectorStats {
count: self.len(),
dimensions: self.config.dimensions,
metric: format!("{:?}", self.config.metric),
memory_bytes: self.index.memory_usage(),
}
}
}
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
struct IdMapping<'a> {
id_to_key: &'a HashMap<String, u64>,
next_key: u64,
}
#[derive(Deserialize)]
struct IdMappingOwned {
id_to_key: HashMap<String, u64>,
next_key: u64,
}
#[cfg(test)]
#[path = "tests/lib_tests.rs"]
mod tests;