use chrono::Utc;
use hnsw_rs::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use tokio::sync::Mutex;
use crate::error::CoreError;
type AnnCacheKey = (String, usize);
type SharedAnnIndex = Arc<Mutex<AnnIndex>>;
type AnnCache = std::sync::Mutex<HashMap<AnnCacheKey, SharedAnnIndex>>;
const HNSW_BASENAME: &str = "hnsw";
const META_FILENAME: &str = "hnsw.meta.json";
const META_VERSION: u32 = 1;
const MAX_NB_CONNECTION: usize = 16;
const EF_CONSTRUCTION: usize = 200;
const MAX_LAYER: usize = 16;
const DEFAULT_EF_SEARCH: usize = 64;
const MAX_SEARCH_TOP_K: usize = 50;
const MAX_RAW_SEARCH_CANDIDATES: usize = 150;
const BASE_RAW_SEARCH_MULTIPLIER: usize = 3;
const MAX_STALE_SLOT_RATIO_NUMERATOR: usize = 3;
const MAX_STALE_SLOT_RATIO_DENOMINATOR: usize = 10;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AnnMeta {
version: u32,
dim: usize,
size: u32,
built_at: String,
schema_hash: String,
id_map: Vec<String>,
tombstones: Vec<String>,
}
pub struct AnnIndex {
inner: Option<Hnsw<'static, f32, DistCosine>>,
id_map: Vec<String>,
reverse: HashMap<String, usize>,
tombstones: HashSet<String>,
dim: usize,
dirty: bool,
project_hash: String,
schema_hash: String,
}
impl AnnIndex {
pub async fn load_or_empty(project_hash: &str, dim: usize) -> Result<Self, CoreError> {
let dir = crate::infra::db::project_index_dir(project_hash);
let meta_path = dir.join(META_FILENAME);
let graph_path = dir.join(format!("{HNSW_BASENAME}.hnsw.graph"));
let data_path = dir.join(format!("{HNSW_BASENAME}.hnsw.data"));
if !meta_path.exists() || !graph_path.exists() || !data_path.exists() {
return Ok(Self::empty(project_hash.to_owned(), dim));
}
let meta: AnnMeta = match std::fs::read(&meta_path) {
Ok(bytes) => match serde_json::from_slice(&bytes) {
Ok(m) => m,
Err(_) => return Ok(Self::empty(project_hash.to_owned(), dim)),
},
Err(_) => return Ok(Self::empty(project_hash.to_owned(), dim)),
};
if meta.version != META_VERSION || meta.dim != dim || !meta_sidecar_shape_is_valid(&meta) {
return Ok(Self::empty(project_hash.to_owned(), dim));
}
let reloader: &'static mut HnswIo = Box::leak(Box::new(HnswIo::new(&dir, HNSW_BASENAME)));
let Ok(hnsw) = reloader.load_hnsw::<f32, DistCosine>() else {
return Ok(Self::empty(project_hash.to_owned(), dim));
};
if hnsw.get_nb_point() != meta.id_map.len() {
return Ok(Self::empty(project_hash.to_owned(), dim));
}
let mut reverse: HashMap<String, usize> = HashMap::with_capacity(meta.id_map.len());
for (idx, id) in meta.id_map.iter().enumerate() {
reverse.insert(id.clone(), idx);
}
let tombstones: HashSet<String> = meta.tombstones.iter().cloned().collect();
Ok(Self {
inner: Some(hnsw),
id_map: meta.id_map,
reverse,
tombstones,
dim,
dirty: false,
project_hash: project_hash.to_owned(),
schema_hash: meta.schema_hash,
})
}
pub async fn build_from_chunks(
project_hash: &str,
chunks: &[(String, Vec<f32>)],
) -> Result<Self, CoreError> {
let dim = chunks
.iter()
.find(|(_, v)| !v.is_empty())
.map_or(0, |(_, v)| v.len());
if dim == 0 {
return Ok(Self::empty(project_hash.to_owned(), 0));
}
let capacity_hint = chunks.len().max(1);
let hnsw: Hnsw<'static, f32, DistCosine> = Hnsw::new(
MAX_NB_CONNECTION,
capacity_hint,
MAX_LAYER,
EF_CONSTRUCTION,
DistCosine,
);
let mut id_map: Vec<String> = Vec::with_capacity(chunks.len());
let mut reverse: HashMap<String, usize> = HashMap::with_capacity(chunks.len());
for (chunk_id, emb) in chunks {
if emb.len() != dim {
continue;
}
let internal_id = id_map.len();
hnsw.insert((emb.as_slice(), internal_id));
reverse.insert(chunk_id.clone(), internal_id);
id_map.push(chunk_id.clone());
}
let schema_hash = compute_schema_hash(chunks, dim);
Ok(Self {
inner: Some(hnsw),
id_map,
reverse,
tombstones: HashSet::new(),
dim,
dirty: true,
project_hash: project_hash.to_owned(),
schema_hash,
})
}
fn empty(project_hash: String, dim: usize) -> Self {
Self {
inner: None,
id_map: Vec::new(),
reverse: HashMap::new(),
tombstones: HashSet::new(),
dim,
dirty: false,
project_hash,
schema_hash: String::new(),
}
}
pub fn upsert(&mut self, chunk_id: &str, embedding: &[f32]) {
if embedding.is_empty() {
return;
}
if self.inner.is_none() {
if self.dim == 0 {
self.dim = embedding.len();
}
if embedding.len() != self.dim {
return;
}
self.inner = Some(Hnsw::new(
MAX_NB_CONNECTION,
64,
MAX_LAYER,
EF_CONSTRUCTION,
DistCosine,
));
}
if embedding.len() != self.dim {
return;
}
if let Some(_prev) = self.reverse.get(chunk_id) {
self.tombstones.insert(chunk_id.to_owned());
}
#[allow(clippy::expect_used)]
let hnsw = self.inner.as_ref().expect("inner set above");
let new_internal = self.id_map.len();
hnsw.insert((embedding, new_internal));
self.id_map.push(chunk_id.to_owned());
self.reverse.insert(chunk_id.to_owned(), new_internal);
self.dirty = true;
}
pub fn remove(&mut self, chunk_id: &str) {
if self.reverse.remove(chunk_id).is_some() {
self.tombstones.insert(chunk_id.to_owned());
self.dirty = true;
}
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if top_k == 0 || query.is_empty() || self.id_map.is_empty() {
return Vec::new();
}
let top_k = top_k.min(MAX_SEARCH_TOP_K);
if query.len() != self.dim {
return Vec::new();
}
let Some(hnsw) = self.inner.as_ref() else {
return Vec::new();
};
let raw_k = self.raw_search_candidate_count(top_k);
let ef = DEFAULT_EF_SEARCH.max(top_k.saturating_mul(2));
let raw = hnsw.search(query, raw_k, ef);
let mut out = Vec::with_capacity(top_k);
let mut seen: HashSet<&str> = HashSet::new();
for n in raw {
let internal_id = n.d_id;
let Some(chunk_id) = self.id_map.get(internal_id) else {
continue;
};
if self.tombstones.contains(chunk_id) {
if let Some(¤t) = self.reverse.get(chunk_id) {
if current != internal_id {
continue;
}
} else {
continue;
}
}
if !seen.insert(chunk_id.as_str()) {
continue;
}
out.push((chunk_id.clone(), n.distance));
if out.len() >= top_k {
break;
}
}
out
}
pub async fn save(&mut self) -> Result<(), CoreError> {
let Some(hnsw) = self.inner.as_ref() else {
let dir = crate::infra::db::project_index_dir(&self.project_hash);
std::fs::create_dir_all(&dir)?;
self.write_meta(&dir)?;
self.dirty = false;
return Ok(());
};
let dir = crate::infra::db::project_index_dir(&self.project_hash);
std::fs::create_dir_all(&dir)?;
hnsw.file_dump(&dir, HNSW_BASENAME)
.map_err(|e| CoreError::Internal(format!("hnsw file_dump failed: {e}")))?;
self.write_meta(&dir)?;
self.dirty = false;
Ok(())
}
fn write_meta(&self, dir: &Path) -> Result<(), CoreError> {
let meta = AnnMeta {
version: META_VERSION,
dim: self.dim,
size: u32::try_from(self.id_map.len()).unwrap_or(u32::MAX),
built_at: Utc::now().to_rfc3339(),
schema_hash: self.schema_hash.clone(),
id_map: self.id_map.clone(),
tombstones: self.tombstones.iter().cloned().collect(),
};
let bytes = serde_json::to_vec_pretty(&meta)?;
std::fs::write(dir.join(META_FILENAME), bytes)?;
Ok(())
}
pub const fn is_dirty(&self) -> bool {
self.dirty
}
pub fn live_size(&self) -> u32 {
u32::try_from(self.reverse.len()).unwrap_or(u32::MAX)
}
pub fn total_size(&self) -> u32 {
u32::try_from(self.id_map.len()).unwrap_or(u32::MAX)
}
pub fn needs_compaction(&self) -> bool {
let total = self.id_map.len();
total > 0
&& self
.stale_slot_count()
.saturating_mul(MAX_STALE_SLOT_RATIO_DENOMINATOR)
>= total.saturating_mul(MAX_STALE_SLOT_RATIO_NUMERATOR)
}
pub const fn dim(&self) -> usize {
self.dim
}
fn stale_slot_count(&self) -> usize {
self.id_map.len().saturating_sub(self.reverse.len())
}
fn raw_search_candidate_count(&self, top_k: usize) -> usize {
let total = self.id_map.len();
if total == 0 {
return 0;
}
let live = self.reverse.len().max(1);
let base = top_k.saturating_mul(BASE_RAW_SEARCH_MULTIPLIER);
let scaled = base.saturating_mul(total).div_ceil(live);
scaled
.max(top_k.min(total))
.min(MAX_RAW_SEARCH_CANDIDATES)
.min(total)
}
}
fn meta_sidecar_shape_is_valid(meta: &AnnMeta) -> bool {
if usize::try_from(meta.size).ok() != Some(meta.id_map.len()) {
return false;
}
let known_ids: HashSet<&str> = meta.id_map.iter().map(String::as_str).collect();
meta.tombstones
.iter()
.all(|id| known_ids.contains(id.as_str()))
}
fn compute_schema_hash(chunks: &[(String, Vec<f32>)], dim: usize) -> String {
use sha1::{Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(dim.to_le_bytes());
for (id, _) in chunks.iter().take(64) {
hasher.update(id.as_bytes());
hasher.update([0]);
}
let out = hasher.finalize();
let mut hex = String::with_capacity(12);
for b in out.iter().take(6) {
hex.push_str(&format!("{b:02x}"));
}
hex
}
fn ann_cache() -> &'static AnnCache {
static CACHE: OnceLock<AnnCache> = OnceLock::new();
CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()))
}
pub async fn get_ann_for_project(
project_hash: &str,
dim: usize,
) -> Result<Arc<Mutex<AnnIndex>>, CoreError> {
{
#[allow(clippy::expect_used)]
let guard = ann_cache().lock().expect("ann cache mutex poisoned");
let key = (project_hash.to_owned(), dim);
if let Some(existing) = guard.get(&key) {
return Ok(Arc::clone(existing));
}
}
let loaded = AnnIndex::load_or_empty(project_hash, dim).await?;
let arc = Arc::new(Mutex::new(loaded));
#[allow(clippy::expect_used)]
let mut guard = ann_cache().lock().expect("ann cache mutex poisoned");
let key = (project_hash.to_owned(), dim);
let entry = guard.entry(key).or_insert(arc);
Ok(Arc::clone(entry))
}
#[cfg(test)]
pub fn invalidate_cache(project_hash: &str) {
#[allow(clippy::expect_used)]
let mut guard = ann_cache().lock().expect("ann cache mutex poisoned");
guard.retain(|(cached_project, _), _| cached_project != project_hash);
}
pub fn ann_files_for_project(project_hash: &str) -> (PathBuf, PathBuf, PathBuf) {
let dir = crate::infra::db::project_index_dir(project_hash);
(
dir.join(format!("{HNSW_BASENAME}.hnsw.graph")),
dir.join(format!("{HNSW_BASENAME}.hnsw.data")),
dir.join(META_FILENAME),
)
}
#[cfg(test)]
mod tests {
use super::*;
fn unique_hash(tag: &str) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("{tag}-{nanos}")
}
fn random_vec(seed: u64, dim: usize) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut v = Vec::with_capacity(dim);
for i in 0..dim {
let mut h = DefaultHasher::new();
(seed, i).hash(&mut h);
let raw = h.finish();
let x = ((raw as i64) as f64) / (i64::MAX as f64);
v.push(x as f32);
}
let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if n > 0.0 {
for x in &mut v {
*x /= n;
}
}
v
}
#[tokio::test]
async fn empty_index_search_returns_empty() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("empty-search");
let idx = AnnIndex::load_or_empty(&hash, 16).await.unwrap();
assert_eq!(idx.total_size(), 0);
let hits = idx.search(&[0.1f32; 16], 5);
assert!(hits.is_empty());
}
#[tokio::test]
async fn build_and_search_returns_nearest() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("build-search");
let dim = 32;
let mut chunks: Vec<(String, Vec<f32>)> = Vec::new();
for i in 0..50 {
chunks.push((format!("c{i}"), random_vec(i as u64, dim)));
}
let target_seed = 99u64;
let target = random_vec(target_seed, dim);
chunks.push(("target".to_owned(), target.clone()));
let idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
assert_eq!(idx.total_size(), 51);
let hits = idx.search(&target, 5);
assert!(!hits.is_empty());
assert_eq!(hits[0].0, "target");
assert!(
hits[0].1 < 1e-3,
"self-match distance should be near zero, got {}",
hits[0].1
);
}
#[tokio::test]
async fn save_and_load_roundtrip() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("roundtrip");
let dim = 24;
let mut chunks: Vec<(String, Vec<f32>)> = Vec::new();
for i in 0..20 {
chunks.push((format!("c{i}"), random_vec(i as u64, dim)));
}
let query = random_vec(7, dim);
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
let before = idx.search(&query, 5);
assert!(!before.is_empty());
idx.save().await.unwrap();
assert!(!idx.is_dirty());
invalidate_cache(&hash);
let reloaded = AnnIndex::load_or_empty(&hash, dim).await.unwrap();
let after = reloaded.search(&query, 5);
assert_eq!(before.len(), after.len());
assert_eq!(before[0].0, after[0].0);
}
#[tokio::test]
async fn upsert_replaces_existing_chunk_embedding() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("upsert");
let dim = 8;
let vec_a = random_vec(1, dim);
let vec_b = random_vec(2, dim);
let mut idx = AnnIndex::build_from_chunks(
&hash,
&[
("id1".to_owned(), vec_a.clone()),
("neighbor".to_owned(), random_vec(3, dim)),
],
)
.await
.unwrap();
let hits1 = idx.search(&vec_a, 2);
assert_eq!(hits1[0].0, "id1");
idx.upsert("id1", &vec_b);
let hits_a = idx.search(&vec_a, 2);
if !hits_a.is_empty() && hits_a[0].0 == "id1" {
let current = idx.reverse.get("id1").copied();
assert!(current.is_some());
}
let hits_b = idx.search(&vec_b, 2);
assert!(
hits_b.iter().any(|(id, _)| id == "id1"),
"upserted chunk must be searchable via new vector"
);
}
#[tokio::test]
async fn remove_drops_from_search_results() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("remove");
let dim = 12;
let vec_a = random_vec(10, dim);
let mut idx = AnnIndex::build_from_chunks(
&hash,
&[
("doomed".to_owned(), vec_a.clone()),
("keep".to_owned(), random_vec(11, dim)),
],
)
.await
.unwrap();
let before = idx.search(&vec_a, 2);
assert!(before.iter().any(|(id, _)| id == "doomed"));
idx.remove("doomed");
let after = idx.search(&vec_a, 2);
assert!(
!after.iter().any(|(id, _)| id == "doomed"),
"removed chunk must not appear in search results"
);
}
#[tokio::test]
async fn dim_mismatch_triggers_fallback() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("dim-mismatch");
let dim = 16;
let mut chunks: Vec<(String, Vec<f32>)> = Vec::new();
for i in 0..5 {
chunks.push((format!("c{i}"), random_vec(i as u64, dim)));
}
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
idx.save().await.unwrap();
invalidate_cache(&hash);
let reloaded = AnnIndex::load_or_empty(&hash, 32).await.unwrap();
assert_eq!(reloaded.total_size(), 0, "dim drift must reset the index");
let hits = reloaded.search(&random_vec(0, 32), 5);
assert!(hits.is_empty());
}
#[tokio::test]
async fn search_recall_at_top_10() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("recall10");
let dim = 32;
let mut chunks: Vec<(String, Vec<f32>)> = Vec::new();
for i in 0..1000 {
chunks.push((format!("c{i}"), random_vec(i as u64, dim)));
}
let target_idx = 250u64;
let mut target_query = random_vec(target_idx, dim);
for (i, x) in target_query.iter_mut().enumerate() {
*x = 0.01f32.mul_add(random_vec(i as u64 + 5000, 1)[0], *x);
}
let n: f32 = target_query.iter().map(|x| x * x).sum::<f32>().sqrt();
if n > 0.0 {
for x in &mut target_query {
*x /= n;
}
}
let idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
let hits = idx.search(&target_query, 10);
let ids: Vec<_> = hits.iter().map(|(id, _)| id.as_str()).collect();
assert!(
ids.contains(&format!("c{target_idx}").as_str()),
"target chunk must be in top-10 (got {ids:?})"
);
}
#[tokio::test]
async fn search_caps_large_top_k_requests() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("cap-top-k");
let chunks: Vec<(String, Vec<f32>)> = (0..80)
.map(|i| (format!("id-{i:02}"), vec![i as f32, 1.0]))
.collect();
let idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
let hits = idx.search(&[1.0, 1.0], 500);
assert!(
hits.len() <= MAX_SEARCH_TOP_K,
"ANN search should cap oversized requests, got {} hits",
hits.len()
);
}
#[tokio::test]
async fn persistence_meta_version_bump_invalidates() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("meta-bump");
let dim = 8;
let chunks = vec![
("a".to_owned(), random_vec(1, dim)),
("b".to_owned(), random_vec(2, dim)),
];
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
idx.save().await.unwrap();
let (_g_path, _d_path, meta_path) = ann_files_for_project(&hash);
let raw = std::fs::read(&meta_path).unwrap();
let mut parsed: serde_json::Value = serde_json::from_slice(&raw).unwrap();
parsed["version"] = serde_json::json!(999);
std::fs::write(&meta_path, serde_json::to_vec(&parsed).unwrap()).unwrap();
invalidate_cache(&hash);
let reloaded = AnnIndex::load_or_empty(&hash, dim).await.unwrap();
assert_eq!(
reloaded.total_size(),
0,
"version drift must wipe the in-memory graph"
);
}
#[tokio::test]
async fn persistence_meta_size_mismatch_invalidates() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("meta-size-mismatch");
let dim = 8;
let chunks = vec![
("a".to_owned(), random_vec(1, dim)),
("b".to_owned(), random_vec(2, dim)),
];
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
idx.save().await.unwrap();
let (_g_path, _d_path, meta_path) = ann_files_for_project(&hash);
let raw = std::fs::read(&meta_path).unwrap();
let mut parsed: serde_json::Value = serde_json::from_slice(&raw).unwrap();
parsed["size"] = serde_json::json!(999);
std::fs::write(&meta_path, serde_json::to_vec(&parsed).unwrap()).unwrap();
invalidate_cache(&hash);
let reloaded = AnnIndex::load_or_empty(&hash, dim).await.unwrap();
assert_eq!(
reloaded.total_size(),
0,
"meta size/id_map mismatch must wipe the in-memory graph"
);
}
#[tokio::test]
async fn persistence_graph_meta_count_mismatch_invalidates() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("graph-meta-count-mismatch");
let dim = 8;
let chunks = vec![
("a".to_owned(), random_vec(1, dim)),
("b".to_owned(), random_vec(2, dim)),
];
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
idx.save().await.unwrap();
let (_g_path, _d_path, meta_path) = ann_files_for_project(&hash);
let raw = std::fs::read(&meta_path).unwrap();
let mut parsed: serde_json::Value = serde_json::from_slice(&raw).unwrap();
parsed["id_map"] = serde_json::json!(["a"]);
parsed["size"] = serde_json::json!(1);
std::fs::write(&meta_path, serde_json::to_vec(&parsed).unwrap()).unwrap();
invalidate_cache(&hash);
let reloaded = AnnIndex::load_or_empty(&hash, dim).await.unwrap();
assert_eq!(
reloaded.total_size(),
0,
"graph point count/id_map mismatch must wipe the in-memory graph"
);
}
#[tokio::test]
async fn stale_slots_trigger_compaction_policy_and_scaled_raw_search() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("stale-policy");
let dim = 16;
let chunks: Vec<(String, Vec<f32>)> = (0..100)
.map(|i| (format!("c{i}"), random_vec(i as u64, dim)))
.collect();
let mut idx = AnnIndex::build_from_chunks(&hash, &chunks).await.unwrap();
assert!(!idx.needs_compaction());
assert_eq!(idx.raw_search_candidate_count(10), 30);
for pass in 0..2 {
for i in 0..100 {
idx.upsert(
&format!("c{i}"),
&random_vec(1_000 + (pass * 100 + i) as u64, dim),
);
}
}
assert!(idx.needs_compaction());
assert_eq!(idx.live_size(), 100);
assert_eq!(idx.total_size(), 300);
assert!(
idx.raw_search_candidate_count(10) > 30,
"raw search budget must scale above the old fixed 3x when stale slots grow"
);
}
#[tokio::test]
async fn get_ann_for_project_caches_across_calls() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("cache");
let a = get_ann_for_project(&hash, 16).await.unwrap();
let b = get_ann_for_project(&hash, 16).await.unwrap();
assert!(
Arc::ptr_eq(&a, &b),
"cache must return the same Arc across calls"
);
}
#[tokio::test]
async fn get_ann_for_project_keys_cache_by_dim() {
let _home = crate::infra::db::shared_test_home();
let hash = unique_hash("cache-dim");
let a = get_ann_for_project(&hash, 16).await.unwrap();
let b = get_ann_for_project(&hash, 32).await.unwrap();
assert!(
!Arc::ptr_eq(&a, &b),
"same project with different embedding dims must not reuse one ANN cache entry"
);
}
}