use alloc::string::String;
use alloc::vec::Vec;
use hashbrown::HashMap;
use lazy_static::lazy_static;
use spin::Mutex;
use super::hnsw::{HnswIndex, HnswStats};
use super::types::{
DistanceMetric, IndexParams, IndexType, SearchFilter, VectorError, VectorIndexMeta,
VectorSearchResult,
};
struct DatasetIndex {
hnsw: HnswIndex,
model_id: u32,
model_name: String,
dirty: bool,
}
impl DatasetIndex {
fn new(model_name: &str, params: &IndexParams) -> Self {
let hnsw = HnswIndex::with_params(
params.hnsw_m,
params.hnsw_ef_construction,
params.hnsw_ef_search,
params.metric,
);
Self {
hnsw,
model_id: Self::hash_model_name(model_name),
model_name: String::from(model_name),
dirty: false,
}
}
fn hash_model_name(name: &str) -> u32 {
let mut hash = 0u32;
for byte in name.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
}
hash
}
fn metadata(&self) -> VectorIndexMeta {
let stats = self.hnsw.stats();
VectorIndexMeta {
index_type: IndexType::Hnsw,
vector_count: stats.vector_count as u64,
dimensions: stats.dimensions as u32,
distance_metric: stats.metric,
hnsw_m: stats.m as u16,
hnsw_ef_construction: stats.ef_construction as u16,
max_layer: stats.layer_count as u8,
entry_point: stats.entry_point.unwrap_or(0),
}
}
}
pub struct VectorEngine {
indexes: HashMap<String, DatasetIndex>,
default_params: IndexParams,
}
impl VectorEngine {
pub fn new() -> Self {
Self {
indexes: HashMap::new(),
default_params: IndexParams::default(),
}
}
pub fn set_default_params(&mut self, params: IndexParams) {
self.default_params = params;
}
fn get_or_create_index(&mut self, dataset: &str, model: &str) -> &mut DatasetIndex {
if !self.indexes.contains_key(dataset) {
let index = DatasetIndex::new(model, &self.default_params);
self.indexes.insert(String::from(dataset), index);
}
self.indexes.get_mut(dataset).unwrap()
}
pub fn store_embedding(
&mut self,
dataset: &str,
object_id: u64,
model: &str,
embedding: &[f32],
) -> Result<(), VectorError> {
let index = self.get_or_create_index(dataset, model);
let model_id = DatasetIndex::hash_model_name(model);
if !index.hnsw.is_empty() && index.model_id != model_id {
return Err(VectorError::NotSupported(
"Cannot mix embeddings from different models in same dataset".into(),
));
}
index.hnsw.insert(object_id, embedding)?;
index.dirty = true;
Ok(())
}
pub fn search(
&self,
dataset: &str,
query: &[f32],
k: usize,
) -> Result<Vec<VectorSearchResult>, VectorError> {
let index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
if index.hnsw.is_empty() {
return Err(VectorError::EmptyIndex);
}
Ok(index.hnsw.search(query, k))
}
pub fn search_filtered(
&self,
dataset: &str,
query: &[f32],
k: usize,
filter: &SearchFilter,
) -> Result<Vec<VectorSearchResult>, VectorError> {
let index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
if index.hnsw.is_empty() {
return Err(VectorError::EmptyIndex);
}
let candidates = index.hnsw.search(query, k * 10);
let filtered: Vec<_> = candidates
.into_iter()
.filter(|r| filter.matches_basic(r))
.take(k)
.collect();
Ok(filtered)
}
pub fn search_with_ef(
&self,
dataset: &str,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<VectorSearchResult>, VectorError> {
let index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
if index.hnsw.is_empty() {
return Err(VectorError::EmptyIndex);
}
Ok(index.hnsw.search_with_ef(query, k, ef))
}
pub fn find_similar(
&self,
dataset: &str,
object_id: u64,
k: usize,
) -> Result<Vec<VectorSearchResult>, VectorError> {
let index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
let query = index
.hnsw
.get_vector(object_id)
.ok_or(VectorError::ObjectNotFound(object_id))?;
let mut results = index.hnsw.search(query, k + 1);
results.retain(|r| r.object_id != object_id);
results.truncate(k);
Ok(results)
}
pub fn batch_store(
&mut self,
dataset: &str,
model: &str,
embeddings: &[(u64, Vec<f32>)],
) -> Result<usize, VectorError> {
let mut count = 0;
for (id, embedding) in embeddings {
self.store_embedding(dataset, *id, model, embedding)?;
count += 1;
}
Ok(count)
}
pub fn delete(&mut self, dataset: &str, object_id: u64) -> Result<(), VectorError> {
let index = self
.indexes
.get_mut(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
index.hnsw.delete(object_id)?;
index.dirty = true;
Ok(())
}
pub fn get_vector(&self, dataset: &str, object_id: u64) -> Result<Vec<f32>, VectorError> {
let index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
index
.hnsw
.get_vector(object_id)
.map(|v| v.to_vec())
.ok_or(VectorError::ObjectNotFound(object_id))
}
pub fn contains(&self, dataset: &str, object_id: u64) -> bool {
self.indexes
.get(dataset)
.map(|idx| idx.hnsw.contains(object_id))
.unwrap_or(false)
}
pub fn get_index_meta(&self, dataset: &str) -> Option<VectorIndexMeta> {
self.indexes.get(dataset).map(|idx| idx.metadata())
}
pub fn get_stats(&self, dataset: &str) -> Option<HnswStats> {
self.indexes.get(dataset).map(|idx| idx.hnsw.stats())
}
pub fn list_datasets(&self) -> Vec<String> {
self.indexes.keys().cloned().collect()
}
pub fn has_index(&self, dataset: &str) -> bool {
self.indexes.contains_key(dataset)
}
pub fn remove_index(&mut self, dataset: &str) -> bool {
self.indexes.remove(dataset).is_some()
}
pub fn set_ef_search(&mut self, dataset: &str, ef: usize) -> Result<(), VectorError> {
let index = self
.indexes
.get_mut(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
index.hnsw.set_ef_search(ef);
Ok(())
}
pub fn rebuild_index(
&mut self,
dataset: &str,
params: &IndexParams,
) -> Result<(), VectorError> {
let old_index = self
.indexes
.get(dataset)
.ok_or_else(|| VectorError::DatasetNotFound(String::from(dataset)))?;
let ids = old_index.hnsw.get_ids();
let vectors: Vec<_> = ids
.iter()
.filter_map(|&id| old_index.hnsw.get_vector(id).map(|v| (id, v.to_vec())))
.collect();
let model_name = old_index.model_name.clone();
let mut new_index = DatasetIndex::new(&model_name, params);
for (id, embedding) in vectors {
new_index.hnsw.insert(id, &embedding)?;
}
new_index.dirty = true;
self.indexes.insert(String::from(dataset), new_index);
Ok(())
}
pub fn total_vectors(&self) -> usize {
self.indexes.values().map(|idx| idx.hnsw.len()).sum()
}
pub fn mark_clean(&mut self) {
for index in self.indexes.values_mut() {
index.dirty = false;
}
}
pub fn is_dirty(&self) -> bool {
self.indexes.values().any(|idx| idx.dirty)
}
}
impl Default for VectorEngine {
fn default() -> Self {
Self::new()
}
}
lazy_static! {
static ref VECTOR_ENGINE: Mutex<VectorEngine> = Mutex::new(VectorEngine::new());
}
pub fn store_embedding(
dataset: &str,
object_id: u64,
model: &str,
embedding: &[f32],
) -> Result<(), VectorError> {
VECTOR_ENGINE
.lock()
.store_embedding(dataset, object_id, model, embedding)
}
pub fn search(
dataset: &str,
query: &[f32],
k: usize,
) -> Result<Vec<VectorSearchResult>, VectorError> {
VECTOR_ENGINE.lock().search(dataset, query, k)
}
pub fn search_filtered(
dataset: &str,
query: &[f32],
k: usize,
filter: &SearchFilter,
) -> Result<Vec<VectorSearchResult>, VectorError> {
VECTOR_ENGINE
.lock()
.search_filtered(dataset, query, k, filter)
}
pub fn find_similar(
dataset: &str,
object_id: u64,
k: usize,
) -> Result<Vec<VectorSearchResult>, VectorError> {
VECTOR_ENGINE.lock().find_similar(dataset, object_id, k)
}
pub fn batch_store(
dataset: &str,
model: &str,
embeddings: &[(u64, Vec<f32>)],
) -> Result<usize, VectorError> {
VECTOR_ENGINE.lock().batch_store(dataset, model, embeddings)
}
pub fn delete_embedding(dataset: &str, object_id: u64) -> Result<(), VectorError> {
VECTOR_ENGINE.lock().delete(dataset, object_id)
}
pub fn get_embedding(dataset: &str, object_id: u64) -> Result<Vec<f32>, VectorError> {
VECTOR_ENGINE.lock().get_vector(dataset, object_id)
}
pub fn get_index_meta(dataset: &str) -> Option<VectorIndexMeta> {
VECTOR_ENGINE.lock().get_index_meta(dataset)
}
pub fn get_index_stats(dataset: &str) -> Option<HnswStats> {
VECTOR_ENGINE.lock().get_stats(dataset)
}
pub fn list_datasets() -> Vec<String> {
VECTOR_ENGINE.lock().list_datasets()
}
pub fn rebuild_index(dataset: &str, params: &IndexParams) -> Result<(), VectorError> {
VECTOR_ENGINE.lock().rebuild_index(dataset, params)
}
pub fn set_ef_search(dataset: &str, ef: usize) -> Result<(), VectorError> {
VECTOR_ENGINE.lock().set_ef_search(dataset, ef)
}
pub fn total_vectors() -> usize {
VECTOR_ENGINE.lock().total_vectors()
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = seed;
let mut v: Vec<f32> = (0..dim)
.map(|_| {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
((rng >> 33) as f32 / (1u64 << 31) as f32) * 2.0 - 1.0
})
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
v
}
#[test]
fn test_vector_engine_basic() {
let mut engine = VectorEngine::new();
let embedding = vec![1.0, 0.0, 0.0, 0.0];
engine
.store_embedding("test/dataset", 1, "test-model", &embedding)
.unwrap();
assert!(engine.contains("test/dataset", 1));
assert!(!engine.contains("test/dataset", 2));
let retrieved = engine.get_vector("test/dataset", 1).unwrap();
assert_eq!(retrieved, embedding);
}
#[test]
fn test_vector_engine_search() {
let mut engine = VectorEngine::new();
engine
.store_embedding("test", 1, "model", &[1.0, 0.0, 0.0, 0.0])
.unwrap();
engine
.store_embedding("test", 2, "model", &[0.9, 0.1, 0.0, 0.0])
.unwrap();
engine
.store_embedding("test", 3, "model", &[0.0, 1.0, 0.0, 0.0])
.unwrap();
let results = engine.search("test", &[1.0, 0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].object_id, 1);
}
#[test]
fn test_vector_engine_find_similar() {
let mut engine = VectorEngine::new();
engine
.store_embedding("test", 1, "model", &[1.0, 0.0, 0.0])
.unwrap();
engine
.store_embedding("test", 2, "model", &[0.9, 0.1, 0.0])
.unwrap();
engine
.store_embedding("test", 3, "model", &[0.0, 1.0, 0.0])
.unwrap();
let similar = engine.find_similar("test", 1, 2).unwrap();
assert!(similar.iter().all(|r| r.object_id != 1));
assert_eq!(similar[0].object_id, 2);
}
#[test]
fn test_vector_engine_batch() {
let mut engine = VectorEngine::new();
let embeddings = vec![
(1, vec![1.0, 0.0, 0.0]),
(2, vec![0.0, 1.0, 0.0]),
(3, vec![0.0, 0.0, 1.0]),
];
let count = engine.batch_store("test", "model", &embeddings).unwrap();
assert_eq!(count, 3);
assert!(engine.contains("test", 1));
assert!(engine.contains("test", 2));
assert!(engine.contains("test", 3));
}
#[test]
fn test_vector_engine_delete() {
let mut engine = VectorEngine::new();
engine
.store_embedding("test", 1, "model", &[1.0, 0.0, 0.0])
.unwrap();
engine
.store_embedding("test", 2, "model", &[0.0, 1.0, 0.0])
.unwrap();
assert!(engine.contains("test", 1));
engine.delete("test", 1).unwrap();
assert!(!engine.contains("test", 1));
assert!(engine.contains("test", 2));
}
#[test]
fn test_vector_engine_metadata() {
let mut engine = VectorEngine::new();
for i in 0..10 {
let v = random_vector(64, i);
engine.store_embedding("test", i, "model", &v).unwrap();
}
let meta = engine.get_index_meta("test").unwrap();
assert_eq!(meta.vector_count, 10);
assert_eq!(meta.dimensions, 64);
let stats = engine.get_stats("test").unwrap();
assert_eq!(stats.vector_count, 10);
}
#[test]
fn test_vector_engine_rebuild() {
let mut engine = VectorEngine::new();
for i in 0..5 {
let v = random_vector(32, i);
engine.store_embedding("test", i, "model", &v).unwrap();
}
let params = IndexParams {
hnsw_m: 32,
hnsw_ef_construction: 400,
..Default::default()
};
engine.rebuild_index("test", ¶ms).unwrap();
for i in 0..5 {
assert!(engine.contains("test", i as u64));
}
let stats = engine.get_stats("test").unwrap();
assert_eq!(stats.m, 32);
}
#[test]
fn test_vector_engine_multiple_datasets() {
let mut engine = VectorEngine::new();
engine
.store_embedding("photos", 1, "clip", &[1.0, 0.0])
.unwrap();
engine
.store_embedding("documents", 1, "bert", &[0.0, 1.0])
.unwrap();
assert!(engine.has_index("photos"));
assert!(engine.has_index("documents"));
assert!(!engine.has_index("other"));
let datasets = engine.list_datasets();
assert_eq!(datasets.len(), 2);
assert!(datasets.contains(&String::from("photos")));
assert!(datasets.contains(&String::from("documents")));
}
#[test]
fn test_vector_engine_error_cases() {
let engine = VectorEngine::new();
let result = engine.search("nonexistent", &[1.0, 0.0], 10);
assert!(matches!(result, Err(VectorError::DatasetNotFound(_))));
}
#[test]
fn test_vector_engine_dirty_tracking() {
let mut engine = VectorEngine::new();
assert!(!engine.is_dirty());
engine
.store_embedding("test", 1, "model", &[1.0, 0.0, 0.0])
.unwrap();
assert!(engine.is_dirty());
engine.mark_clean();
assert!(!engine.is_dirty());
}
}