#![allow(clippy::duplicated_attributes)]
#![cfg(feature = "memory-hnsw")]
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use instant_distance::{Builder, HnswMap, Search};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::warn;
use crate::error::Result;
use crate::providers::LLMProvider;
use super::traits::MemorySearcher;
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
(dot / (mag_a * mag_b)).clamp(0.0, 1.0)
}
#[derive(Clone)]
struct EmbeddingPoint(Vec<f32>);
impl instant_distance::Point for EmbeddingPoint {
fn distance(&self, other: &Self) -> f32 {
1.0 - cosine_similarity(&self.0, &other.0)
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
struct VectorStore {
vectors: HashMap<String, Vec<f32>>,
}
fn load_vector_store(path: &PathBuf) -> VectorStore {
match std::fs::read_to_string(path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_else(|e| {
warn!(
"Failed to parse HNSW vector store at {}: {}",
path.display(),
e
);
VectorStore::default()
}),
Err(_) => VectorStore::default(),
}
}
fn save_vector_store(path: &PathBuf, store: &VectorStore) {
if let Some(parent) = path.parent() {
if let Err(e) = std::fs::create_dir_all(parent) {
warn!("Failed to create HNSW vector store directory: {}", e);
return;
}
}
match serde_json::to_string_pretty(store) {
Ok(json) => {
if let Err(e) = std::fs::write(path, &json) {
warn!(
"Failed to write HNSW vector store to {}: {}",
path.display(),
e
);
}
}
Err(e) => warn!("Failed to serialize HNSW vector store: {}", e),
}
}
fn build_hnsw_index(store: &VectorStore) -> Option<HnswMap<EmbeddingPoint, String>> {
if store.vectors.is_empty() {
return None;
}
let mut entries: Vec<(&String, &Vec<f32>)> = store.vectors.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
let points: Vec<EmbeddingPoint> = entries
.iter()
.map(|(_, v)| EmbeddingPoint((*v).clone()))
.collect();
let values: Vec<String> = entries.iter().map(|(k, _)| (*k).clone()).collect();
Some(Builder::default().build(points, values))
}
pub struct HnswSearcher {
provider: Arc<dyn LLMProvider>,
store: RwLock<VectorStore>,
index: RwLock<Option<HnswMap<EmbeddingPoint, String>>>,
store_path: PathBuf,
}
impl HnswSearcher {
pub fn new(provider: Arc<dyn LLMProvider>, store_path: PathBuf) -> Self {
let store = load_vector_store(&store_path);
let index = build_hnsw_index(&store);
Self {
provider,
store: RwLock::new(store),
index: RwLock::new(index),
store_path,
}
}
async fn rebuild_index(&self) {
let store = self.store.read().await;
let new_index = build_hnsw_index(&store);
let mut idx = self.index.write().await;
*idx = new_index;
}
}
#[async_trait]
impl MemorySearcher for HnswSearcher {
fn name(&self) -> &str {
"hnsw"
}
fn score(&self, _chunk: &str, _query: &str) -> f32 {
0.0
}
async fn score_batch(&self, chunks: &[&str], query: &str) -> Vec<f32> {
if chunks.is_empty() {
return Vec::new();
}
let embeddings = match self.provider.embed(&[query.to_string()]).await {
Ok(vecs) => vecs,
Err(e) => {
warn!(
"HNSW: embedding failed in score_batch: {}; returning zero scores",
e
);
return vec![0.0; chunks.len()];
}
};
let query_vec = match embeddings.into_iter().next() {
Some(v) if !v.is_empty() => v,
_ => {
warn!("HNSW: embed() returned no vector; returning zero scores");
return vec![0.0; chunks.len()];
}
};
let query_point = EmbeddingPoint(query_vec.clone());
let index = self.index.read().await;
let hnsw_map = match index.as_ref() {
Some(m) => m,
None => {
return vec![0.0; chunks.len()];
}
};
let store = self.store.read().await;
let k = store.vectors.len().min(chunks.len().max(10));
drop(store);
let mut search = Search::default();
let neighbors: HashMap<String, f32> = hnsw_map
.search(&query_point, &mut search)
.take(k)
.map(|item| {
let sim = (1.0 - item.distance).clamp(0.0, 1.0);
(item.value.clone(), sim)
})
.collect();
chunks
.iter()
.map(|chunk| neighbors.get(*chunk).copied().unwrap_or(0.0))
.collect()
}
async fn index(&self, key: &str, text: &str) -> Result<()> {
let embeddings = self.provider.embed(&[text.to_string()]).await?;
let vector = embeddings.into_iter().next().unwrap_or_default();
{
let mut store = self.store.write().await;
store.vectors.insert(key.to_string(), vector);
save_vector_store(&self.store_path, &store);
}
self.rebuild_index().await;
Ok(())
}
async fn remove(&self, key: &str) -> Result<()> {
{
let mut store = self.store.write().await;
store.vectors.remove(key);
save_vector_store(&self.store_path, &store);
}
self.rebuild_index().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hnsw_cosine_identical() {
let v = vec![1.0f32, 2.0, 3.0];
let score = cosine_similarity(&v, &v);
assert!(
(score - 1.0).abs() < 1e-6,
"Identical vectors should produce similarity 1.0, got {}",
score
);
}
#[test]
fn test_hnsw_cosine_orthogonal() {
let a = vec![1.0f32, 0.0, 0.0];
let b = vec![0.0f32, 1.0, 0.0];
let score = cosine_similarity(&a, &b);
assert!(
score.abs() < 1e-6,
"Orthogonal vectors should produce similarity 0.0, got {}",
score
);
}
use crate::error::Result as ZResult;
use crate::providers::{ChatOptions, LLMProvider, LLMResponse, ToolDefinition};
use crate::session::Message;
use async_trait::async_trait;
use std::sync::Arc;
struct FakeHnswProvider {
dim: usize,
}
#[async_trait]
impl LLMProvider for FakeHnswProvider {
fn name(&self) -> &str {
"fake-hnsw"
}
fn default_model(&self) -> &str {
"fake-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> ZResult<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
async fn embed(&self, texts: &[String]) -> ZResult<Vec<Vec<f32>>> {
Ok(texts
.iter()
.enumerate()
.map(|(i, _)| {
let mut v = vec![0.0f32; self.dim];
if !v.is_empty() {
v[i % self.dim] = 1.0;
}
v
})
.collect())
}
}
#[test]
fn test_hnsw_searcher_name() {
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let path = std::env::temp_dir().join("zepto_test_hnsw_name.json");
let searcher = HnswSearcher::new(provider, path);
assert_eq!(searcher.name(), "hnsw");
}
#[test]
fn test_hnsw_sync_score_returns_zero() {
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let path = std::env::temp_dir().join("zepto_test_hnsw_sync.json");
let searcher = HnswSearcher::new(provider, path);
assert_eq!(searcher.score("hello world", "hello"), 0.0);
assert_eq!(searcher.score("", ""), 0.0);
}
#[tokio::test]
async fn test_hnsw_vector_persistence() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("hnsw_vectors.json");
let mut store = VectorStore::default();
store.vectors.insert("k1".to_string(), vec![1.0, 0.0]);
store.vectors.insert("k2".to_string(), vec![0.0, 1.0]);
save_vector_store(&path, &store);
let loaded = load_vector_store(&path);
assert_eq!(loaded.vectors.len(), 2);
assert_eq!(loaded.vectors["k1"], vec![1.0, 0.0]);
assert_eq!(loaded.vectors["k2"], vec![0.0, 1.0]);
}
#[tokio::test]
async fn test_hnsw_index_stores_vector() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("hnsw_vectors.json");
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let searcher = HnswSearcher::new(provider, path.clone());
searcher.index("key:hello", "hello world").await.unwrap();
let store = load_vector_store(&path);
assert!(
store.vectors.contains_key("key:hello"),
"Expected 'key:hello' in persisted store"
);
assert_eq!(store.vectors["key:hello"].len(), 4);
}
#[tokio::test]
async fn test_hnsw_remove_deletes_vector() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("hnsw_vectors.json");
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let searcher = HnswSearcher::new(provider, path.clone());
searcher.index("key:a", "alpha").await.unwrap();
searcher.index("key:b", "beta").await.unwrap();
{
let store = load_vector_store(&path);
assert_eq!(store.vectors.len(), 2);
}
searcher.remove("key:a").await.unwrap();
let store = load_vector_store(&path);
assert!(
!store.vectors.contains_key("key:a"),
"key:a should be removed"
);
assert!(store.vectors.contains_key("key:b"), "key:b should remain");
}
#[tokio::test]
async fn test_hnsw_rebuild_index_search_works() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("hnsw_vectors.json");
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let searcher = HnswSearcher::new(provider, path.clone());
searcher.index("entry:0", "text-0").await.unwrap();
searcher.index("entry:1", "text-1").await.unwrap();
searcher.remove("entry:1").await.unwrap();
let scores = searcher.score_batch(&["entry:0"], "text-0").await;
assert_eq!(scores.len(), 1);
assert!(
scores[0] >= 0.0 && scores[0] <= 1.0,
"Score out of range: {}",
scores[0]
);
}
#[tokio::test]
async fn test_hnsw_empty_index_search() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("hnsw_vectors.json");
let provider = Arc::new(FakeHnswProvider { dim: 4 });
let searcher = HnswSearcher::new(provider, path);
let scores = searcher.score_batch(&["some chunk"], "query text").await;
assert_eq!(scores.len(), 1);
assert_eq!(scores[0], 0.0, "Empty index should return 0.0 scores");
}
}