use crate::error::{Error, Result};
use crate::index::traits::{DistanceType, SearchResult, VectorIndex};
use ahash::AHashMap;
use parking_lot::RwLock;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct IndexInfo {
pub name: String,
pub dimension: usize,
pub distance_type: DistanceType,
pub size: usize,
pub memory_bytes: usize,
}
#[derive(Debug, Clone)]
pub struct MultiIndexResult {
pub index_name: String,
pub results: Vec<SearchResult>,
}
#[derive(Debug, Clone, Default)]
pub struct MultiIndexResults {
pub by_index: Vec<MultiIndexResult>,
pub total_count: usize,
}
impl MultiIndexResults {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, index_name: String, results: Vec<SearchResult>) {
self.total_count += results.len();
self.by_index.push(MultiIndexResult {
index_name,
results,
});
}
#[must_use]
pub fn flatten(&self) -> Vec<(String, SearchResult)> {
self.by_index
.iter()
.flat_map(|mir| {
mir.results
.iter()
.cloned()
.map(|r| (mir.index_name.clone(), r))
})
.collect()
}
}
#[derive(Debug, Default)]
pub struct IndexRegistry {
indexes: AHashMap<String, Box<dyn VectorIndex>>,
}
impl IndexRegistry {
#[must_use]
pub fn new() -> Self {
Self {
indexes: AHashMap::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
indexes: AHashMap::with_capacity(capacity),
}
}
pub fn register<I: VectorIndex + 'static>(
&mut self,
name: impl Into<String>,
index: I,
) -> Result<()> {
let name = name.into();
if self.indexes.contains_key(&name) {
return Err(Error::DuplicateIndex { name });
}
self.indexes.insert(name, Box::new(index));
Ok(())
}
pub fn register_or_replace<I: VectorIndex + 'static>(
&mut self,
name: impl Into<String>,
index: I,
) -> Option<Box<dyn VectorIndex>> {
self.indexes.insert(name.into(), Box::new(index))
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&dyn VectorIndex> {
self.indexes.get(name).map(AsRef::as_ref)
}
pub fn remove(&mut self, name: &str) -> Option<Box<dyn VectorIndex>> {
self.indexes.remove(name)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.indexes.contains_key(name)
}
#[must_use]
pub fn list(&self) -> Vec<&str> {
self.indexes.keys().map(String::as_str).collect()
}
#[must_use]
pub fn info(&self) -> Vec<IndexInfo> {
self.indexes
.iter()
.map(|(name, index)| IndexInfo {
name: name.clone(),
dimension: index.dimension(),
distance_type: index.distance_type(),
size: index.len(),
memory_bytes: index.memory_usage(),
})
.collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.indexes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.indexes.is_empty()
}
#[must_use]
pub fn total_vectors(&self) -> usize {
self.indexes.values().map(|i| i.len()).sum()
}
#[must_use]
pub fn total_memory(&self) -> usize {
self.indexes.values().map(|i| i.memory_usage()).sum()
}
pub fn search(&self, name: &str, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
let index = self.indexes.get(name).ok_or_else(|| Error::IndexNotFound {
name: name.to_string(),
})?;
index.search(query, k)
}
pub fn search_all(&self, query: &[f32], k: usize) -> Result<MultiIndexResults> {
let mut results = MultiIndexResults::new();
for (name, index) in &self.indexes {
if index.dimension() != query.len() {
continue;
}
let index_results = index.search(query, k)?;
results.add(name.clone(), index_results);
}
Ok(results)
}
pub fn search_indexes(
&self,
names: &[&str],
query: &[f32],
k: usize,
) -> Result<MultiIndexResults> {
let mut results = MultiIndexResults::new();
for name in names {
let index = self.indexes.get(*name).ok_or_else(|| Error::IndexNotFound {
name: (*name).to_string(),
})?;
if index.dimension() != query.len() {
return Err(Error::DimensionMismatch {
expected: index.dimension(),
got: query.len(),
});
}
let index_results = index.search(query, k)?;
results.add((*name).to_string(), index_results);
}
Ok(results)
}
pub fn add(&mut self, index_name: &str, id: String, vector: &[f32]) -> Result<()> {
let index = self
.indexes
.get_mut(index_name)
.ok_or_else(|| Error::IndexNotFound {
name: index_name.to_string(),
})?;
index.add(id, vector)
}
pub fn remove_vector(&mut self, index_name: &str, id: &str) -> Result<bool> {
let index = self
.indexes
.get_mut(index_name)
.ok_or_else(|| Error::IndexNotFound {
name: index_name.to_string(),
})?;
index.remove(id)
}
pub fn clear_all(&mut self) {
for index in self.indexes.values_mut() {
index.clear();
}
}
}
pub type SharedRegistry = Arc<RwLock<IndexRegistry>>;
#[must_use]
pub fn shared_registry() -> SharedRegistry {
Arc::new(RwLock::new(IndexRegistry::new()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{FlatIndex, IndexConfig};
fn create_test_index(dim: usize) -> FlatIndex {
FlatIndex::new(IndexConfig::new(dim))
}
#[test]
fn test_register_and_get() {
let mut registry = IndexRegistry::new();
let index = create_test_index(128);
registry.register("test", index).unwrap();
assert!(registry.contains("test"));
assert!(!registry.contains("other"));
assert_eq!(registry.len(), 1);
let retrieved = registry.get("test").unwrap();
assert_eq!(retrieved.dimension(), 128);
}
#[test]
fn test_duplicate_register_error() {
let mut registry = IndexRegistry::new();
registry.register("test", create_test_index(128)).unwrap();
let result = registry.register("test", create_test_index(256));
assert!(result.is_err());
}
#[test]
fn test_register_or_replace() {
let mut registry = IndexRegistry::new();
let old = registry.register_or_replace("test", create_test_index(128));
assert!(old.is_none());
let old = registry.register_or_replace("test", create_test_index(256));
assert!(old.is_some());
assert_eq!(old.unwrap().dimension(), 128);
assert_eq!(registry.get("test").unwrap().dimension(), 256);
}
#[test]
fn test_remove() {
let mut registry = IndexRegistry::new();
registry.register("test", create_test_index(128)).unwrap();
let removed = registry.remove("test");
assert!(removed.is_some());
assert_eq!(removed.unwrap().dimension(), 128);
assert!(registry.is_empty());
}
#[test]
fn test_list_and_info() {
let mut registry = IndexRegistry::new();
registry.register("a", create_test_index(128)).unwrap();
registry.register("b", create_test_index(256)).unwrap();
let names = registry.list();
assert_eq!(names.len(), 2);
assert!(names.contains(&"a"));
assert!(names.contains(&"b"));
let info = registry.info();
assert_eq!(info.len(), 2);
}
#[test]
fn test_search_specific_index() {
let mut registry = IndexRegistry::new();
let mut index = create_test_index(4);
index.add("v1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add("v2".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
registry.register("test", index).unwrap();
let query = [1.0, 0.0, 0.0, 0.0];
let results = registry.search("test", &query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "v1"); }
#[test]
fn test_search_nonexistent_index() {
let registry = IndexRegistry::new();
let result = registry.search("nonexistent", &[1.0], 1);
assert!(result.is_err());
}
#[test]
fn test_search_all() {
let mut registry = IndexRegistry::new();
let mut index1 = create_test_index(4);
index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
let mut index2 = create_test_index(4);
index2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
registry.register("index1", index1).unwrap();
registry.register("index2", index2).unwrap();
let query = [0.5, 0.5, 0.0, 0.0];
let results = registry.search_all(&query, 10).unwrap();
assert_eq!(results.by_index.len(), 2);
assert_eq!(results.total_count, 2);
}
#[test]
fn test_search_all_skips_incompatible_dimensions() {
let mut registry = IndexRegistry::new();
let mut index1 = create_test_index(4);
index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
let mut index2 = create_test_index(8); index2
.add("b1".to_string(), &[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
registry.register("index1", index1).unwrap();
registry.register("index2", index2).unwrap();
let query = [0.5, 0.5, 0.0, 0.0];
let results = registry.search_all(&query, 10).unwrap();
assert_eq!(results.by_index.len(), 1);
assert_eq!(results.by_index[0].index_name, "index1");
}
#[test]
fn test_search_indexes() {
let mut registry = IndexRegistry::new();
let mut index1 = create_test_index(4);
index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
let mut index2 = create_test_index(4);
index2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
let mut index3 = create_test_index(4);
index3.add("c1".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
registry.register("idx1", index1).unwrap();
registry.register("idx2", index2).unwrap();
registry.register("idx3", index3).unwrap();
let query = [0.5, 0.5, 0.0, 0.0];
let results = registry
.search_indexes(&["idx1", "idx2"], &query, 10)
.unwrap();
assert_eq!(results.by_index.len(), 2);
assert_eq!(results.total_count, 2);
}
#[test]
fn test_add_to_index() {
let mut registry = IndexRegistry::new();
registry.register("test", create_test_index(4)).unwrap();
registry
.add("test", "v1".to_string(), &[1.0, 0.0, 0.0, 0.0])
.unwrap();
assert_eq!(registry.get("test").unwrap().len(), 1);
}
#[test]
fn test_multi_index_results_flatten() {
let mut results = MultiIndexResults::new();
results.add(
"idx1".to_string(),
vec![SearchResult::new("a".to_string(), 0.5, DistanceType::L2)],
);
results.add(
"idx2".to_string(),
vec![SearchResult::new("b".to_string(), 0.3, DistanceType::L2)],
);
let flat = results.flatten();
assert_eq!(flat.len(), 2);
assert_eq!(flat[0].0, "idx1");
assert_eq!(flat[0].1.id, "a");
assert_eq!(flat[1].0, "idx2");
assert_eq!(flat[1].1.id, "b");
}
#[test]
fn test_total_vectors_and_memory() {
let mut registry = IndexRegistry::new();
let mut index1 = create_test_index(4);
index1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
index1.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
let mut index2 = create_test_index(4);
index2.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
registry.register("idx1", index1).unwrap();
registry.register("idx2", index2).unwrap();
assert_eq!(registry.total_vectors(), 3);
assert!(registry.total_memory() > 0);
}
#[test]
fn test_clear_all() {
let mut registry = IndexRegistry::new();
let mut index1 = create_test_index(4);
index1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
let mut index2 = create_test_index(4);
index2.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
registry.register("idx1", index1).unwrap();
registry.register("idx2", index2).unwrap();
assert_eq!(registry.total_vectors(), 2);
registry.clear_all();
assert_eq!(registry.total_vectors(), 0);
assert_eq!(registry.len(), 2); }
#[test]
fn test_shared_registry() {
let registry = shared_registry();
{
let mut reg = registry.write();
reg.register("test", create_test_index(128)).unwrap();
}
{
let reg = registry.read();
assert!(reg.contains("test"));
}
}
}