mod build;
mod persist;
mod safety;
mod search;
pub use persist::{verify_hnsw_checksums, HNSW_ALL_EXTENSIONS};
use std::cell::UnsafeCell;
use hnsw_rs::anndists::dist::distances::DistCosine;
use hnsw_rs::api::AnnT;
use hnsw_rs::hnsw::Hnsw;
use hnsw_rs::hnswio::HnswIo;
use self_cell::self_cell;
use thiserror::Error;
use crate::embedder::Embedding;
use crate::index::{IndexResult, VectorIndex};
pub(crate) const MAX_LAYER: usize = 16;
const DEFAULT_M: usize = 24;
const DEFAULT_EF_CONSTRUCTION: usize = 200;
const DEFAULT_EF_SEARCH: usize = 100;
pub(crate) fn max_nb_connection() -> usize {
let m: usize = std::env::var("CQS_HNSW_M")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_M);
if m != DEFAULT_M {
tracing::info!(m, "CQS_HNSW_M override active");
}
m
}
pub(crate) fn ef_construction() -> usize {
let ef: usize = std::env::var("CQS_HNSW_EF_CONSTRUCTION")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_EF_CONSTRUCTION);
if ef != DEFAULT_EF_CONSTRUCTION {
tracing::info!(ef, "CQS_HNSW_EF_CONSTRUCTION override active");
}
ef
}
pub(crate) fn ef_search() -> usize {
let ef: usize = std::env::var("CQS_HNSW_EF_SEARCH")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_EF_SEARCH);
if ef != DEFAULT_EF_SEARCH {
tracing::info!(ef, "CQS_HNSW_EF_SEARCH override active");
}
ef
}
#[derive(Error, Debug)]
pub enum HnswError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("HNSW index not found at {0}")]
NotFound(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Build error: {0}")]
Build(String),
#[error("HNSW error: {0}")]
Internal(String),
#[error(
"Checksum mismatch for {file}: expected {expected}, got {actual}. Index may be corrupted."
)]
ChecksumMismatch {
file: String,
expected: String,
actual: String,
},
}
type HnswGraph<'a> = Hnsw<'a, f32, DistCosine>;
pub(crate) struct HnswIoCell(pub(crate) UnsafeCell<HnswIo>);
unsafe impl Send for HnswIoCell {}
unsafe impl Sync for HnswIoCell {}
self_cell!(
pub(crate) struct LoadedHnsw {
owner: Box<HnswIoCell>,
#[not_covariant]
dependent: HnswGraph,
}
);
unsafe impl Send for LoadedHnsw {}
unsafe impl Sync for LoadedHnsw {}
pub struct HnswIndex {
pub(crate) inner: HnswInner,
pub(crate) id_map: Vec<String>,
pub(crate) ef_search: usize,
pub(crate) dim: usize,
pub(crate) _lock_file: Option<std::fs::File>,
}
pub(crate) enum HnswInner {
Owned(Hnsw<'static, f32, DistCosine>),
Loaded(LoadedHnsw),
}
impl HnswInner {
pub(crate) fn with_hnsw<R>(&self, f: impl FnOnce(&Hnsw<'_, f32, DistCosine>) -> R) -> R {
match self {
HnswInner::Owned(hnsw) => f(hnsw),
HnswInner::Loaded(loaded) => loaded.with_dependent(|_, hnsw| f(hnsw)),
}
}
}
impl HnswIndex {
pub fn set_ef_search(&mut self, ef: usize) {
self.ef_search = ef;
}
pub fn len(&self) -> usize {
self.id_map.len()
}
pub fn is_empty(&self) -> bool {
self.id_map.is_empty()
}
pub fn insert_batch(&mut self, items: &[(String, &[f32])]) -> Result<usize, HnswError> {
let _span = tracing::info_span!("hnsw_insert_batch", count = items.len()).entered();
if items.is_empty() {
return Ok(0);
}
let hnsw = match &mut self.inner {
HnswInner::Owned(h) => h,
HnswInner::Loaded(_) => {
return Err(HnswError::Internal(
"Cannot incrementally insert into loaded index; rebuild required".into(),
));
}
};
for (id, emb) in items {
if emb.len() != self.dim {
return Err(HnswError::DimensionMismatch {
expected: self.dim,
actual: emb.len(),
});
}
tracing::trace!("Inserting {} into HNSW index", id);
}
let base_idx = self.id_map.len();
for (id, _) in items {
self.id_map.push(id.clone());
}
let owned_vecs: Vec<Vec<f32>> = items.iter().map(|(_, emb)| emb.to_vec()).collect();
let data_for_insert: Vec<(&Vec<f32>, usize)> = owned_vecs
.iter()
.enumerate()
.map(|(i, v)| (v, base_idx + i))
.collect();
hnsw.parallel_insert_data(&data_for_insert);
tracing::info!(
inserted = items.len(),
total = self.id_map.len(),
"HNSW batch insert complete"
);
Ok(items.len())
}
}
pub(crate) fn prepare_index_data(
embeddings: Vec<(String, crate::Embedding)>,
expected_dim: usize,
) -> Result<(Vec<String>, Vec<f32>, usize), HnswError> {
let n = embeddings.len();
if n == 0 {
return Err(HnswError::Build("No embeddings to index".into()));
}
for (id, emb) in &embeddings {
if emb.len() != expected_dim {
return Err(HnswError::Build(format!(
"Embedding dimension mismatch for {}: got {}, expected {}",
id,
emb.len(),
expected_dim
)));
}
}
let mut id_map = Vec::with_capacity(n);
let cap = n
.checked_mul(expected_dim)
.ok_or_else(|| HnswError::Build("embedding count * dimension would overflow".into()))?;
let mut data = Vec::with_capacity(cap);
for (chunk_id, embedding) in embeddings {
id_map.push(chunk_id);
data.extend(embedding.into_inner());
}
Ok((id_map, data, n))
}
impl VectorIndex for HnswIndex {
fn search(&self, query: &Embedding, k: usize) -> Vec<IndexResult> {
self.search(query, k)
}
fn search_with_filter(
&self,
query: &Embedding,
k: usize,
filter: &dyn Fn(&str) -> bool,
) -> Vec<IndexResult> {
self.search_filtered(query, k, filter)
}
fn len(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.is_empty()
}
fn name(&self) -> &'static str {
"HNSW"
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
pub(crate) fn make_test_embedding(seed: u32) -> Embedding {
let mut v = vec![0.0f32; crate::EMBEDDING_DIM];
for (i, val) in v.iter_mut().enumerate() {
*val = ((seed as f32 * 0.1) + (i as f32 * 0.001)).sin();
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut v {
*val /= norm;
}
}
Embedding::new(v)
}
#[cfg(test)]
mod send_sync_tests {
use super::*;
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
#[test]
fn test_hnsw_index_is_send_sync() {
assert_send::<HnswIndex>();
assert_sync::<HnswIndex>();
}
#[test]
fn test_loaded_hnsw_is_send_sync() {
assert_send::<LoadedHnsw>();
assert_sync::<LoadedHnsw>();
}
}
#[cfg(test)]
mod insert_batch_tests {
use super::*;
use crate::hnsw::make_test_embedding;
use crate::EMBEDDING_DIM;
#[test]
fn test_insert_batch_on_owned() {
let embeddings: Vec<(String, Embedding)> = (0..5)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let mut index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let initial_len = index.len();
assert_eq!(initial_len, 5);
let new_embeddings: Vec<(String, Embedding)> = (5..8)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let refs: Vec<(String, &[f32])> = new_embeddings
.iter()
.map(|(id, emb)| (id.clone(), emb.as_slice()))
.collect();
let inserted = index.insert_batch(&refs).unwrap();
assert_eq!(inserted, 3);
assert_eq!(index.len(), initial_len + 3);
let query = make_test_embedding(6);
let results = index.search(&query, 3);
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.id == "chunk_6"));
}
#[test]
fn test_insert_batch_empty() {
let embeddings: Vec<(String, Embedding)> = (0..3)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let mut index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let initial_len = index.len();
let inserted = index.insert_batch(&[]).unwrap();
assert_eq!(inserted, 0);
assert_eq!(index.len(), initial_len);
}
#[test]
fn test_insert_batch_on_loaded_fails() {
let embeddings: Vec<(String, Embedding)> = (0..3)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let dir = tempfile::tempdir().unwrap();
index.save(dir.path(), "test").unwrap();
let mut loaded =
HnswIndex::load_with_dim(dir.path(), "test", crate::EMBEDDING_DIM).unwrap();
let new_emb = make_test_embedding(10);
let items = vec![("new_chunk".to_string(), new_emb.as_slice())];
let result = loaded.insert_batch(&items);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Cannot incrementally insert"),
"Expected 'Cannot incrementally insert' error, got: {}",
err
);
}
#[test]
fn test_insert_batch_dimension_mismatch() {
let embeddings: Vec<(String, Embedding)> = (0..3)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let mut index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let bad_vec = vec![1.0f32; 10]; let items = vec![("bad".to_string(), bad_vec.as_slice())];
let result = index.insert_batch(&items);
assert!(result.is_err());
match result.unwrap_err() {
HnswError::DimensionMismatch { expected, actual } => {
assert_eq!(expected, EMBEDDING_DIM);
assert_eq!(actual, 10);
}
other => panic!("Expected DimensionMismatch, got: {}", other),
}
}
#[test]
fn test_insert_batch_dim_mismatch_leaves_id_map_untouched() {
let embeddings: Vec<(String, Embedding)> = (0..3)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let mut index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let before = index.len();
let bad_vec = vec![1.0f32; 10];
let items = vec![("bad".to_string(), bad_vec.as_slice())];
let _ = index.insert_batch(&items);
assert_eq!(
index.len(),
before,
"id_map must not grow when insert fails validation"
);
}
#[test]
fn test_insert_batch_monotonic_base_idx() {
let embeddings: Vec<(String, Embedding)> = (0..3)
.map(|i| (format!("chunk_{}", i), make_test_embedding(i)))
.collect();
let mut index = HnswIndex::build_with_dim(embeddings, crate::EMBEDDING_DIM).unwrap();
let after_build = index.len();
let batch_a: Vec<(String, Embedding)> = (3..5)
.map(|i| (format!("a{}", i), make_test_embedding(i)))
.collect();
let refs_a: Vec<(String, &[f32])> = batch_a
.iter()
.map(|(id, emb)| (id.clone(), emb.as_slice()))
.collect();
index.insert_batch(&refs_a).unwrap();
let after_a = index.len();
assert_eq!(after_a, after_build + 2);
let batch_b: Vec<(String, Embedding)> = (5..8)
.map(|i| (format!("b{}", i), make_test_embedding(i)))
.collect();
let refs_b: Vec<(String, &[f32])> = batch_b
.iter()
.map(|(id, emb)| (id.clone(), emb.as_slice()))
.collect();
index.insert_batch(&refs_b).unwrap();
let after_b = index.len();
assert_eq!(after_b, after_a + 3);
let q = make_test_embedding(4);
let r = index.search(&q, 5);
assert!(r.iter().any(|n| n.id == "a4"), "a4 should be findable");
let q = make_test_embedding(6);
let r = index.search(&q, 5);
assert!(r.iter().any(|n| n.id == "b6"), "b6 should be findable");
}
}
#[cfg(test)]
mod env_override_tests {
use std::sync::Mutex;
static ENV_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_m_default() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_HNSW_M");
assert_eq!(super::max_nb_connection(), 24);
}
#[test]
fn test_m_override() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_M", "32");
assert_eq!(super::max_nb_connection(), 32);
std::env::remove_var("CQS_HNSW_M");
}
#[test]
fn test_m_invalid_falls_back() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_M", "not_a_number");
assert_eq!(super::max_nb_connection(), 24);
std::env::remove_var("CQS_HNSW_M");
}
#[test]
fn test_ef_construction_default() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_HNSW_EF_CONSTRUCTION");
assert_eq!(super::ef_construction(), 200);
}
#[test]
fn test_ef_construction_override() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_EF_CONSTRUCTION", "400");
assert_eq!(super::ef_construction(), 400);
std::env::remove_var("CQS_HNSW_EF_CONSTRUCTION");
}
#[test]
fn test_ef_construction_invalid_falls_back() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_EF_CONSTRUCTION", "xyz");
assert_eq!(super::ef_construction(), 200);
std::env::remove_var("CQS_HNSW_EF_CONSTRUCTION");
}
#[test]
fn test_ef_search_default() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_HNSW_EF_SEARCH");
assert_eq!(super::ef_search(), 100);
}
#[test]
fn test_ef_search_override() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_EF_SEARCH", "250");
assert_eq!(super::ef_search(), 250);
std::env::remove_var("CQS_HNSW_EF_SEARCH");
}
#[test]
fn test_ef_search_invalid_falls_back() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_HNSW_EF_SEARCH", "");
assert_eq!(super::ef_search(), 100);
std::env::remove_var("CQS_HNSW_EF_SEARCH");
}
}