use std::{collections::HashMap, path::Path, sync::RwLock};
use anyhow::{Context, Result};
use hnsw_rs::{
api::AnnT,
hnswio::HnswIo,
prelude::{DistCosine, Hnsw},
};
use serde::{Deserialize, Serialize};
use crate::kb::store::KbStore;
pub const DEFAULT_DIMENSION: usize = 1024;
const M: usize = 16;
const EF_CONSTRUCTION: usize = 200;
const MAX_NB_LAYER: usize = 16;
const EF_SEARCH: usize = 64;
const INITIAL_CAPACITY: usize = 10_000;
const SNAPSHOT_NAME: &str = "snapshot";
const SNAPSHOT_SCHEMA_VERSION: u32 = 1;
#[derive(Serialize, Deserialize)]
struct HnswMeta {
#[serde(default = "default_schema_version")]
schema_version: u32,
dimension: usize,
id_to_chunk: Vec<String>,
}
fn default_schema_version() -> u32 {
SNAPSHOT_SCHEMA_VERSION
}
pub struct HnswCache {
inner: RwLock<HnswInner>,
dimension: usize,
}
struct HnswInner {
hnsw: Hnsw<'static, f32, DistCosine>,
id_to_chunk: Vec<String>,
chunk_to_id: HashMap<String, usize>,
}
impl HnswInner {
fn empty() -> Self {
Self {
hnsw: Hnsw::<'static, f32, DistCosine>::new(
M,
INITIAL_CAPACITY,
MAX_NB_LAYER,
EF_CONSTRUCTION,
DistCosine,
),
id_to_chunk: Vec::new(),
chunk_to_id: HashMap::new(),
}
}
}
impl HnswCache {
pub fn new(dimension: usize) -> Self {
Self {
inner: RwLock::new(HnswInner::empty()),
dimension,
}
}
pub fn empty() -> Self {
Self::new(DEFAULT_DIMENSION)
}
pub fn dim(&self) -> usize {
self.dimension
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
let inner = self.inner.read().unwrap_or_else(|p| p.into_inner());
if inner.id_to_chunk.is_empty() || query.len() != self.dimension {
return Vec::new();
}
let raw = inner.hnsw.search(query, k, EF_SEARCH);
raw.into_iter()
.filter_map(|n| {
inner
.id_to_chunk
.get(n.d_id)
.map(|id| (id.clone(), 1.0 - n.distance))
})
.collect()
}
pub fn insert(&self, chunk_id: &str, vector: &[f32]) -> Result<()> {
if vector.len() != self.dimension {
return Err(anyhow::anyhow!(
"hnsw insert: expected dim={}, got {}",
self.dimension,
vector.len()
));
}
let mut inner = self.inner.write().unwrap_or_else(|p| p.into_inner());
let new_id = inner.id_to_chunk.len();
inner.id_to_chunk.push(chunk_id.to_string());
inner.chunk_to_id.insert(chunk_id.to_string(), new_id);
let vec_clone = vector.to_vec();
inner.hnsw.insert((&vec_clone, new_id));
Ok(())
}
pub fn rebuild(&self, store: &KbStore) -> Result<()> {
let rtx = store.begin_read()?;
let mut id_to_chunk: Vec<String> = Vec::new();
let mut chunk_to_id: HashMap<String, usize> = HashMap::new();
let mut vectors: Vec<Vec<f32>> = Vec::new();
{
use redb::ReadableTable;
use crate::kb::{
model::KbChunk,
store::{codec::decode, schema::KB_CHUNKS},
};
let tbl = rtx.open_table(KB_CHUNKS)?;
for entry in tbl.iter()? {
let (_, v) = entry?;
let c: KbChunk = decode(v.value())?;
if c.vector.len() != self.dimension {
continue;
}
let seq = id_to_chunk.len();
chunk_to_id.insert(c.id.clone(), seq);
id_to_chunk.push(c.id.clone());
vectors.push(c.vector);
}
}
let capacity = INITIAL_CAPACITY.max(vectors.len() * 2);
let hnsw = Hnsw::<'static, f32, DistCosine>::new(
M,
capacity,
MAX_NB_LAYER,
EF_CONSTRUCTION,
DistCosine,
);
let inserts: Vec<(&Vec<f32>, usize)> =
vectors.iter().enumerate().map(|(i, v)| (v, i)).collect();
hnsw.parallel_insert(&inserts);
let new_inner = HnswInner {
hnsw,
id_to_chunk,
chunk_to_id,
};
let n = new_inner.id_to_chunk.len();
*self.inner.write().unwrap_or_else(|p| p.into_inner()) = new_inner;
tracing::info!(n, "kb hnsw: rebuild complete");
Ok(())
}
pub fn len(&self) -> usize {
self.inner.read().unwrap_or_else(|p| p.into_inner()).id_to_chunk.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn snapshot(&self, dir: &Path) -> Result<()> {
std::fs::create_dir_all(dir)
.with_context(|| format!("create_dir_all {}", dir.display()))?;
let inner = self.inner.read().unwrap_or_else(|p| p.into_inner());
if !inner.id_to_chunk.is_empty() {
inner
.hnsw
.file_dump(dir, SNAPSHOT_NAME)
.map_err(|e| anyhow::anyhow!("hnsw file_dump: {e}"))?;
}
let meta = HnswMeta {
schema_version: SNAPSHOT_SCHEMA_VERSION,
dimension: self.dimension,
id_to_chunk: inner.id_to_chunk.clone(),
};
let meta_path = dir.join(format!("{SNAPSHOT_NAME}.meta.json"));
std::fs::write(&meta_path, serde_json::to_vec(&meta)?)
.with_context(|| format!("write {}", meta_path.display()))?;
tracing::info!(
n = inner.id_to_chunk.len(),
dir = %dir.display(),
"kb hnsw: snapshot written"
);
Ok(())
}
pub fn restore(&self, dir: &Path) -> Result<bool> {
let meta_path = dir.join(format!("{SNAPSHOT_NAME}.meta.json"));
let graph_path = dir.join(format!("{SNAPSHOT_NAME}.hnsw.graph"));
let data_path = dir.join(format!("{SNAPSHOT_NAME}.hnsw.data"));
if !meta_path.exists() {
return Ok(false);
}
let meta_bytes =
std::fs::read(&meta_path).with_context(|| format!("read {}", meta_path.display()))?;
let meta: HnswMeta = serde_json::from_slice(&meta_bytes)
.with_context(|| format!("decode {}", meta_path.display()))?;
if meta.schema_version != SNAPSHOT_SCHEMA_VERSION {
return Err(anyhow::anyhow!(
"snapshot schema_version={} is incompatible with runtime version={SNAPSHOT_SCHEMA_VERSION} \
— delete the hnsw/ directory to force a rebuild from redb",
meta.schema_version
));
}
if meta.dimension != self.dimension {
return Err(anyhow::anyhow!(
"snapshot dim={} does not match runtime dim={}",
meta.dimension,
self.dimension
));
}
let n = meta.id_to_chunk.len();
if n == 0 {
let mut inner = self.inner.write().unwrap_or_else(|p| p.into_inner());
*inner = HnswInner::empty();
tracing::info!(dir = %dir.display(), "kb hnsw: restored empty snapshot");
return Ok(true);
}
if !graph_path.exists() || !data_path.exists() {
return Err(anyhow::anyhow!(
"snapshot meta present but graph/data files missing in {}",
dir.display()
));
}
let reloader_box: Box<HnswIo> = Box::new(HnswIo::new(dir, SNAPSHOT_NAME));
let reloader: &'static mut HnswIo = Box::leak(reloader_box);
let hnsw: Hnsw<'static, f32, DistCosine> = reloader
.load_hnsw::<f32, DistCosine>()
.map_err(|e| anyhow::anyhow!("hnsw load: {e}"))?;
let mut chunk_to_id = HashMap::with_capacity(n);
for (i, id) in meta.id_to_chunk.iter().enumerate() {
chunk_to_id.insert(id.clone(), i);
}
let new_inner = HnswInner {
hnsw,
id_to_chunk: meta.id_to_chunk,
chunk_to_id,
};
*self.inner.write().unwrap_or_else(|p| p.into_inner()) = new_inner;
tracing::info!(n, dir = %dir.display(), "kb hnsw: snapshot restored");
Ok(true)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tempfile::TempDir;
use super::*;
use crate::kb::{
canonicalize::{CanonicalizeInput, canonicalize_by_mime},
embedder::{KbEmbedder, StubEmbedder},
paths::KbPaths,
pipeline::{IngestInput, ingest_canonicalized},
store::KbStore,
worker::{DefaultDispatcher, WorkerConfig, WorkerPool, handlers::HandlerCtx},
};
fn fixture_with_chunks() -> (TempDir, Arc<KbStore>) {
let tmp = TempDir::new().unwrap();
let store = Arc::new(KbStore::open(&tmp.path().join("kb.redb")).unwrap());
let paths = Arc::new(KbPaths::new(tmp.path().join("kb")));
paths.ensure_layout().unwrap();
let embedder: Arc<dyn KbEmbedder> = Arc::new(StubEmbedder::default());
let bytes = b"# Hi\n\nfirst body content.\n\nsecond body content.";
let canon = canonicalize_by_mime(CanonicalizeInput {
bytes,
mime: "text/markdown",
hint_title: Some("t"),
logical_source_id_seed: None,
})
.unwrap()
.unwrap();
ingest_canonicalized(
&store,
IngestInput {
canon: &canon,
raw_bytes: bytes,
raw_ext: "md",
visibility: None,
owner_user_id: None,
seen_key: None,
source: None,
paths: &paths,
},
)
.unwrap();
let index = Arc::new(crate::kb::index::KbIndex::open(&paths).unwrap());
let ctx = HandlerCtx {
store: store.clone(),
paths,
embedder,
index,
};
let cfg = WorkerConfig {
worker_id: "w".into(),
..WorkerConfig::default()
};
WorkerPool::run_one_blocking(&ctx, &cfg, &DefaultDispatcher).unwrap();
(tmp, store)
}
#[test]
fn rebuild_then_search_returns_hits() {
let (_tmp, store) = fixture_with_chunks();
let cache = HnswCache::empty();
cache.rebuild(&store).unwrap();
assert!(cache.len() > 0, "expected chunks to be loaded");
let q = vec![0.0_f32; DEFAULT_DIMENSION];
let hits = cache.search(&q, 5);
assert!(!hits.is_empty(), "expected hits, got empty");
}
#[test]
fn search_on_empty_returns_empty() {
let cache = HnswCache::empty();
assert!(cache.search(&vec![0.0; DEFAULT_DIMENSION], 5).is_empty());
}
#[test]
fn insert_dim_mismatch_errors() {
let cache = HnswCache::empty();
assert!(cache.insert("c1", &[0.0; 512]).is_err());
}
#[test]
fn append_only_insert_returns_new_chunk_for_same_id() {
let cache = HnswCache::empty();
let v1 = vec![1.0_f32; DEFAULT_DIMENSION];
let mut v2 = vec![0.0_f32; DEFAULT_DIMENSION];
v2[0] = 1.0; cache.insert("c1", &v1).unwrap();
cache.insert("c1", &v2).unwrap();
let hits = cache.search(&v2, 1);
assert!(!hits.is_empty());
assert_eq!(hits[0].0, "c1");
}
#[test]
fn snapshot_roundtrip_preserves_search() {
let dir = TempDir::new().unwrap();
let dump_dir = dir.path().join("snap");
let cache = HnswCache::empty();
let mut v_alpha = vec![0.0_f32; DEFAULT_DIMENSION];
v_alpha[0] = 1.0;
let mut v_beta = vec![0.0_f32; DEFAULT_DIMENSION];
v_beta[1] = 1.0;
cache.insert("alpha", &v_alpha).unwrap();
cache.insert("beta", &v_beta).unwrap();
cache.snapshot(&dump_dir).unwrap();
let restored = HnswCache::empty();
assert!(restored.restore(&dump_dir).unwrap());
assert_eq!(restored.len(), 2);
let hits = restored.search(&v_alpha, 1);
assert!(!hits.is_empty());
assert_eq!(hits[0].0, "alpha");
}
#[test]
fn restore_returns_false_when_no_snapshot() {
let dir = TempDir::new().unwrap();
let cache = HnswCache::empty();
assert!(!cache.restore(dir.path()).unwrap());
}
#[test]
fn restore_rejects_incompatible_schema_version() {
let dir = TempDir::new().unwrap();
let dump_dir = dir.path().join("snap");
std::fs::create_dir_all(&dump_dir).unwrap();
let bad_meta = serde_json::json!({
"schema_version": 999,
"dimension": DEFAULT_DIMENSION,
"id_to_chunk": ["c1"],
});
std::fs::write(
dump_dir.join("snapshot.meta.json"),
serde_json::to_vec(&bad_meta).unwrap(),
)
.unwrap();
let cache = HnswCache::empty();
let err = cache.restore(&dump_dir).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("schema_version"),
"expected schema_version mismatch error, got: {msg}"
);
}
#[test]
fn snapshot_empty_cache_writes_meta_only() {
let dir = TempDir::new().unwrap();
let dump_dir = dir.path().join("empty");
let cache = HnswCache::empty();
cache.snapshot(&dump_dir).unwrap();
assert!(dump_dir.join("snapshot.meta.json").exists());
assert!(!dump_dir.join("snapshot.hnsw.graph").exists());
let restored = HnswCache::empty();
assert!(restored.restore(&dump_dir).unwrap());
assert_eq!(restored.len(), 0);
}
}