use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
pub use hnsw_rs::prelude::{DistCosine, Hnsw, Neighbour};
pub struct HNSWIndex {
hnsw: Hnsw<'static, f32, DistCosine>,
id_map: HashMap<usize, String>,
reverse_map: HashMap<String, usize>,
deleted: HashSet<usize>,
next_id: usize,
dimension: usize,
params: HNSWParams,
count: usize,
max_elements: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HNSWParams {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub max_elements: usize,
pub max_layer: usize,
pub quantized: bool,
}
impl Default for HNSWParams {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
ef_search: 50,
max_elements: 100_000,
max_layer: 16,
quantized: false,
}
}
}
impl HNSWParams {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self
}
#[must_use]
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
#[must_use]
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
#[must_use]
pub fn with_max_elements(mut self, max: usize) -> Self {
self.max_elements = max;
self
}
#[must_use]
pub fn with_max_layer(mut self, max: usize) -> Self {
self.max_layer = max;
self
}
pub fn validate(&self) -> Result<(), IndexError> {
if self.m == 0 {
return Err(IndexError::InvalidParameter("m must be > 0".to_string()));
}
if self.ef_construction < self.m {
return Err(IndexError::InvalidParameter(
"ef_construction must be >= m".to_string(),
));
}
if self.ef_search == 0 {
return Err(IndexError::InvalidParameter(
"ef_search must be > 0".to_string(),
));
}
if self.max_elements == 0 {
return Err(IndexError::InvalidParameter(
"max_elements must be > 0".to_string(),
));
}
if self.max_layer == 0 {
return Err(IndexError::InvalidParameter(
"max_layer must be > 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RebuildStats {
pub active: usize,
pub deleted: usize,
pub duration_ms: u64,
}
impl HNSWIndex {
pub fn new(dimension: usize) -> Self {
Self::with_params(dimension, HNSWParams::default())
}
pub fn with_params(dimension: usize, params: HNSWParams) -> Self {
params.validate().unwrap_or_else(|e| {
tracing::warn!("Invalid HNSW params, using defaults: {:?}", e);
});
let hnsw = Hnsw::new(
params.m,
params.max_elements,
params.max_layer,
params.ef_construction,
DistCosine {},
);
let max_elements = params.max_elements;
Self {
hnsw,
id_map: HashMap::new(),
reverse_map: HashMap::new(),
deleted: HashSet::new(),
next_id: 0,
dimension,
params,
count: 0,
max_elements,
}
}
pub fn insert(&mut self, node_id: String, embedding: Vec<f32>) -> Result<(), IndexError> {
if embedding.len() != self.dimension {
return Err(IndexError::DimensionMismatch {
expected: self.dimension,
got: embedding.len(),
});
}
if self.reverse_map.contains_key(&node_id) {
return Err(IndexError::NodeExists(node_id));
}
let internal_id = self.next_id;
self.next_id += 1;
self.hnsw.insert((&embedding, internal_id));
self.id_map.insert(internal_id, node_id.clone());
self.reverse_map.insert(node_id, internal_id);
self.count += 1;
Ok(())
}
pub fn insert_batch(&mut self, vectors: impl IntoIterator<Item = (String, Vec<f32>)>) -> usize {
let mut inserted = 0;
for (node_id, embedding) in vectors {
if self.insert(node_id, embedding).is_ok() {
inserted += 1;
}
}
inserted
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if query.len() != self.dimension {
return Vec::new();
}
if self.count == 0 {
return Vec::new();
}
let ef_search = self.params.ef_search.max(top_k);
let results = self.hnsw.search(query, top_k, ef_search);
let mut output = Vec::new();
for neighbour in results.into_iter() {
let internal_id = neighbour.d_id;
let dist = neighbour.distance;
if self.deleted.contains(&internal_id) {
continue;
}
if let Some(node_id) = self.id_map.get(&internal_id) {
let similarity = 1.0 - dist;
let similarity = similarity.max(0.0);
output.push((node_id.clone(), similarity));
}
}
output.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
output
}
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn remove(&mut self, node_id: &str) -> bool {
if let Some(internal_id) = self.reverse_map.remove(node_id) {
self.id_map.remove(&internal_id);
self.deleted.insert(internal_id);
self.count -= 1;
true
} else {
false
}
}
pub fn clear(&mut self) {
self.hnsw = Hnsw::new(
self.params.m,
self.max_elements,
self.params.max_layer,
self.params.ef_construction,
DistCosine {},
);
self.id_map.clear();
self.reverse_map.clear();
self.deleted.clear();
self.next_id = 0;
self.count = 0;
}
pub fn rebuild(&mut self) -> Result<RebuildStats, IndexError> {
let active_count = self.count;
let deleted_count = self.deleted.len();
if deleted_count == 0 {
return Ok(RebuildStats {
active: active_count,
deleted: deleted_count,
duration_ms: 0,
});
}
let start = std::time::Instant::now();
tracing::warn!(
"HNSW rebuild called with {} deleted nodes. Note: Full rebuild requires external embedding store.",
deleted_count
);
self.hnsw = Hnsw::new(
self.params.m,
self.max_elements,
self.params.max_layer,
self.params.ef_construction,
DistCosine {},
);
let stats = RebuildStats {
active: active_count,
deleted: deleted_count,
duration_ms: start.elapsed().as_millis() as u64,
};
tracing::info!(
"HNSW rebuild complete: {} active nodes, {} deleted nodes removed in {}ms",
stats.active,
stats.deleted,
stats.duration_ms
);
self.deleted.clear();
Ok(stats)
}
pub fn get(&self, _node_id: &str) -> Option<&Vec<f32>> {
None
}
#[must_use]
pub fn params(&self) -> &HNSWParams {
&self.params
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
self.count * self.params.m * self.dimension * 4 + self.count * self.dimension * 4 + self.id_map.len() * (std::mem::size_of::<usize>() + std::mem::size_of::<String>()) +
self.reverse_map.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>())
}
}
impl Default for HNSWIndex {
fn default() -> Self {
Self::new(768) }
}
#[derive(Debug, Error)]
pub enum IndexError {
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
#[error("Node {0} already exists")]
NodeExists(String),
#[error("Node {0} not found")]
NodeNotFound(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Insertion failed: {0}")]
InsertionFailed(String),
#[error("Serialization failed: {0}")]
SerializationFailed(String),
#[error("Deserialization failed: {0}")]
DeserializationFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hnsw_index_creation() {
let index = HNSWIndex::new(3);
assert_eq!(index.dimension(), 3);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_hnsw_index_insert() {
let mut index = HNSWIndex::new(3);
let result = index.insert("test".to_string(), vec![0.1, 0.2, 0.3]);
assert!(result.is_ok());
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_hnsw_index_dimension_mismatch() {
let mut index = HNSWIndex::new(3);
let result = index.insert("test".to_string(), vec![0.1, 0.2]);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
IndexError::DimensionMismatch { .. }
));
}
#[test]
fn test_hnsw_index_duplicate_insert() {
let mut index = HNSWIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
let result = index.insert("test".to_string(), vec![0.4, 0.5, 0.6]);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), IndexError::NodeExists(_)));
}
#[test]
fn test_hnsw_search() {
let mut index = HNSWIndex::new(3);
index.insert("a".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.insert("b".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
index.insert("c".to_string(), vec![0.9, 0.1, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2);
assert!(!results.is_empty());
assert_eq!(results[0].0, "a");
assert!(results[0].1 > 0.9);
}
#[test]
fn test_hnsw_search_empty_index() {
let index = HNSWIndex::new(3);
let query = vec![0.1, 0.2, 0.3];
let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_hnsw_batch_insert() {
let mut index = HNSWIndex::new(3);
let vectors = vec![
("a".to_string(), vec![1.0, 0.0, 0.0]),
("b".to_string(), vec![0.0, 1.0, 0.0]),
("c".to_string(), vec![0.0, 0.0, 1.0]),
];
let inserted = index.insert_batch(vectors);
assert_eq!(inserted, 3);
assert_eq!(index.len(), 3);
}
#[test]
fn test_hnsw_remove() {
let mut index = HNSWIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
assert_eq!(index.len(), 1);
assert!(index.remove("test"));
assert_eq!(index.len(), 0);
assert!(!index.remove("nonexistent"));
}
#[test]
fn test_hnsw_clear() {
let mut index = HNSWIndex::new(3);
index.insert("a".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.insert("b".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
assert_eq!(index.len(), 2);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_hnsw_params_default() {
let params = HNSWParams::default();
assert_eq!(params.m, 16);
assert_eq!(params.ef_construction, 200);
assert_eq!(params.ef_search, 50);
}
#[test]
fn test_hnsw_params_builder() {
let params = HNSWParams::new()
.with_m(32)
.with_ef_construction(400)
.with_ef_search(100);
assert_eq!(params.m, 32);
assert_eq!(params.ef_construction, 400);
assert_eq!(params.ef_search, 100);
}
#[test]
fn test_hnsw_params_validation() {
let params = HNSWParams::default();
assert!(params.validate().is_ok());
let params = HNSWParams {
m: 0,
..Default::default()
};
assert!(params.validate().is_err());
let params = HNSWParams {
m: 100,
ef_construction: 50,
..Default::default()
};
assert!(params.validate().is_err());
let params = HNSWParams {
ef_search: 0,
..Default::default()
};
assert!(params.validate().is_err());
}
#[test]
fn test_hnsw_custom_params() {
let params = HNSWParams {
m: 8,
ef_construction: 100,
ef_search: 25,
max_elements: 50000,
max_layer: 12,
..Default::default()
};
let index = HNSWIndex::with_params(3, params);
assert_eq!(index.params().m, 8);
assert_eq!(index.params().ef_construction, 100);
assert_eq!(index.params().ef_search, 25);
assert_eq!(index.params().max_elements, 50000);
assert_eq!(index.params().max_layer, 12);
}
#[test]
fn test_hnsw_large_scale() {
let mut index = HNSWIndex::new(128);
for i in 0..1000 {
let vector: Vec<f32> = (0..128).map(|_| rand::random::<f32>()).collect();
index.insert(format!("node_{}", i), vector).unwrap();
}
assert_eq!(index.len(), 1000);
let query: Vec<f32> = (0..128).map(|_| rand::random::<f32>()).collect();
let results = index.search(&query, 10);
assert_eq!(results.len(), 10);
}
#[test]
fn test_hnsw_get_returns_none() {
let mut index = HNSWIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
assert!(index.get("test").is_none());
}
#[test]
fn test_hnsw_rebuild() {
let mut index = HNSWIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
assert!(index.rebuild().is_ok());
}
#[test]
fn test_hnsw_estimated_memory() {
let mut index = HNSWIndex::new(768);
index.insert("test".to_string(), vec![0.0; 768]).unwrap();
let memory = index.estimated_memory_bytes();
assert!(memory > 0);
let min_expected = 768 * 4; assert!(memory >= min_expected);
}
#[test]
fn test_hnsw_with_custom_dimension() {
let index = HNSWIndex::new(256);
assert_eq!(index.dimension(), 256);
let mut index = HNSWIndex::new(256);
index.insert("test".to_string(), vec![0.0; 256]).unwrap();
assert_eq!(index.len(), 1);
}
}