use std::collections::HashMap;
use crate::memory::math::cosine_similarity;
#[derive(Debug, Default)]
pub struct VectorIndex {
entries: HashMap<i64, Vec<f32>>,
}
impl VectorIndex {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn contains(&self, id: i64) -> bool {
self.entries.contains_key(&id)
}
pub fn insert(&mut self, id: i64, embedding: Vec<f32>) {
self.entries.insert(id, embedding);
}
pub fn remove(&mut self, id: i64) {
self.entries.remove(&id);
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(i64, f32)> {
if k == 0 || query.is_empty() || self.entries.is_empty() {
return Vec::new();
}
let mut scored: Vec<(i64, f32)> = self
.entries
.iter()
.map(|(&id, vec)| (id, cosine_similarity(query, vec)))
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
scored.truncate(k);
scored
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_index_search_returns_empty() {
let index = VectorIndex::new();
let results = index.search(&[1.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn finds_inserted_vector_with_high_self_similarity() {
let mut index = VectorIndex::new();
index.insert(42, vec![1.0, 0.0, 0.0]);
let results = index.search(&[1.0, 0.0, 0.0], 5);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 42);
assert!((results[0].1 - 1.0).abs() < 1e-6, "self-sim should be ~1.0");
}
#[test]
fn orders_results_by_descending_similarity() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0, 0.0]); index.insert(2, vec![0.0, 1.0, 0.0]); index.insert(3, vec![0.7, 0.7, 0.0]); let results = index.search(&[1.0, 0.0, 0.0], 5);
let ids: Vec<i64> = results.iter().map(|(id, _)| *id).collect();
assert_eq!(ids, vec![1, 3, 2]);
for pair in results.windows(2) {
assert!(pair[0].1 >= pair[1].1);
}
}
#[test]
fn respects_k_limit() {
let mut index = VectorIndex::new();
for i in 0..10 {
index.insert(i, vec![i as f32, 1.0, 0.0]);
}
let results = index.search(&[1.0, 1.0, 0.0], 3);
assert_eq!(results.len(), 3);
}
#[test]
fn k_larger_than_count_returns_all() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0]);
index.insert(2, vec![0.0, 1.0]);
let results = index.search(&[1.0, 1.0], 100);
assert_eq!(results.len(), 2);
}
#[test]
fn k_zero_returns_empty() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0]);
assert!(index.search(&[1.0, 0.0], 0).is_empty());
}
#[test]
fn insert_with_existing_id_upserts() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0, 0.0]);
index.insert(1, vec![0.0, 1.0, 0.0]); assert_eq!(index.len(), 1);
let results = index.search(&[0.0, 1.0, 0.0], 5);
assert_eq!(results[0].0, 1);
assert!((results[0].1 - 1.0).abs() < 1e-6);
}
#[test]
fn remove_deletes_entry() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0]);
index.insert(2, vec![0.0, 1.0]);
assert!(index.contains(1));
index.remove(1);
assert!(!index.contains(1));
assert_eq!(index.len(), 1);
let ids: Vec<i64> = index
.search(&[1.0, 1.0], 5)
.iter()
.map(|(id, _)| *id)
.collect();
assert_eq!(ids, vec![2]);
}
#[test]
fn remove_missing_id_is_noop() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0]);
index.remove(999);
assert_eq!(index.len(), 1);
}
#[test]
fn empty_query_returns_empty() {
let mut index = VectorIndex::new();
index.insert(1, vec![1.0, 0.0]);
assert!(index.search(&[], 5).is_empty());
}
}