use std::collections::HashMap;
use std::path::Path;
use crate::vector::{
ClusterId, ConcurrentVectorStorage, MmapVectorStorage, Score, SegmentOrdinal, VectorDimension,
VectorError, VectorId, assign_to_nearest_centroid, cosine_similarity, kmeans_clustering,
};
const MIN_CLUSTERS: usize = 1;
const MAX_CLUSTERS: usize = 100;
#[derive(Debug)]
pub struct VectorSearchEngine {
storage: ConcurrentVectorStorage,
cluster_assignments: HashMap<VectorId, ClusterId>,
centroids: Vec<Vec<f32>>,
dimension: VectorDimension,
}
impl VectorSearchEngine {
#[must_use = "The created VectorSearchEngine instance should be used for indexing and searching"]
pub fn new(
storage_path: impl AsRef<Path>,
dimension: VectorDimension,
) -> Result<Self, VectorError> {
let mmap_storage = MmapVectorStorage::new(storage_path.as_ref(), SegmentOrdinal::new(0), dimension)
.map_err(|e| VectorError::Storage(std::io::Error::other(
format!("Failed to create storage: {e}. Check that the directory exists and you have write permissions")
)))?;
let storage = ConcurrentVectorStorage::new(mmap_storage);
Ok(Self {
storage,
cluster_assignments: HashMap::new(),
centroids: Vec::new(),
dimension,
})
}
pub fn index_vectors(&mut self, vectors: &[(VectorId, Vec<f32>)]) -> Result<(), VectorError> {
if vectors.is_empty() {
return Ok(());
}
for (_, vec) in vectors {
self.dimension.validate_vector(vec)?;
}
let vector_refs: Vec<(VectorId, &[f32])> = vectors
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
self.storage.write_batch(&vector_refs).map_err(|e| {
VectorError::Storage(std::io::Error::other(format!(
"Failed to store vectors: {e}. Check disk space and file permissions"
)))
})?;
let vecs: Vec<Vec<f32>> = vectors.iter().map(|(_, v)| v.clone()).collect();
let k = (vecs.len() as f32).sqrt().ceil() as usize;
let k = k.clamp(MIN_CLUSTERS, MAX_CLUSTERS);
let clustering_result = kmeans_clustering(&vecs, k)
.map_err(|e| VectorError::ClusteringFailed(e.to_string()))?;
self.centroids = clustering_result.centroids;
self.cluster_assignments.clear();
for (i, (id, _)) in vectors.iter().enumerate() {
self.cluster_assignments
.insert(*id, clustering_result.assignments[i]);
}
Ok(())
}
#[must_use = "Search results should be processed to retrieve relevant vectors"]
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(VectorId, Score)>, VectorError> {
self.dimension.validate_vector(query)?;
if self.centroids.is_empty() {
return Ok(Vec::new());
}
let centroid_refs: Vec<&[f32]> = self.centroids.iter().map(|c| c.as_slice()).collect();
let nearest_cluster = assign_to_nearest_centroid(query, ¢roid_refs);
let mut candidates = Vec::new();
for (vector_id, cluster_id) in &self.cluster_assignments {
if *cluster_id == nearest_cluster {
if let Some(vector) = self.storage.read_vector(*vector_id) {
let similarity = cosine_similarity(query, &vector);
if let Ok(score) = Score::new(similarity) {
candidates.push((*vector_id, score));
}
}
}
}
candidates.sort_by(|a, b| b.1.cmp(&a.1));
candidates.truncate(k);
Ok(candidates)
}
#[must_use = "The cluster assignment should be used for cluster-aware operations"]
pub fn get_cluster_for_vector(&self, id: VectorId) -> Option<ClusterId> {
self.cluster_assignments.get(&id).copied()
}
#[must_use]
pub fn as_centroids(&self) -> &[Vec<f32>] {
&self.centroids
}
#[must_use]
pub fn vector_count(&self) -> usize {
self.cluster_assignments.len()
}
#[must_use]
pub fn dimension(&self) -> VectorDimension {
self.dimension
}
#[must_use]
pub fn get_all_cluster_assignments(&self) -> Vec<(VectorId, ClusterId)> {
self.cluster_assignments
.iter()
.map(|(id, cluster)| (*id, *cluster))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_vectors(n: usize, dim: usize) -> Vec<(VectorId, Vec<f32>)> {
(1..=n)
.map(|i| {
let id = VectorId::new(i as u32).unwrap();
let mut vec = vec![0.0; dim];
let angle = (i as f32 - 1.0) * std::f32::consts::PI * 2.0 / n as f32;
vec[0] = angle.cos();
vec[1] = angle.sin();
#[allow(clippy::needless_range_loop)]
for j in 2..dim.min(10) {
vec[j] = ((i * j) as f32 / (n * dim) as f32).sin();
}
let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
let vec: Vec<f32> = vec.iter().map(|x| x / norm).collect();
(id, vec)
})
.collect()
}
#[test]
fn test_engine_creation() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
assert!(engine.centroids.is_empty());
assert!(engine.cluster_assignments.is_empty());
}
#[test]
fn test_index_and_search() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let mut engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
let vectors = create_test_vectors(20, 128);
engine.index_vectors(&vectors).unwrap();
assert!(!engine.centroids.is_empty());
assert_eq!(engine.cluster_assignments.len(), 20);
let mut found_count = 0;
for (query_id, query_vec) in &vectors {
let results = engine.search(query_vec, 5).unwrap();
if results.iter().any(|(id, _)| id == query_id) {
found_count += 1;
}
}
assert!(found_count > 0, "Should find at least some vectors");
let query = &vectors[0].1;
let results = engine.search(query, 10).unwrap();
for i in 1..results.len() {
assert!(
results[i - 1].1 >= results[i].1,
"Results should be sorted by score"
);
}
}
#[test]
fn test_empty_index_search() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
let query = vec![0.5; 128];
let results = engine.search(&query, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let mut engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
let bad_vectors = vec![
(VectorId::new(1).unwrap(), vec![0.5; 64]), ];
let result = engine.index_vectors(&bad_vectors);
assert!(result.is_err());
let bad_query = vec![0.5; 64];
let result = engine.search(&bad_query, 5);
assert!(result.is_err());
}
#[test]
fn test_cluster_assignment_lookup() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let mut engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
let vectors = create_test_vectors(10, 128);
engine.index_vectors(&vectors).unwrap();
for (id, _) in &vectors {
let cluster = engine.get_cluster_for_vector(*id);
assert!(cluster.is_some());
}
let non_existent = VectorId::new(999).unwrap();
assert!(engine.get_cluster_for_vector(non_existent).is_none());
}
#[test]
fn test_search_returns_sorted_results() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(128).unwrap();
let mut engine = VectorSearchEngine::new(temp_dir.path(), dimension).unwrap();
let vectors = create_test_vectors(50, 128);
engine.index_vectors(&vectors).unwrap();
let query = vectors[25].1.clone();
let results = engine.search(&query, 10).unwrap();
for i in 1..results.len() {
assert!(results[i - 1].1 >= results[i].1);
}
}
}