use std::collections::HashMap;
use std::path::Path;
use std::sync::RwLock;
use crate::error::{Error, Result};
use crate::index::VectorIndex;
use uuid::Uuid;
pub struct UsearchIndex {
index: RwLock<usearch::Index>,
uuid_to_key: RwLock<HashMap<Uuid, u64>>,
key_to_uuid: RwLock<HashMap<u64, Uuid>>,
next_key: RwLock<u64>,
dimensions: usize,
}
impl UsearchIndex {
pub fn new(dimensions: usize) -> Result<Self> {
let opts = usearch::IndexOptions {
dimensions,
metric: usearch::MetricKind::Cos,
quantization: usearch::ScalarKind::F32,
..Default::default()
};
let index = usearch::Index::new(&opts).map_err(|e| Error::Index(e.to_string()))?;
index
.reserve(10_000)
.map_err(|e| Error::Index(e.to_string()))?;
Ok(Self {
index: RwLock::new(index),
uuid_to_key: RwLock::new(HashMap::new()),
key_to_uuid: RwLock::new(HashMap::new()),
next_key: RwLock::new(0),
dimensions,
})
}
fn allocate_key(&self, id: Uuid) -> u64 {
let mut next = self.next_key.write().unwrap_or_else(|e| e.into_inner());
let key = *next;
*next += 1;
self.uuid_to_key
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(id, key);
self.key_to_uuid
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(key, id);
key
}
fn rollback_key(&self, id: Uuid, key: u64) {
self.uuid_to_key
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&id);
self.key_to_uuid
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&key);
}
}
impl VectorIndex for UsearchIndex {
fn add(&self, id: Uuid, vector: &[f32]) -> Result<()> {
if vector.len() != self.dimensions {
return Err(Error::Validation(format!(
"expected {} dimensions, got {}",
self.dimensions,
vector.len()
)));
}
if self
.uuid_to_key
.read()
.unwrap_or_else(|e| e.into_inner())
.contains_key(&id)
{
self.remove(id)?;
}
let key = self.allocate_key(id);
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
if index.size() >= index.capacity() {
index
.reserve(index.capacity() + 10_000)
.map_err(|e| Error::Index(e.to_string()))?;
}
if let Err(e) = index.add(key, vector) {
drop(index);
self.rollback_key(id, key);
return Err(Error::Index(e.to_string()));
}
Ok(())
}
fn remove(&self, id: Uuid) -> Result<()> {
let key = {
let map = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
match map.get(&id) {
Some(&k) => k,
None => return Ok(()),
}
};
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
index.remove(key).map_err(|e| Error::Index(e.to_string()))?;
self.uuid_to_key
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&id);
self.key_to_uuid
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&key);
Ok(())
}
fn search(&self, query: &[f32], limit: usize) -> Result<Vec<(Uuid, f32)>> {
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
let results = index
.search(query, limit)
.map_err(|e| Error::Index(e.to_string()))?;
let key_map = self.key_to_uuid.read().unwrap_or_else(|e| e.into_inner());
let mut output = Vec::new();
for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
if let Some(&uuid) = key_map.get(key) {
output.push((uuid, *distance));
}
}
Ok(output)
}
fn filtered_search(
&self,
query: &[f32],
limit: usize,
filter: &dyn Fn(Uuid) -> bool,
) -> Result<Vec<(Uuid, f32)>> {
let index_size = self.len();
if index_size == 0 {
return Ok(Vec::new());
}
let mut oversample = (limit * 3).max(1);
loop {
let results = self.search(query, oversample.min(index_size))?;
let filtered: Vec<(Uuid, f32)> = results
.into_iter()
.filter(|(uuid, _)| filter(*uuid))
.take(limit)
.collect();
if filtered.len() >= limit || oversample >= index_size {
return Ok(filtered);
}
oversample = (oversample * 2).min(index_size);
}
}
fn save(&self, path: &Path) -> Result<()> {
let path_str = path
.to_str()
.ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
index
.save(path_str)
.map_err(|e| Error::Index(e.to_string()))?;
let mappings_path = path.with_extension("mappings.json");
let uuid_to_key = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
let next_key = *self.next_key.read().unwrap_or_else(|e| e.into_inner());
let data = serde_json::json!({
"uuid_to_key": uuid_to_key.iter().map(|(k, v)| (k.to_string(), v)).collect::<HashMap<String, &u64>>(),
"next_key": next_key,
});
let json_str = serde_json::to_string(&data).map_err(|e| Error::Index(e.to_string()))?;
std::fs::write(&mappings_path, json_str).map_err(|e| Error::Index(e.to_string()))?;
Ok(())
}
fn load(&self, path: &Path) -> Result<()> {
let path_str = path
.to_str()
.ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
index
.load(path_str)
.map_err(|e| Error::Index(e.to_string()))?;
let mappings_path = path.with_extension("mappings.json");
if mappings_path.exists() {
let data =
std::fs::read_to_string(&mappings_path).map_err(|e| Error::Index(e.to_string()))?;
let parsed: serde_json::Value =
serde_json::from_str(&data).map_err(|e| Error::Index(e.to_string()))?;
let mut uuid_to_key = self.uuid_to_key.write().unwrap_or_else(|e| e.into_inner());
let mut key_to_uuid = self.key_to_uuid.write().unwrap_or_else(|e| e.into_inner());
let mut next_key = self.next_key.write().unwrap_or_else(|e| e.into_inner());
uuid_to_key.clear();
key_to_uuid.clear();
if let Some(map) = parsed["uuid_to_key"].as_object() {
for (uuid_str, key_val) in map {
let uuid =
Uuid::parse_str(uuid_str).map_err(|e| Error::Index(e.to_string()))?;
let key = key_val.as_u64().ok_or_else(|| {
Error::Index(format!("invalid key value for UUID {uuid_str}"))
})?;
uuid_to_key.insert(uuid, key);
key_to_uuid.insert(key, uuid);
}
}
if let Some(nk) = parsed["next_key"].as_u64() {
*next_key = nk;
}
}
Ok(())
}
fn len(&self) -> usize {
let index = self.index.read().unwrap_or_else(|e| e.into_inner());
index.size()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vector(dims: usize, seed: u64) -> Vec<f32> {
let mut v = Vec::with_capacity(dims);
let mut x = seed;
for _ in 0..dims {
x = x.wrapping_mul(6364136223846793005).wrapping_add(1);
v.push((x as f32) / (u64::MAX as f32));
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
#[test]
fn test_add_and_search() {
let index = UsearchIndex::new(128).unwrap();
let mut ids = Vec::new();
let mut vectors = Vec::new();
for i in 0..100 {
let id = Uuid::now_v7();
let vec = random_vector(128, i);
index.add(id, &vec).unwrap();
ids.push(id);
vectors.push(vec);
}
assert_eq!(index.len(), 100);
let results = index.search(&vectors[0], 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, ids[0]);
}
#[test]
fn test_remove() {
let index = UsearchIndex::new(128).unwrap();
let id = Uuid::now_v7();
let vec = random_vector(128, 42);
index.add(id, &vec).unwrap();
assert_eq!(index.len(), 1);
index.remove(id).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_filtered_search() {
let index = UsearchIndex::new(128).unwrap();
let mut ids = Vec::new();
for i in 0..50 {
let id = Uuid::now_v7();
let vec = random_vector(128, i);
index.add(id, &vec).unwrap();
ids.push(id);
}
let excluded: std::collections::HashSet<Uuid> = ids.iter().step_by(2).copied().collect();
let query = random_vector(128, 0);
let results = index
.filtered_search(&query, 10, &|id| !excluded.contains(&id))
.unwrap();
for (id, _) in &results {
assert!(!excluded.contains(id));
}
}
#[test]
fn test_save_and_load() {
let dir = std::env::temp_dir().join(format!("usearch_test_{}", Uuid::now_v7()));
std::fs::create_dir_all(&dir).unwrap();
let index_path = dir.join("test.usearch");
let index = UsearchIndex::new(128).unwrap();
let id1 = Uuid::now_v7();
let id2 = Uuid::now_v7();
index.add(id1, &random_vector(128, 1)).unwrap();
index.add(id2, &random_vector(128, 2)).unwrap();
index.save(&index_path).unwrap();
let index2 = UsearchIndex::new(128).unwrap();
index2.load(&index_path).unwrap();
assert_eq!(index2.len(), 2);
let results = index2.search(&random_vector(128, 1), 1).unwrap();
assert_eq!(results[0].0, id1);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn test_dimension_mismatch() {
let index = UsearchIndex::new(128).unwrap();
let result = index.add(Uuid::now_v7(), &vec![0.1; 64]);
assert!(result.is_err());
}
}