use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use tokio::sync::RwLock;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
const INITIAL_CAPACITY: usize = 1_024;
const DEFAULT_HNSW_MAX_ELEMENTS: usize = 1_000_000;
fn hnsw_max_elements() -> usize {
std::env::var("TRUSTY_MAX_CHUNKS")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_HNSW_MAX_ELEMENTS)
}
#[derive(Debug, Clone)]
pub struct VectorHit {
pub chunk_id: String,
pub score: f32,
}
#[async_trait]
#[allow(clippy::len_without_is_empty)]
pub trait VectorStore: Send + Sync {
async fn upsert(&self, id: &str, embedding: Vec<f32>) -> Result<()>;
async fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<VectorHit>>;
async fn remove(&self, id: &str) -> Result<()>;
async fn len(&self) -> Result<usize>;
async fn upsert_batch(&self, items: &[(String, Vec<f32>)]) -> Result<()> {
for (id, vec) in items {
self.upsert(id, vec.clone()).await?;
}
Ok(())
}
}
pub struct UsearchStore {
index: Arc<RwLock<Index>>,
id_to_key: Arc<RwLock<HashMap<String, u64>>>,
key_to_id: Arc<RwLock<HashMap<u64, String>>>,
next_key: Arc<AtomicU64>,
dim: usize,
}
impl UsearchStore {
pub fn new(dim: usize) -> Result<Self> {
Self::with_capacity_hint(dim, INITIAL_CAPACITY)
}
pub fn with_capacity_hint(dim: usize, expected_chunks: usize) -> Result<Self> {
let (connectivity, expansion_add, expansion_search) = if expected_chunks > 50_000 {
(32, 128, 64)
} else {
(0, 0, 0)
};
let options = IndexOptions {
dimensions: dim,
metric: MetricKind::Cos,
quantization: ScalarKind::F32,
connectivity,
expansion_add,
expansion_search,
multi: false,
};
let index = Index::new(&options).map_err(|e| anyhow!("usearch Index::new failed: {e}"))?;
let initial = expected_chunks
.max(INITIAL_CAPACITY)
.min(hnsw_max_elements());
index
.reserve(initial)
.map_err(|e| anyhow!("usearch reserve failed: {e}"))?;
Ok(Self {
index: Arc::new(RwLock::new(index)),
id_to_key: Arc::new(RwLock::new(HashMap::new())),
key_to_id: Arc::new(RwLock::new(HashMap::new())),
next_key: Arc::new(AtomicU64::new(1)), dim,
})
}
pub fn dim(&self) -> usize {
self.dim
}
fn ensure_capacity(index: &Index) -> Result<()> {
let size = index.size();
let cap = index.capacity();
let max_elem = hnsw_max_elements();
if size >= max_elem {
return Err(anyhow!(
"usearch index at TRUSTY_MAX_CHUNKS cap ({} elements) — refusing further upserts",
max_elem
));
}
if size + 1 > cap {
let mut new_cap = (cap.max(1)).saturating_mul(2);
if new_cap > max_elem {
new_cap = max_elem;
}
index
.reserve(new_cap)
.map_err(|e| anyhow!("usearch reserve grow failed: {e}"))?;
}
Ok(())
}
}
#[async_trait]
impl VectorStore for UsearchStore {
async fn upsert(&self, id: &str, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.dim {
return Err(anyhow!(
"embedding dim mismatch: got {}, expected {}",
embedding.len(),
self.dim
));
}
let key = {
let mut id_to_key = self.id_to_key.write().await;
if let Some(&existing) = id_to_key.get(id) {
existing
} else {
let key = self.next_key.fetch_add(1, Ordering::Relaxed);
id_to_key.insert(id.to_string(), key);
self.key_to_id.write().await.insert(key, id.to_string());
key
}
};
let index = self.index.write().await;
if index.contains(key) {
index
.remove(key)
.map_err(|e| anyhow!("usearch remove (for upsert) failed: {e}"))?;
}
Self::ensure_capacity(&index)?;
index
.add(key, &embedding)
.map_err(|e| anyhow!("usearch add failed: {e}"))?;
Ok(())
}
async fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<VectorHit>> {
if query.len() != self.dim {
return Err(anyhow!(
"query dim mismatch: got {}, expected {}",
query.len(),
self.dim
));
}
if top_k == 0 {
return Ok(Vec::new());
}
let matches = {
let index = self.index.read().await;
index
.search(query, top_k)
.map_err(|e| anyhow!("usearch search failed: {e}"))?
};
let key_to_id = self.key_to_id.read().await;
let mut hits = Vec::with_capacity(matches.keys.len());
for (key, dist) in matches.keys.iter().zip(matches.distances.iter()) {
if let Some(chunk_id) = key_to_id.get(key) {
let score = 1.0 - *dist;
hits.push(VectorHit {
chunk_id: chunk_id.clone(),
score,
});
}
}
Ok(hits)
}
async fn remove(&self, id: &str) -> Result<()> {
let key = {
let mut id_to_key = self.id_to_key.write().await;
match id_to_key.remove(id) {
Some(k) => k,
None => return Ok(()), }
};
self.key_to_id.write().await.remove(&key);
let index = self.index.write().await;
if index.contains(key) {
index
.remove(key)
.map_err(|e| anyhow!("usearch remove failed: {e}"))?;
}
Ok(())
}
async fn len(&self) -> Result<usize> {
Ok(self.index.read().await.size())
}
async fn upsert_batch(&self, items: &[(String, Vec<f32>)]) -> Result<()> {
if items.is_empty() {
return Ok(());
}
for (_, v) in items {
if v.len() != self.dim {
return Err(anyhow!(
"embedding dim mismatch: got {}, expected {}",
v.len(),
self.dim
));
}
}
{
let mut id_map = self.id_to_key.write().await;
let mut key_map = self.key_to_id.write().await;
for (id, _) in items {
if !id_map.contains_key(id.as_str()) {
let k = self.next_key.fetch_add(1, Ordering::Relaxed);
id_map.insert(id.clone(), k);
key_map.insert(k, id.clone());
}
}
}
let id_map = self.id_to_key.read().await;
let index = self.index.write().await;
let want = index.size() + items.len();
let max_elem = hnsw_max_elements();
if index.size() >= max_elem {
return Err(anyhow!(
"usearch index at TRUSTY_MAX_CHUNKS cap ({} elements) — refusing batch upsert",
max_elem
));
}
if want > index.capacity() {
let mut new_cap = index.capacity().max(1);
while new_cap < want {
new_cap = new_cap.saturating_mul(2);
}
if new_cap > max_elem {
new_cap = max_elem;
}
index
.reserve(new_cap)
.map_err(|e| anyhow!("usearch reserve grow failed: {e}"))?;
}
for (id, embedding) in items {
if let Some(&key) = id_map.get(id.as_str()) {
if index.contains(key) {
index
.remove(key)
.map_err(|e| anyhow!("usearch remove (for upsert) failed: {e}"))?;
}
index
.add(key, embedding)
.map_err(|e| anyhow!("usearch add failed: {e}"))?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_upsert_and_search() {
let store = UsearchStore::new(4).expect("store init");
let v = vec![1.0f32, 0.0, 0.0, 0.0];
store.upsert("chunk:a", v.clone()).await.expect("upsert a");
store
.upsert("chunk:b", vec![0.0, 1.0, 0.0, 0.0])
.await
.expect("upsert b");
store
.upsert("chunk:c", vec![0.9, 0.1, 0.0, 0.0])
.await
.expect("upsert c");
let hits = store.search(&v, 2).await.expect("search");
assert_eq!(hits.len(), 2);
assert_eq!(hits[0].chunk_id, "chunk:a");
}
#[tokio::test]
async fn test_len() {
let store = UsearchStore::new(4).expect("store init");
assert_eq!(store.len().await.unwrap(), 0);
store.upsert("x", vec![1.0, 0.0, 0.0, 0.0]).await.unwrap();
assert_eq!(store.len().await.unwrap(), 1);
}
#[tokio::test]
async fn test_remove() {
let store = UsearchStore::new(4).expect("store init");
store
.upsert("del-me", vec![1.0, 0.0, 0.0, 0.0])
.await
.unwrap();
assert_eq!(store.len().await.unwrap(), 1);
store.remove("del-me").await.unwrap();
let hits = store.search(&[1.0, 0.0, 0.0, 0.0], 5).await.unwrap();
assert!(!hits.iter().any(|h| h.chunk_id == "del-me"));
}
#[tokio::test]
async fn test_concurrent_reads() {
let store = Arc::new(UsearchStore::new(4).expect("store init"));
store.upsert("r1", vec![1.0, 0.0, 0.0, 0.0]).await.unwrap();
store.upsert("r2", vec![0.0, 1.0, 0.0, 0.0]).await.unwrap();
let s1 = store.clone();
let s2 = store.clone();
let q = vec![1.0f32, 0.0, 0.0, 0.0];
let (r1, r2) = tokio::join!(s1.search(&q, 2), s2.search(&q, 2));
assert!(!r1.unwrap().is_empty());
assert!(!r2.unwrap().is_empty());
}
#[tokio::test]
async fn test_upsert_replaces_existing() {
let store = UsearchStore::new(4).expect("store init");
store
.upsert("same", vec![1.0, 0.0, 0.0, 0.0])
.await
.unwrap();
store
.upsert("same", vec![0.0, 1.0, 0.0, 0.0])
.await
.unwrap();
assert_eq!(store.len().await.unwrap(), 1);
let hits = store.search(&[0.0, 1.0, 0.0, 0.0], 1).await.unwrap();
assert_eq!(hits[0].chunk_id, "same");
}
#[tokio::test]
async fn test_dim_mismatch_errors() {
let store = UsearchStore::new(4).expect("store init");
assert!(store.upsert("bad", vec![1.0, 0.0]).await.is_err());
assert!(store.search(&[1.0, 0.0], 1).await.is_err());
}
#[tokio::test]
async fn test_upsert_batch_inserts_all() {
let store = UsearchStore::new(4).expect("store init");
let dirs: [[f32; 4]; 4] = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let items: Vec<(String, Vec<f32>)> = (0..4)
.map(|i| (format!("k{i}"), dirs[i].to_vec()))
.collect();
store.upsert_batch(&items).await.expect("batch upsert");
assert_eq!(store.len().await.unwrap(), 4);
store.upsert_batch(&items).await.expect("re-batch upsert");
assert_eq!(store.len().await.unwrap(), 4);
let hits = store.search(&dirs[2], 1).await.unwrap();
assert_eq!(hits[0].chunk_id, "k2");
}
#[tokio::test]
async fn test_upsert_batch_empty_noop() {
let store = UsearchStore::new(4).expect("store init");
store.upsert_batch(&[]).await.unwrap();
assert_eq!(store.len().await.unwrap(), 0);
}
#[tokio::test]
async fn test_upsert_batch_dim_mismatch_errors() {
let store = UsearchStore::new(4).expect("store init");
let items = vec![("bad".to_string(), vec![1.0, 0.0])];
assert!(store.upsert_batch(&items).await.is_err());
}
#[tokio::test]
async fn test_capacity_growth() {
let store = UsearchStore::new(4).expect("store init");
for i in 0..50 {
let v = vec![i as f32, 0.0, 0.0, 0.0];
store.upsert(&format!("k{i}"), v).await.unwrap();
}
assert_eq!(store.len().await.unwrap(), 50);
}
}