use crate::filter::{Filter, Metadata};
use crate::simd;
use crate::types::{DistanceMetric, IndexStats, SearchConfig, SearchResult};
use anyhow::{anyhow, Result};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorSearchIndex {
config: SearchConfig,
embeddings: HashMap<String, Vec<f32>>,
entity_ids: Vec<String>,
embedding_matrix: Option<Vec<Vec<f32>>>,
dimensions: usize,
is_built: bool,
metadata: HashMap<String, Metadata>,
}
impl VectorSearchIndex {
pub fn new(config: SearchConfig) -> Self {
info!(
"Initialized vector search index: metric={:?}, parallel={}",
config.metric, config.parallel
);
Self {
config,
embeddings: HashMap::new(),
entity_ids: Vec::new(),
embedding_matrix: None,
dimensions: 0,
is_built: false,
metadata: HashMap::new(),
}
}
pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
if embeddings.is_empty() {
return Err(anyhow!("Cannot build index from empty embeddings"));
}
info!(
"Building vector search index for {} entities",
embeddings.len()
);
self.embeddings = embeddings.clone();
self.entity_ids = embeddings.keys().cloned().collect();
self.dimensions = embeddings.values().next().unwrap().len();
let mut matrix = Vec::new();
for entity_id in &self.entity_ids {
let mut emb = self.embeddings[entity_id].clone();
if self.config.normalize {
Self::normalize_vector(&mut emb);
}
matrix.push(emb);
}
self.embedding_matrix = Some(matrix);
self.is_built = true;
info!("Vector search index built successfully");
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
if query.len() != self.dimensions {
return Err(anyhow!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
self.dimensions
));
}
let mut normalized_query = query.to_vec();
if self.config.normalize {
Self::normalize_vector(&mut normalized_query);
}
debug!("Searching for {} nearest neighbors", k);
self.exact_search(&normalized_query, k)
}
fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
let matrix = self.embedding_matrix.as_ref().unwrap();
let scores: Vec<(usize, f32)> = if self.config.parallel {
(0..self.entity_ids.len())
.into_par_iter()
.map(|i| {
let score = self.compute_similarity(query, &matrix[i]);
(i, score)
})
.collect()
} else {
(0..self.entity_ids.len())
.map(|i| {
let score = self.compute_similarity(query, &matrix[i]);
(i, score)
})
.collect()
};
let mut sorted_scores = scores;
sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<SearchResult> = sorted_scores
.iter()
.take(k.min(self.entity_ids.len()))
.enumerate()
.map(|(rank, &(idx, score))| SearchResult {
entity_id: self.entity_ids[idx].clone(),
score,
distance: self.score_to_distance(score),
rank: rank + 1,
})
.collect();
debug!("Found {} results", results.len());
Ok(results)
}
pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
info!("Batch searching for {} queries", queries.len());
let results: Vec<Vec<SearchResult>> = if self.config.parallel {
queries
.par_iter()
.map(|query| self.search(query, k).unwrap_or_default())
.collect()
} else {
queries
.iter()
.map(|query| self.search(query, k).unwrap_or_default())
.collect()
};
Ok(results)
}
#[inline]
fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
simd::compute_distance_simd(self.config.metric, a, b)
}
#[inline]
fn score_to_distance(&self, score: f32) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => 1.0 - score, DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, DistanceMetric::DotProduct => -score,
}
}
#[inline]
fn normalize_vector(vec: &mut [f32]) {
simd::normalize_vector_simd(vec);
}
pub fn get_stats(&self) -> IndexStats {
IndexStats {
num_entities: self.entity_ids.len(),
dimensions: self.dimensions,
is_built: self.is_built,
metric: self.config.metric,
}
}
pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
let all_results = self.search(query, self.entity_ids.len())?;
Ok(all_results
.into_iter()
.filter(|r| r.distance <= radius)
.collect())
}
pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
self.metadata.insert(entity_id.to_string(), metadata);
}
pub fn set_metadata_batch(&mut self, metadata_map: HashMap<String, Metadata>) {
self.metadata.extend(metadata_map);
}
pub fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
self.metadata.get(entity_id)
}
pub fn filtered_search(
&self,
query: &[f32],
k: usize,
filter: &Filter,
) -> Result<Vec<SearchResult>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
if filter.is_empty() {
return self.search(query, k);
}
debug!(
"Filtered search: k={}, filter conditions={}",
k,
filter.conditions().len()
);
let all_results = self.search(query, self.entity_ids.len())?;
let filtered: Vec<SearchResult> = all_results
.into_iter()
.filter(|r| {
self.metadata
.get(&r.entity_id)
.is_some_and(|m| filter.matches(m))
})
.take(k)
.enumerate()
.map(|(i, mut r)| {
r.rank = i + 1; r
})
.collect();
debug!("Filtered search returned {} results", filtered.len());
Ok(filtered)
}
pub fn prefiltered_search(
&self,
query: &[f32],
k: usize,
filter: &Filter,
) -> Result<Vec<SearchResult>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
if query.len() != self.dimensions {
return Err(anyhow!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
self.dimensions
));
}
if filter.is_empty() {
return self.search(query, k);
}
debug!("Pre-filtered search: k={}", k);
let mut normalized_query = query.to_vec();
if self.config.normalize {
Self::normalize_vector(&mut normalized_query);
}
let matrix = self.embedding_matrix.as_ref().unwrap();
let matching_indices: Vec<usize> = (0..self.entity_ids.len())
.filter(|&i| {
self.metadata
.get(&self.entity_ids[i])
.is_some_and(|m| filter.matches(m))
})
.collect();
if matching_indices.is_empty() {
return Ok(Vec::new());
}
let scores: Vec<(usize, f32)> = if self.config.parallel {
matching_indices
.par_iter()
.map(|&i| {
let score = self.compute_similarity(&normalized_query, &matrix[i]);
(i, score)
})
.collect()
} else {
matching_indices
.iter()
.map(|&i| {
let score = self.compute_similarity(&normalized_query, &matrix[i]);
(i, score)
})
.collect()
};
let mut sorted_scores = scores;
sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<SearchResult> = sorted_scores
.iter()
.take(k)
.enumerate()
.map(|(rank, &(idx, score))| SearchResult {
entity_id: self.entity_ids[idx].clone(),
score,
distance: self.score_to_distance(score),
rank: rank + 1,
})
.collect();
debug!("Pre-filtered search returned {} results", results.len());
Ok(results)
}
pub fn add_vector(&mut self, entity_id: String, mut embedding: Vec<f32>) -> Result<()> {
if self.is_built && embedding.len() != self.dimensions {
return Err(anyhow!(
"Vector dimension {} doesn't match index dimension {}",
embedding.len(),
self.dimensions
));
}
if self.embeddings.contains_key(&entity_id) {
return Err(anyhow!("Entity '{}' already exists in index", entity_id));
}
if !self.is_built {
self.dimensions = embedding.len();
}
if self.config.normalize {
Self::normalize_vector(&mut embedding);
}
self.embeddings.insert(entity_id.clone(), embedding.clone());
self.entity_ids.push(entity_id);
if let Some(ref mut matrix) = self.embedding_matrix {
matrix.push(embedding);
} else {
self.embedding_matrix = Some(vec![embedding]);
}
self.is_built = true;
debug!("Added vector to index (total: {})", self.entity_ids.len());
Ok(())
}
pub fn add_vectors(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
if embeddings.is_empty() {
return Ok(());
}
info!("Adding {} vectors to index", embeddings.len());
for (entity_id, embedding) in embeddings {
if self.is_built && embedding.len() != self.dimensions {
return Err(anyhow!(
"Vector dimension {} doesn't match index dimension {}",
embedding.len(),
self.dimensions
));
}
if self.embeddings.contains_key(entity_id) {
return Err(anyhow!("Entity '{}' already exists in index", entity_id));
}
}
if !self.is_built {
self.dimensions = embeddings.values().next().unwrap().len();
}
for (entity_id, embedding) in embeddings {
let mut emb = embedding.clone();
if self.config.normalize {
Self::normalize_vector(&mut emb);
}
self.embeddings.insert(entity_id.clone(), emb.clone());
self.entity_ids.push(entity_id.clone());
if let Some(ref mut matrix) = self.embedding_matrix {
matrix.push(emb);
} else {
self.embedding_matrix = Some(vec![emb]);
}
}
self.is_built = true;
info!(
"Added vectors successfully (total: {})",
self.entity_ids.len()
);
Ok(())
}
pub fn remove_vector(&mut self, entity_id: &str) -> Result<()> {
if !self.embeddings.contains_key(entity_id) {
return Err(anyhow!("Entity '{}' not found in index", entity_id));
}
let idx = self
.entity_ids
.iter()
.position(|id| id == entity_id)
.ok_or_else(|| anyhow!("Entity '{}' not found in entity_ids", entity_id))?;
self.embeddings.remove(entity_id);
self.entity_ids.remove(idx);
if let Some(ref mut matrix) = self.embedding_matrix {
matrix.remove(idx);
}
self.metadata.remove(entity_id);
if self.embeddings.is_empty() {
self.is_built = false;
self.dimensions = 0;
}
debug!(
"Removed vector from index (remaining: {})",
self.entity_ids.len()
);
Ok(())
}
pub fn remove_vectors(&mut self, entity_ids: &[&str]) -> Result<()> {
info!("Removing {} vectors from index", entity_ids.len());
for entity_id in entity_ids {
self.remove_vector(entity_id)?;
}
info!(
"Removed vectors successfully (remaining: {})",
self.entity_ids.len()
);
Ok(())
}
pub fn update_vector(&mut self, entity_id: &str, mut new_embedding: Vec<f32>) -> Result<()> {
if !self.embeddings.contains_key(entity_id) {
return Err(anyhow!("Entity '{}' not found in index", entity_id));
}
if new_embedding.len() != self.dimensions {
return Err(anyhow!(
"Vector dimension {} doesn't match index dimension {}",
new_embedding.len(),
self.dimensions
));
}
if self.config.normalize {
Self::normalize_vector(&mut new_embedding);
}
let idx = self
.entity_ids
.iter()
.position(|id| id == entity_id)
.ok_or_else(|| anyhow!("Entity '{}' not found in entity_ids", entity_id))?;
self.embeddings
.insert(entity_id.to_string(), new_embedding.clone());
if let Some(ref mut matrix) = self.embedding_matrix {
matrix[idx] = new_embedding;
}
debug!("Updated vector in index: {}", entity_id);
Ok(())
}
pub fn clear(&mut self) {
self.embeddings.clear();
self.entity_ids.clear();
self.embedding_matrix = None;
self.metadata.clear();
self.dimensions = 0;
self.is_built = false;
info!("Index cleared");
}
#[inline]
pub fn len(&self) -> usize {
self.entity_ids.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.entity_ids.is_empty()
}
#[inline]
pub fn contains(&self, entity_id: &str) -> bool {
self.embeddings.contains_key(entity_id)
}
#[inline]
pub fn get_vector(&self, entity_id: &str) -> Option<&Vec<f32>> {
self.embeddings.get(entity_id)
}
pub fn merge(&mut self, other: &VectorSearchIndex, overwrite_duplicates: bool) -> Result<()> {
if !other.is_built {
return Err(anyhow!("Cannot merge from an unbuilt index"));
}
if !self.is_built && other.is_built {
self.dimensions = other.dimensions;
}
if self.is_built && other.is_built && self.dimensions != other.dimensions {
return Err(anyhow!(
"Cannot merge indexes with different dimensions: {} vs {}",
self.dimensions,
other.dimensions
));
}
info!("Merging index with {} vectors", other.entity_ids.len());
let mut added = 0;
let mut updated = 0;
let mut skipped = 0;
for entity_id in &other.entity_ids {
let embedding = &other.embeddings[entity_id];
if self.embeddings.contains_key(entity_id) {
if overwrite_duplicates {
self.update_vector(entity_id, embedding.clone())?;
updated += 1;
} else {
skipped += 1;
}
} else {
self.add_vector(entity_id.clone(), embedding.clone())?;
added += 1;
}
if let Some(metadata) = other.metadata.get(entity_id) {
self.metadata.insert(entity_id.clone(), metadata.clone());
}
}
info!(
"Merge complete: added={}, updated={}, skipped={}",
added, updated, skipped
);
Ok(())
}
pub fn merge_multiple(indexes: &[&VectorSearchIndex]) -> Result<VectorSearchIndex> {
if indexes.is_empty() {
return Err(anyhow!("Cannot merge zero indexes"));
}
let first_built = indexes
.iter()
.find(|idx| idx.is_built)
.ok_or_else(|| anyhow!("At least one index must be built"))?;
let dimensions = first_built.dimensions;
let config = first_built.config.clone();
for (i, index) in indexes.iter().enumerate() {
if index.is_built && index.dimensions != dimensions {
return Err(anyhow!(
"Index {} has incompatible dimensions: {} vs {}",
i,
index.dimensions,
dimensions
));
}
}
info!("Merging {} indexes into one", indexes.len());
let mut all_embeddings = HashMap::new();
let mut all_metadata = HashMap::new();
for index in indexes {
for (entity_id, embedding) in &index.embeddings {
all_embeddings.insert(entity_id.clone(), embedding.clone());
}
for (entity_id, metadata) in &index.metadata {
all_metadata.insert(entity_id.clone(), metadata.clone());
}
}
let mut merged = VectorSearchIndex::new(config);
merged.build(&all_embeddings)?;
merged.metadata = all_metadata;
info!("Merged index contains {} vectors", merged.len());
Ok(merged)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::filter::FilterValue;
fn create_test_embeddings() -> HashMap<String, Vec<f32>> {
let mut embeddings = HashMap::new();
embeddings.insert("entity1".to_string(), vec![1.0, 0.0, 0.0]);
embeddings.insert("entity2".to_string(), vec![0.9, 0.1, 0.0]);
embeddings.insert("entity3".to_string(), vec![0.0, 1.0, 0.0]);
embeddings.insert("entity4".to_string(), vec![0.0, 0.0, 1.0]);
embeddings.insert("entity5".to_string(), vec![0.7, 0.7, 0.0]);
embeddings
}
fn create_test_metadata() -> HashMap<String, Metadata> {
let mut metadata = HashMap::new();
let mut m1 = HashMap::new();
m1.insert(
"type".to_string(),
FilterValue::String("article".to_string()),
);
m1.insert("year".to_string(), FilterValue::Int(2023));
metadata.insert("entity1".to_string(), m1);
let mut m2 = HashMap::new();
m2.insert(
"type".to_string(),
FilterValue::String("article".to_string()),
);
m2.insert("year".to_string(), FilterValue::Int(2022));
metadata.insert("entity2".to_string(), m2);
let mut m3 = HashMap::new();
m3.insert("type".to_string(), FilterValue::String("book".to_string()));
m3.insert("year".to_string(), FilterValue::Int(2023));
metadata.insert("entity3".to_string(), m3);
let mut m4 = HashMap::new();
m4.insert("type".to_string(), FilterValue::String("book".to_string()));
m4.insert("year".to_string(), FilterValue::Int(2021));
metadata.insert("entity4".to_string(), m4);
let mut m5 = HashMap::new();
m5.insert(
"type".to_string(),
FilterValue::String("article".to_string()),
);
m5.insert("year".to_string(), FilterValue::Int(2024));
metadata.insert("entity5".to_string(), m5);
metadata
}
#[test]
fn test_index_creation() {
let config = SearchConfig::default();
let index = VectorSearchIndex::new(config);
assert!(!index.is_built);
assert_eq!(index.dimensions, 0);
}
#[test]
fn test_index_building() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
assert!(index.build(&embeddings).is_ok());
assert!(index.is_built);
assert_eq!(index.dimensions, 3);
assert_eq!(index.entity_ids.len(), 5);
}
#[test]
fn test_search() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].entity_id == "entity1" || results[0].entity_id == "entity2");
}
#[test]
fn test_batch_search() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let results = index.batch_search(&queries, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 2);
assert_eq!(results[1].len(), 2);
}
#[test]
fn test_get_stats() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let stats = index.get_stats();
assert_eq!(stats.num_entities, 5);
assert_eq!(stats.dimensions, 3);
assert!(stats.is_built);
assert_eq!(stats.metric, DistanceMetric::Cosine);
}
#[test]
fn test_set_and_get_metadata() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let mut metadata = HashMap::new();
metadata.insert(
"type".to_string(),
FilterValue::String("article".to_string()),
);
index.set_metadata("entity1", metadata.clone());
let retrieved = index.get_metadata("entity1");
assert!(retrieved.is_some());
assert_eq!(
retrieved.unwrap().get("type"),
Some(&FilterValue::String("article".to_string()))
);
}
#[test]
fn test_filtered_search() {
let embeddings = create_test_embeddings();
let metadata = create_test_metadata();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
index.set_metadata_batch(metadata);
let filter = Filter::new().eq("type", "article");
let query = vec![1.0, 0.0, 0.0];
let results = index.filtered_search(&query, 5, &filter).unwrap();
assert_eq!(results.len(), 3);
for result in &results {
let meta = index.get_metadata(&result.entity_id).unwrap();
assert_eq!(
meta.get("type"),
Some(&FilterValue::String("article".to_string()))
);
}
}
#[test]
fn test_filtered_search_with_year() {
let embeddings = create_test_embeddings();
let metadata = create_test_metadata();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
index.set_metadata_batch(metadata);
let filter = Filter::new().gte("year", 2023i64);
let query = vec![1.0, 0.0, 0.0];
let results = index.filtered_search(&query, 5, &filter).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_prefiltered_search() {
let embeddings = create_test_embeddings();
let metadata = create_test_metadata();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
index.set_metadata_batch(metadata);
let filter = Filter::new().eq("type", "book");
let query = vec![0.0, 1.0, 0.0]; let results = index.prefiltered_search(&query, 5, &filter).unwrap();
assert_eq!(results.len(), 2);
for result in &results {
let meta = index.get_metadata(&result.entity_id).unwrap();
assert_eq!(
meta.get("type"),
Some(&FilterValue::String("book".to_string()))
);
}
}
#[test]
fn test_filtered_search_empty_filter() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let filter = Filter::new();
let query = vec![1.0, 0.0, 0.0];
let results = index.filtered_search(&query, 3, &filter).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_filtered_search_no_matches() {
let embeddings = create_test_embeddings();
let metadata = create_test_metadata();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
index.set_metadata_batch(metadata);
let filter = Filter::new().eq("type", "journal");
let query = vec![1.0, 0.0, 0.0];
let results = index.filtered_search(&query, 5, &filter).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_add_vector() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let initial_len = index.len();
let result = index.add_vector("entity6".to_string(), vec![0.5, 0.5, 0.5]);
assert!(result.is_ok());
assert_eq!(index.len(), initial_len + 1);
assert!(index.contains("entity6"));
let query = vec![0.5, 0.5, 0.5];
let results = index.search(&query, 1).unwrap();
assert_eq!(results[0].entity_id, "entity6");
}
#[test]
fn test_add_vector_duplicate() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let result = index.add_vector("entity1".to_string(), vec![0.5, 0.5, 0.5]);
assert!(result.is_err());
}
#[test]
fn test_add_vector_dimension_mismatch() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let result = index.add_vector("entity6".to_string(), vec![0.5, 0.5]); assert!(result.is_err());
}
#[test]
fn test_add_vectors() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let initial_len = index.len();
let mut new_embeddings = HashMap::new();
new_embeddings.insert("entity6".to_string(), vec![0.5, 0.5, 0.5]);
new_embeddings.insert("entity7".to_string(), vec![0.6, 0.6, 0.6]);
let result = index.add_vectors(&new_embeddings);
assert!(result.is_ok());
assert_eq!(index.len(), initial_len + 2);
assert!(index.contains("entity6"));
assert!(index.contains("entity7"));
}
#[test]
fn test_remove_vector() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let initial_len = index.len();
let result = index.remove_vector("entity1");
assert!(result.is_ok());
assert_eq!(index.len(), initial_len - 1);
assert!(!index.contains("entity1"));
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert!(!results.iter().any(|r| r.entity_id == "entity1"));
}
#[test]
fn test_remove_vector_not_found() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let result = index.remove_vector("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_remove_vectors() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let initial_len = index.len();
let result = index.remove_vectors(&["entity1", "entity2"]);
assert!(result.is_ok());
assert_eq!(index.len(), initial_len - 2);
assert!(!index.contains("entity1"));
assert!(!index.contains("entity2"));
}
#[test]
fn test_update_vector() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let new_embedding = vec![0.9, 0.9, 0.9];
let result = index.update_vector("entity1", new_embedding.clone());
assert!(result.is_ok());
let retrieved = index.get_vector("entity1").unwrap();
let mut expected = new_embedding.clone();
VectorSearchIndex::normalize_vector(&mut expected);
assert_eq!(retrieved.len(), expected.len());
for (a, b) in retrieved.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_update_vector_not_found() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let result = index.update_vector("nonexistent", vec![0.5, 0.5, 0.5]);
assert!(result.is_err());
}
#[test]
fn test_update_vector_dimension_mismatch() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let result = index.update_vector("entity1", vec![0.5, 0.5]); assert!(result.is_err());
}
#[test]
fn test_clear() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
assert!(!index.is_empty());
assert!(index.is_built);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
assert!(!index.is_built);
assert_eq!(index.dimensions, 0);
}
#[test]
fn test_get_vector() {
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let vector = index.get_vector("entity1");
assert!(vector.is_some());
let vector = index.get_vector("nonexistent");
assert!(vector.is_none());
}
#[test]
fn test_incremental_build() {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index
.add_vector("entity1".to_string(), vec![1.0, 0.0, 0.0])
.unwrap();
index
.add_vector("entity2".to_string(), vec![0.0, 1.0, 0.0])
.unwrap();
index
.add_vector("entity3".to_string(), vec![0.0, 0.0, 1.0])
.unwrap();
assert_eq!(index.len(), 3);
assert!(index.is_built);
assert_eq!(index.dimensions, 3);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results[0].entity_id, "entity1");
}
#[test]
fn test_merge_indexes() {
let embeddings1 = create_test_embeddings();
let mut index1 = VectorSearchIndex::new(SearchConfig::default());
index1.build(&embeddings1).unwrap();
let mut embeddings2 = HashMap::new();
embeddings2.insert("entity6".to_string(), vec![0.6, 0.6, 0.0]);
embeddings2.insert("entity7".to_string(), vec![0.7, 0.7, 0.0]);
let mut index2 = VectorSearchIndex::new(SearchConfig::default());
index2.build(&embeddings2).unwrap();
let initial_len = index1.len();
let result = index1.merge(&index2, false);
assert!(result.is_ok());
assert_eq!(index1.len(), initial_len + 2);
assert!(index1.contains("entity6"));
assert!(index1.contains("entity7"));
}
#[test]
fn test_merge_with_duplicates_skip() {
let embeddings1 = create_test_embeddings();
let mut index1 = VectorSearchIndex::new(SearchConfig::default());
index1.build(&embeddings1).unwrap();
let mut embeddings2 = HashMap::new();
embeddings2.insert("entity1".to_string(), vec![0.9, 0.9, 0.9]); embeddings2.insert("entity6".to_string(), vec![0.6, 0.6, 0.0]);
let mut index2 = VectorSearchIndex::new(SearchConfig::default());
index2.build(&embeddings2).unwrap();
let initial_len = index1.len();
let result = index1.merge(&index2, false);
assert!(result.is_ok());
assert_eq!(index1.len(), initial_len + 1); }
#[test]
fn test_merge_with_duplicates_overwrite() {
let embeddings1 = create_test_embeddings();
let mut index1 = VectorSearchIndex::new(SearchConfig::default());
index1.build(&embeddings1).unwrap();
let mut embeddings2 = HashMap::new();
embeddings2.insert("entity1".to_string(), vec![0.9, 0.9, 0.9]);
let mut index2 = VectorSearchIndex::new(SearchConfig::default());
index2.build(&embeddings2).unwrap();
let initial_len = index1.len();
let result = index1.merge(&index2, true);
assert!(result.is_ok());
assert_eq!(index1.len(), initial_len);
let vector = index1.get_vector("entity1").unwrap();
let mut expected = vec![0.9, 0.9, 0.9];
VectorSearchIndex::normalize_vector(&mut expected);
for (a, b) in vector.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_merge_dimension_mismatch() {
let embeddings1 = create_test_embeddings(); let mut index1 = VectorSearchIndex::new(SearchConfig::default());
index1.build(&embeddings1).unwrap();
let mut embeddings2 = HashMap::new();
embeddings2.insert("entity6".to_string(), vec![0.6, 0.6]);
let mut index2 = VectorSearchIndex::new(SearchConfig::default());
index2.build(&embeddings2).unwrap();
let result = index1.merge(&index2, false);
assert!(result.is_err());
}
#[test]
fn test_merge_multiple_indexes() {
let mut index1 = VectorSearchIndex::new(SearchConfig::default());
let mut embeddings1 = HashMap::new();
embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
embeddings1.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
index1.build(&embeddings1).unwrap();
let mut index2 = VectorSearchIndex::new(SearchConfig::default());
let mut embeddings2 = HashMap::new();
embeddings2.insert("doc3".to_string(), vec![0.0, 1.0, 0.0]);
embeddings2.insert("doc4".to_string(), vec![0.1, 0.9, 0.0]);
index2.build(&embeddings2).unwrap();
let mut index3 = VectorSearchIndex::new(SearchConfig::default());
let mut embeddings3 = HashMap::new();
embeddings3.insert("doc5".to_string(), vec![0.0, 0.0, 1.0]);
index3.build(&embeddings3).unwrap();
let merged = VectorSearchIndex::merge_multiple(&[&index1, &index2, &index3]);
assert!(merged.is_ok());
let merged = merged.unwrap();
assert_eq!(merged.len(), 5);
assert!(merged.contains("doc1"));
assert!(merged.contains("doc3"));
assert!(merged.contains("doc5"));
let query = vec![1.0, 0.0, 0.0];
let results = merged.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_merge_multiple_empty() {
let result = VectorSearchIndex::merge_multiple(&[]);
assert!(result.is_err());
}
#[cfg(test)]
mod proptest_tests {
use super::*;
use proptest::prelude::*;
fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
proptest::collection::vec(-1.0f32..1.0f32, dim..=dim)
}
fn embeddings_strategy(
count: usize,
dim: usize,
) -> impl Strategy<Value = HashMap<String, Vec<f32>>> {
proptest::collection::vec(
(
proptest::string::string_regex("[a-z0-9]{5,10}").unwrap(),
vector_strategy(dim),
),
count..=count,
)
.prop_map(|pairs| pairs.into_iter().collect())
}
proptest! {
#[test]
fn prop_search_never_panics(
embeddings in embeddings_strategy(10, 128),
query in vector_strategy(128),
k in 1usize..10
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let _ = index.search(&query, k);
}
#[test]
fn prop_search_respects_k(
embeddings in embeddings_strategy(20, 64),
query in vector_strategy(64),
k in 1usize..15
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let results = index.search(&query, k).unwrap();
prop_assert!(results.len() <= k);
prop_assert!(results.len() <= embeddings.len());
}
#[test]
fn prop_search_results_sorted(
embeddings in embeddings_strategy(15, 32),
query in vector_strategy(32),
k in 2usize..10
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let results = index.search(&query, k).unwrap();
for i in 1..results.len() {
prop_assert!(results[i-1].score >= results[i].score,
"Results not sorted: {} < {}", results[i-1].score, results[i].score);
}
}
#[test]
fn prop_search_ranks_consecutive(
embeddings in embeddings_strategy(10, 16),
query in vector_strategy(16),
k in 1usize..8
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let results = index.search(&query, k).unwrap();
for (i, result) in results.iter().enumerate() {
prop_assert_eq!(result.rank, i + 1);
}
}
#[test]
fn prop_search_deterministic(
embeddings in embeddings_strategy(12, 48),
query in vector_strategy(48),
k in 1usize..10
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let results1 = index.search(&query, k).unwrap();
let results2 = index.search(&query, k).unwrap();
prop_assert_eq!(results1.len(), results2.len());
for (r1, r2) in results1.iter().zip(results2.iter()) {
prop_assert_eq!(&r1.entity_id, &r2.entity_id);
prop_assert!((r1.score - r2.score).abs() < 1e-6);
}
}
#[test]
fn prop_batch_search_count(
embeddings in embeddings_strategy(10, 32),
num_queries in 1usize..5,
k in 1usize..8
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let queries: Vec<Vec<f32>> = (0..num_queries)
.map(|i| vec![i as f32; 32])
.collect();
let results = index.batch_search(&queries, k).unwrap();
prop_assert_eq!(results.len(), num_queries);
for result_set in results {
prop_assert!(result_set.len() <= k);
}
}
#[test]
fn prop_distance_non_negative(
embeddings in embeddings_strategy(8, 24),
query in vector_strategy(24),
k in 1usize..6
) {
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let results = index.search(&query, k).unwrap();
for result in results {
prop_assert!(result.distance >= 0.0,
"Negative distance: {}", result.distance);
}
}
#[test]
fn prop_empty_embeddings_fail(_dim in 1usize..128) {
let embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let mut index = VectorSearchIndex::new(SearchConfig::default());
prop_assert!(index.build(&embeddings).is_err());
}
#[test]
fn prop_dimension_mismatch_fail(
embeddings in embeddings_strategy(5, 64),
wrong_dim in 1usize..128
) {
prop_assume!(wrong_dim != 64);
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
let query = vec![0.0; wrong_dim];
let result = index.search(&query, 3);
prop_assert!(result.is_err());
}
#[test]
fn prop_normalize_unit_norm(mut vec in vector_strategy(128)) {
VectorSearchIndex::normalize_vector(&mut vec);
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
prop_assert!((norm - 1.0).abs() < 1e-5, "Norm: {}", norm);
}
}
}
}