use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct VectorIndex {
embeddings: HashMap<String, Vec<f32>>,
dimension: usize,
count: usize,
}
impl VectorIndex {
pub fn new(dimension: usize) -> Self {
Self {
embeddings: HashMap::new(),
dimension,
count: 0,
}
}
pub fn insert(&mut self, node_id: String, embedding: Vec<f32>) -> Result<(), Error> {
if embedding.len() != self.dimension {
return Err(Error::DimensionMismatch {
expected: self.dimension,
got: embedding.len(),
});
}
self.embeddings.insert(node_id, embedding);
self.count += 1;
Ok(())
}
pub fn insert_batch(&mut self, vectors: impl IntoIterator<Item = (String, Vec<f32>)>) -> usize {
let mut inserted = 0;
for (node_id, embedding) in vectors {
if self.insert(node_id, embedding).is_ok() {
inserted += 1;
}
}
inserted
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if query.len() != self.dimension {
return Vec::new();
}
let mut results: Vec<(String, f32)> = self
.embeddings
.iter()
.map(|(node_id, embedding)| {
let similarity = cosine_similarity(query, embedding);
(node_id.clone(), similarity)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.into_iter().take(top_k).collect()
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn remove(&mut self, node_id: &str) -> bool {
if self.embeddings.remove(node_id).is_some() {
self.count -= 1;
true
} else {
false
}
}
pub fn clear(&mut self) {
self.embeddings.clear();
self.count = 0;
}
pub fn get(&self, node_id: &str) -> Option<&Vec<f32>> {
self.embeddings.get(node_id)
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
let embeddings_size = self
.embeddings
.iter()
.map(|(k, v)| {
k.len() + std::mem::size_of::<String>() + v.len() * std::mem::size_of::<f32>() + std::mem::size_of::<Vec<f32>>()
})
.sum::<usize>();
embeddings_size + std::mem::size_of::<Self>()
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot_product = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for i in 0..a.len() {
dot_product += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
#[error("Index is empty")]
EmptyIndex,
#[error("Invalid embedding: {0}")]
InvalidEmbedding(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub node_id: String,
pub score: f32,
}
impl SearchResult {
pub fn new(node_id: String, score: f32) -> Self {
Self { node_id, score }
}
}
impl Default for VectorIndex {
fn default() -> Self {
Self::new(768) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_index_creation() {
let index = VectorIndex::new(128);
assert_eq!(index.dimension(), 128);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_vector_index_insert() {
let mut index = VectorIndex::new(3);
let result = index.insert("test".to_string(), vec![0.1, 0.2, 0.3]);
assert!(result.is_ok());
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_vector_index_dimension_mismatch() {
let mut index = VectorIndex::new(3);
let result = index.insert("test".to_string(), vec![0.1, 0.2]);
assert!(result.is_err());
}
#[test]
fn test_vector_index_search() {
let mut index = VectorIndex::new(3);
index.insert("a".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.insert("b".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
index.insert("c".to_string(), vec![0.9, 0.1, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "a"); assert_eq!(results[1].0, "c"); }
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < f32::EPSILON);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_vector_index_remove() {
let mut index = VectorIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
assert_eq!(index.len(), 1);
assert!(index.remove("test"));
assert_eq!(index.len(), 0);
assert!(!index.remove("nonexistent"));
}
#[test]
fn test_vector_index_batch_insert() {
let mut index = VectorIndex::new(3);
let vectors = vec![
("a".to_string(), vec![1.0, 0.0, 0.0]),
("b".to_string(), vec![0.0, 1.0, 0.0]),
("c".to_string(), vec![0.0, 0.0, 1.0]),
];
let inserted = index.insert_batch(vectors);
assert_eq!(inserted, 3);
assert_eq!(index.len(), 3);
}
#[test]
fn test_vector_index_get() {
let mut index = VectorIndex::new(3);
let embedding = vec![0.1, 0.2, 0.3];
index.insert("test".to_string(), embedding.clone()).unwrap();
let retrieved = index.get("test");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), &embedding);
assert!(index.get("nonexistent").is_none());
}
#[test]
fn test_vector_index_clear() {
let mut index = VectorIndex::new(3);
index.insert("a".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.insert("b".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
assert_eq!(index.len(), 2);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_search_with_zero_query() {
let mut index = VectorIndex::new(3);
index
.insert("test".to_string(), vec![0.1, 0.2, 0.3])
.unwrap();
let results = index.search(&[0.0, 0.0, 0.0], 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "test");
}
#[test]
fn test_search_empty_index() {
let index = VectorIndex::new(3);
let results = index.search(&[0.1, 0.2, 0.3], 10);
assert_eq!(results.len(), 0);
}
#[test]
fn test_search_respects_top_k() {
let mut index = VectorIndex::new(3);
for i in 0..10 {
let node_id = format!("node{}", i);
let embedding = vec![1.0 / (i + 1) as f32, 0.0, 0.0];
index.insert(node_id, embedding).unwrap();
}
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 3);
assert_eq!(results.len(), 3);
}
}