use crate::index::HNSWIndex;
use crate::primitives::Vector;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug)]
pub enum PersistentHnswError {
Io(std::io::Error),
Decode(bincode::Error),
}
impl std::fmt::Display for PersistentHnswError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "io: {e}"),
Self::Decode(e) => write!(f, "decode: {e}"),
}
}
}
impl std::error::Error for PersistentHnswError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Decode(e) => Some(e),
}
}
}
impl From<std::io::Error> for PersistentHnswError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<bincode::Error> for PersistentHnswError {
fn from(e: bincode::Error) -> Self {
Self::Decode(e)
}
}
pub struct PersistentHnsw {
inner: HNSWIndex,
path: PathBuf,
dirty: bool,
}
impl PersistentHnsw {
pub fn open<P: AsRef<Path>>(
path: P,
m: usize,
ef_construction: usize,
) -> Result<Self, PersistentHnswError> {
let path = path.as_ref().to_path_buf();
let inner = if path.is_file() {
let bytes = fs::read(&path)?;
bincode::deserialize::<HNSWIndex>(&bytes)?
} else {
HNSWIndex::new(m, ef_construction, 0.0)
};
Ok(Self {
inner,
path,
dirty: false,
})
}
pub fn add(&mut self, item_id: impl Into<String>, vector: Vector<f64>) {
self.inner.add(item_id, vector);
self.dirty = true;
}
#[must_use]
pub fn search(&self, query: &Vector<f64>, k: usize) -> Vec<(String, f64)> {
self.inner.search(query, k)
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_dirty(&self) -> bool {
self.dirty
}
pub fn flush(&mut self) -> Result<(), PersistentHnswError> {
use std::io::Write;
let bytes = bincode::serialize(&self.inner)?;
let tmp = Self::tmp_path(&self.path);
{
let mut f = fs::File::create(&tmp)?;
f.write_all(&bytes)?;
f.sync_all()?;
}
fs::rename(&tmp, &self.path)?;
self.dirty = false;
Ok(())
}
pub(crate) fn tmp_path(path: &Path) -> PathBuf {
let mut s = path.as_os_str().to_os_string();
s.push(".tmp");
PathBuf::from(s)
}
#[must_use]
pub fn inner(&self) -> &HNSWIndex {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn fixture(idx: &mut PersistentHnsw) {
idx.add("a", Vector::from_slice(&[1.0, 0.0, 0.0]));
idx.add("b", Vector::from_slice(&[0.0, 1.0, 0.0]));
idx.add("c", Vector::from_slice(&[0.0, 0.0, 1.0]));
idx.add("d", Vector::from_slice(&[0.5, 0.5, 0.0]));
}
#[test]
fn open_creates_empty_when_path_does_not_exist() {
let dir = tempdir().unwrap();
let path = dir.path().join("new.bin");
let idx = PersistentHnsw::open(&path, 8, 64).unwrap();
assert!(idx.is_empty());
assert!(!idx.is_dirty());
}
#[test]
fn add_marks_dirty_flush_clears() {
let dir = tempdir().unwrap();
let path = dir.path().join("dirty.bin");
let mut idx = PersistentHnsw::open(&path, 8, 64).unwrap();
idx.add("x", Vector::from_slice(&[1.0, 0.0]));
assert!(idx.is_dirty());
idx.flush().unwrap();
assert!(!idx.is_dirty());
assert!(path.is_file());
}
#[test]
fn flush_then_reopen_preserves_search_byte_stable() {
let dir = tempdir().unwrap();
let path = dir.path().join("rt.bin");
let mut idx = PersistentHnsw::open(&path, 8, 64).unwrap();
fixture(&mut idx);
let pre = idx.search(&Vector::from_slice(&[0.9, 0.1, 0.0]), 3);
idx.flush().unwrap();
drop(idx);
let reopened = PersistentHnsw::open(&path, 8, 64).unwrap();
let post = reopened.search(&Vector::from_slice(&[0.9, 0.1, 0.0]), 3);
assert_eq!(pre, post);
assert_eq!(reopened.len(), 4);
}
#[test]
fn open_after_decode_failure_returns_error_not_panic() {
let dir = tempdir().unwrap();
let path = dir.path().join("garbage.bin");
fs::write(&path, b"not a bincode payload").unwrap();
let result = PersistentHnsw::open(&path, 8, 64);
assert!(matches!(result, Err(PersistentHnswError::Decode(_))));
}
}