use std::collections::HashMap;
use uuid::Uuid;
use crate::embedding::cosine_similarity;
#[derive(Debug, thiserror::Error)]
pub enum VectorError {
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
#[error("Index error: {0}")]
IndexError(String),
#[error("Not found: {0}")]
NotFound(Uuid),
}
pub struct VectorIndex {
entries: HashMap<Uuid, Vec<f32>>,
dimensions: usize,
}
impl VectorIndex {
pub fn new(dimensions: usize, _capacity_hint: usize) -> Self {
Self {
entries: HashMap::new(),
dimensions,
}
}
pub fn add(&mut self, id: Uuid, embedding: &[f32]) -> Result<(), VectorError> {
if embedding.len() != self.dimensions {
return Err(VectorError::DimensionMismatch {
expected: self.dimensions,
got: embedding.len(),
});
}
self.entries.insert(id, embedding.to_vec());
Ok(())
}
pub fn search(&self, query: &[f32], limit: usize) -> Result<Vec<(Uuid, f32)>, VectorError> {
if query.len() != self.dimensions {
return Err(VectorError::DimensionMismatch {
expected: self.dimensions,
got: query.len(),
});
}
let mut scored: Vec<(Uuid, f32)> = self
.entries
.iter()
.map(|(id, emb)| (*id, cosine_similarity(query, emb)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
pub fn remove(&mut self, id: Uuid) -> Result<(), VectorError> {
if self.entries.remove(&id).is_some() {
Ok(())
} else {
Err(VectorError::NotFound(id))
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_search() {
let mut index = VectorIndex::new(4, 100);
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
index.add(id1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(id2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, id1); assert!(results[0].1 > results[1].1); }
#[test]
fn test_remove() {
let mut index = VectorIndex::new(4, 100);
let id = Uuid::new_v4();
index.add(id, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.remove(id).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_dimension_mismatch() {
let mut index = VectorIndex::new(4, 100);
let result = index.add(Uuid::new_v4(), &[1.0, 0.0]); assert!(result.is_err());
}
#[test]
fn test_empty_search() {
let index = VectorIndex::new(4, 100);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_remove_not_found() {
let mut index = VectorIndex::new(4, 100);
let result = index.remove(Uuid::new_v4());
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn test_search_respects_limit() {
let mut index = VectorIndex::new(2, 100);
for _ in 0..10 {
index.add(Uuid::new_v4(), &[1.0, 0.0]).unwrap();
}
let results = index.search(&[1.0, 0.0], 3).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_add_replaces_existing() {
let mut index = VectorIndex::new(2, 100);
let id = Uuid::new_v4();
index.add(id, &[1.0, 0.0]).unwrap();
index.add(id, &[0.0, 1.0]).unwrap();
assert_eq!(index.len(), 1);
let results = index.search(&[0.0, 1.0], 1).unwrap();
assert_eq!(results[0].0, id);
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[test]
fn test_search_dimension_mismatch() {
let index = VectorIndex::new(4, 100);
let result = index.search(&[1.0, 0.0], 5);
assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
}
#[test]
fn test_is_empty() {
let mut index = VectorIndex::new(2, 10);
assert!(index.is_empty());
let id = Uuid::new_v4();
index.add(id, &[1.0, 0.0]).unwrap();
assert!(!index.is_empty());
index.remove(id).unwrap();
assert!(index.is_empty());
}
#[test]
fn test_hnsw_search_finds_nearest() {
let mut index = VectorIndex::new(3, 10);
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let id3 = Uuid::new_v4();
index.add(id1, &[1.0, 0.0, 0.0]).unwrap(); index.add(id2, &[0.0, 1.0, 0.0]).unwrap(); index.add(id3, &[0.5, 0.5, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results[0].0, id1);
assert_eq!(results[1].0, id3);
}
#[test]
fn test_hnsw_remove() {
let mut index = VectorIndex::new(3, 10);
let id = Uuid::new_v4();
index.add(id, &[1.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.remove(id).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_hnsw_handles_large_k() {
let mut index = VectorIndex::new(3, 10);
let id = Uuid::new_v4();
index.add(id, &[1.0, 0.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 100).unwrap();
assert_eq!(results.len(), 1);
}
}