use std::collections::HashSet;
use std::fs::OpenOptions;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use hnsw_rs::hnswio::load_description;
use hnsw_rs::prelude::{DistCosine, Hnsw, HnswIo, Neighbour};
use parking_lot::RwLock;
use solo_core::{Error, Result, VectorIndex, VectorIndexFactory};
fn ef_for_search(knbn: usize) -> usize {
(knbn * 4).clamp(16, 200)
}
#[derive(Debug, Clone, Copy)]
pub struct HnswParams {
pub max_nb_connection: usize,
pub ef_construction: usize,
pub max_layer: usize,
pub max_elements_hint: usize,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
max_nb_connection: 16,
ef_construction: 200,
max_layer: 16,
max_elements_hint: 10_000,
}
}
}
type Inner = Hnsw<'static, f32, DistCosine>;
pub struct HnswIndex {
inner: Inner,
dim: usize,
tombstones: RwLock<HashSet<i64>>,
}
impl std::fmt::Debug for HnswIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswIndex")
.field("dim", &self.dim)
.field("len", &self.inner.get_nb_point())
.field("tombstones", &self.tombstones.read().len())
.finish()
}
}
impl HnswIndex {
pub fn new(dim: usize, params: HnswParams) -> Self {
let inner = Hnsw::<'static, f32, DistCosine>::new(
params.max_nb_connection,
params.max_elements_hint,
params.max_layer,
params.ef_construction,
DistCosine,
);
Self {
inner,
dim,
tombstones: RwLock::new(HashSet::new()),
}
}
pub(crate) fn from_inner(inner: Inner, dim: usize) -> Self {
Self {
inner,
dim,
tombstones: RwLock::new(HashSet::new()),
}
}
pub(crate) fn inner(&self) -> &Inner {
&self.inner
}
pub(crate) fn raw_len(&self) -> usize {
self.inner.get_nb_point()
}
}
impl VectorIndex for HnswIndex {
fn add(&self, rowid: i64, embedding: &[f32]) -> Result<()> {
if rowid < 0 {
return Err(Error::vector_index(format!(
"rowid must be non-negative; got {rowid}"
)));
}
if embedding.len() != self.dim {
return Err(Error::vector_index(format!(
"embedding dim mismatch: index dim={}, got {}",
self.dim,
embedding.len()
)));
}
self.inner.insert((embedding, rowid as usize));
self.tombstones.write().remove(&rowid);
Ok(())
}
fn remove(&self, rowid: i64) -> Result<()> {
self.tombstones.write().insert(rowid);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
if query.len() != self.dim {
return Err(Error::vector_index(format!(
"query dim mismatch: index dim={}, got {}",
self.dim,
query.len()
)));
}
if k == 0 {
return Ok(Vec::new());
}
let ef = ef_for_search(k);
let widened = (k * 2).min(k + 32);
let neighbours: Vec<Neighbour> = self.inner.search(query, widened, ef);
let tombs = self.tombstones.read();
let mut out: Vec<(i64, f32)> = neighbours
.into_iter()
.map(|n| (n.d_id as i64, n.distance))
.filter(|(rowid, _)| !tombs.contains(rowid))
.take(k)
.collect();
out.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(out)
}
fn save(&self, path: &Path) -> Result<()> {
crate::snapshot::save(self, path)
}
fn len(&self) -> usize {
let total = self.inner.get_nb_point();
let tomb = self.tombstones.read().len();
total.saturating_sub(tomb)
}
fn dim(&self) -> usize {
self.dim
}
}
#[derive(Debug, Clone)]
pub struct HnswFactory {
params: HnswParams,
}
impl HnswFactory {
pub fn new() -> Self {
Self {
params: HnswParams::default(),
}
}
pub fn with_params(params: HnswParams) -> Self {
Self { params }
}
}
impl Default for HnswFactory {
fn default() -> Self {
Self::new()
}
}
impl VectorIndexFactory for HnswFactory {
type Index = HnswIndex;
fn create(&self, dim: usize) -> Result<Self::Index> {
Ok(HnswIndex::new(dim, self.params))
}
fn load(&self, path: &Path) -> Result<Self::Index> {
let dim = self.params.max_elements_hint; let _ = dim;
crate::snapshot::load(path).or_else(|primary_err| {
tracing::warn!(
error = %primary_err,
"primary HNSW snapshot failed to load; trying .bak"
);
crate::snapshot::load_bak(path)
})
}
}
pub(crate) fn load_inner_from_basename(
dir: &Path,
basename: &str,
) -> Result<HnswIndex> {
let mut graph_path = PathBuf::from(dir);
graph_path.push(format!("{basename}.hnsw.graph"));
let dim = peek_dim(&graph_path)?;
let io: &'static mut HnswIo = Box::leak(Box::new(HnswIo::new(dir, basename)));
let inner: Inner = io
.load_hnsw::<f32, DistCosine>()
.map_err(|e| Error::vector_index(format!("HnswIo::load_hnsw: {e}")))?;
Ok(HnswIndex::from_inner(inner, dim))
}
fn peek_dim(graph_path: &Path) -> Result<usize> {
let f = OpenOptions::new()
.read(true)
.open(graph_path)
.map_err(|e| Error::vector_index(format!("open {graph_path:?}: {e}")))?;
let mut buf = BufReader::new(f);
let descr = load_description(&mut buf)
.map_err(|e| Error::vector_index(format!("load_description {graph_path:?}: {e}")))?;
Ok(descr.dimension)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn unit_vec(seed: u32, dim: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
let s = (seed as f32) * 0.123;
for (i, x) in v.iter_mut().enumerate() {
let t = s + i as f32 * 0.317;
*x = t.sin();
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-9);
for x in &mut v {
*x /= norm;
}
v
}
#[test]
fn fresh_index_is_empty_and_searchable() {
let idx = HnswIndex::new(8, HnswParams::default());
assert_eq!(idx.len(), 0);
assert!(idx.is_empty());
assert_eq!(idx.dim(), 8);
let res = idx.search(&unit_vec(1, 8), 5).unwrap();
assert!(res.is_empty());
}
#[test]
fn add_and_search_finds_self() {
let idx = HnswIndex::new(8, HnswParams::default());
let v = unit_vec(7, 8);
idx.add(42, &v).unwrap();
assert_eq!(idx.len(), 1);
let hits = idx.search(&v, 3).unwrap();
assert!(!hits.is_empty(), "search returned no results");
assert_eq!(hits[0].0, 42, "self-search must return rowid 42 first");
}
#[test]
fn dim_mismatch_is_rejected() {
let idx = HnswIndex::new(8, HnswParams::default());
let err = idx.add(1, &vec![0.0; 4]).unwrap_err();
assert!(err.to_string().contains("dim mismatch"));
let err = idx.search(&vec![0.0; 4], 1).unwrap_err();
assert!(err.to_string().contains("dim mismatch"));
}
#[test]
fn negative_rowid_rejected() {
let idx = HnswIndex::new(4, HnswParams::default());
let err = idx.add(-1, &unit_vec(1, 4)).unwrap_err();
assert!(err.to_string().contains("non-negative"));
}
#[test]
fn remove_tombstones_filtered_from_search() {
let idx = HnswIndex::new(8, HnswParams::default());
idx.add(1, &unit_vec(1, 8)).unwrap();
idx.add(2, &unit_vec(2, 8)).unwrap();
idx.add(3, &unit_vec(3, 8)).unwrap();
idx.remove(2).unwrap();
let hits = idx.search(&unit_vec(2, 8), 3).unwrap();
assert!(
!hits.iter().any(|(r, _)| *r == 2),
"tombstoned rowid 2 must not appear: {hits:?}"
);
assert_eq!(idx.len(), 2);
}
#[test]
fn re_add_lifts_tombstone() {
let idx = HnswIndex::new(8, HnswParams::default());
idx.add(5, &unit_vec(5, 8)).unwrap();
idx.remove(5).unwrap();
idx.add(5, &unit_vec(5, 8)).unwrap();
let hits = idx.search(&unit_vec(5, 8), 3).unwrap();
assert!(
hits.iter().any(|(r, _)| *r == 5),
"re-added rowid must reappear: {hits:?}"
);
}
#[test]
fn factory_create_returns_empty_index() {
let factory = HnswFactory::new();
let idx = factory.create(16).unwrap();
assert_eq!(idx.dim(), 16);
assert_eq!(idx.len(), 0);
}
#[test]
fn shareable_across_threads_via_arc() {
let idx: Arc<dyn VectorIndex + Send + Sync> =
Arc::new(HnswIndex::new(4, HnswParams::default()));
let idx2 = idx.clone();
let h = std::thread::spawn(move || {
idx2.add(1, &unit_vec(1, 4)).unwrap();
});
h.join().unwrap();
assert_eq!(idx.len(), 1);
}
#[test]
fn hnsw_rs_accepts_duplicate_origin_id_silently() {
let dim = 8usize;
let idx = HnswIndex::new(dim, HnswParams::default());
let mut vec_a = vec![0.0f32; dim];
vec_a[0] = 1.0;
let mut vec_b = vec![0.0f32; dim];
vec_b[1] = 1.0;
let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(a, b)| a * b).sum();
assert_eq!(dot, 0.0, "test vectors must be orthogonal");
idx.add(1, &vec_a).unwrap();
idx.add(1, &vec_b).unwrap();
let len = idx.len();
eprintln!(
"hnsw_rs duplicate-add probe: after add(1, vec_a) + add(1, vec_b), len() = {len}"
);
assert_eq!(
len, 2,
"hnsw_rs 0.3.4 accepts duplicate origin_id silently; \
observed len()={len}. If this assertion fails, the lib's \
behavior has changed (case b 'overwrite' would give len=1, \
case c 'error' would have errored above). Update \
docs/dev-log/0084-... and the hnsw_id module docs."
);
let hits = idx.search(&vec_a, 5).unwrap();
eprintln!("hnsw_rs duplicate-add probe: search returned {} hits: {hits:?}", hits.len());
let dup_rowid_hits = hits.iter().filter(|(r, _)| *r == 1).count();
assert!(
dup_rowid_hits >= 1,
"search must return at least one hit with the duplicated origin_id; got {dup_rowid_hits} (hits={hits:?})"
);
}
}