use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::Result;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::l2_normalize_f32;
use super::HnswIndex;
use super::MemoryEntry;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticHit {
pub entry: MemoryEntry,
pub distance: f32,
pub similarity: f32,
}
pub struct HnswMemoryIndex {
index: RwLock<HnswIndex>,
key_to_id: RwLock<HashMap<u64, String>>,
id_to_key: RwLock<HashMap<String, u64>>,
next_key: AtomicU64,
persist_path: Option<PathBuf>,
}
impl std::fmt::Debug for HnswMemoryIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswMemoryIndex")
.field("size", &self.len())
.field("dimensions", &self.index.read().dimensions())
.finish()
}
}
impl HnswMemoryIndex {
pub fn new(dimensions: usize, capacity: usize, persist_path: Option<PathBuf>) -> Result<Self> {
let index = HnswIndex::new(dimensions, capacity)?;
Ok(Self {
index: RwLock::new(index),
key_to_id: RwLock::new(HashMap::new()),
id_to_key: RwLock::new(HashMap::new()),
next_key: AtomicU64::new(1), persist_path,
})
}
pub fn restore_or_new(
dimensions: usize,
capacity: usize,
persist_path: Option<PathBuf>,
) -> Result<Self> {
if let Some(ref path) = persist_path {
let index_path = path.join("memory.usearch");
let mapping_path = path.join("key_map.json");
if index_path.exists() && mapping_path.exists() {
tracing::info!(path = %index_path.display(), "Restoring HNSW index from disk");
if let Ok(index) = HnswIndex::load(&index_path) {
if let Ok(data) = std::fs::read_to_string(&mapping_path) {
if let Ok((k2i, i2k)) = serde_json::from_str::<(
HashMap<u64, String>,
HashMap<String, u64>,
)>(&data)
{
let max_key = k2i.keys().max().copied().unwrap_or(0);
return Ok(Self {
index: RwLock::new(index),
key_to_id: RwLock::new(k2i),
id_to_key: RwLock::new(i2k),
next_key: AtomicU64::new(max_key + 1),
persist_path,
});
}
}
}
tracing::warn!("Failed to restore HNSW index, creating new one");
}
}
Self::new(dimensions, capacity, persist_path)
}
fn get_or_create_key(&self, id: &str) -> u64 {
{
let i2k = self.id_to_key.read();
if let Some(&key) = i2k.get(id) {
return key;
}
}
let mut i2k = self.id_to_key.write();
let mut k2i = self.key_to_id.write();
if let Some(&key) = i2k.get(id) {
return key;
}
let key = self.next_key.fetch_add(1, Ordering::Relaxed);
i2k.insert(id.to_string(), key);
k2i.insert(key, id.to_string());
key
}
pub fn add_entry(&self, id: &str, vector: &[f32]) -> Result<()> {
let key = self.get_or_create_key(id);
let mut normalized = vector.to_vec();
l2_normalize_f32(&mut normalized);
self.index.write().add(key, &normalized)?;
Ok(())
}
pub fn remove_entry(&self, id: &str) -> Result<()> {
let key = {
let i2k = self.id_to_key.read();
i2k.get(id).copied()
};
if let Some(key) = key {
self.index.write().remove(key)?;
let mut k2i = self.key_to_id.write();
let mut i2k = self.id_to_key.write();
k2i.remove(&key);
i2k.remove(id);
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
let mut normalized = query.to_vec();
l2_normalize_f32(&mut normalized);
let raw = self.index.read().search(&normalized, k)?;
let k2i = self.key_to_id.read();
let results = raw
.into_iter()
.filter_map(|(key, dist)| k2i.get(&key).map(|id| (id.clone(), dist)))
.collect();
Ok(results)
}
pub fn len(&self) -> usize {
self.index.read().len()
}
pub fn is_empty(&self) -> bool {
self.index.read().is_empty()
}
pub fn persist(&self) -> Result<()> {
if let Some(ref path) = self.persist_path {
std::fs::create_dir_all(path)?;
let index_path = path.join("memory.usearch");
let mapping_path = path.join("key_map.json");
self.index.read().save(&index_path)?;
let k2i = self.key_to_id.read();
let i2k = self.id_to_key.read();
let data = serde_json::to_string(&(k2i.clone(), &*i2k))?;
std::fs::write(&mapping_path, data)?;
tracing::debug!(path = %path.display(), entries = self.len(), "HNSW index persisted");
}
Ok(())
}
}