use crate::{Vector, VectorIndex, SearchResult, SimilarityMetric};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlatIndex {
pub dim: usize,
pub data: Vec<Vector>,
}
impl FlatIndex {
pub fn new(dim: usize, data: Vec<Vector>) -> Self {
Self {
dim,
data,
}
}
pub fn max_id(&self) -> Option<u64> {
self.data.iter().map(|v| v.id).max()
}
}
impl VectorIndex for FlatIndex {
fn add(&mut self, vector: Vector) -> Result<(), String> {
if vector.values.len() != self.dim {
return Err("Vector dimension mismatch".to_string());
}
if self.data.iter().any(|e| e.id == vector.id) {
return Err(format!("Vector ID {} already exists", vector.id));
}
self.data.push(vector);
Ok(())
}
fn delete(&mut self, id: u64) -> Result<(), String> {
self.data.retain(|e| e.id != id);
Ok(())
}
fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Vec<SearchResult> {
let mut similarities: Vec<_> = self.data
.iter()
.map(|e| SearchResult {
id: e.id,
score: similarity_metric.calculate(&e.values, query),
text: e.text.clone(),
metadata: e.metadata.clone()
})
.collect();
similarities.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
similarities.truncate(k);
similarities
}
fn len(&self) -> usize {
self.data.len()
}
fn is_empty(&self) -> bool {
self.data.is_empty()
}
fn get_vector(&self, id: u64) -> Option<&Vector> {
self.data.iter().find(|e| e.id == id)
}
fn dimension(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SimilarityMetric;
use serde_json;
#[test]
fn test_serialization_deserialization() {
let vectors = vec![
Vector { id: 1, values: vec![1.0, 0.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![0.0, 1.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 3, values: vec![0.0, 0.0, 1.0], text: "test".to_string(), metadata: None },
];
let flat_index = FlatIndex::new(3, vectors);
let serialized = serde_json::to_string(&flat_index).expect("Serialization should work");
let deserialized: FlatIndex = serde_json::from_str(&serialized).expect("Deserialization should work");
assert_eq!(deserialized.len(), 3);
assert_eq!(deserialized.dimension(), 3);
assert!(!deserialized.is_empty());
assert!(deserialized.get_vector(1).is_some());
assert!(deserialized.get_vector(2).is_some());
assert!(deserialized.get_vector(3).is_some());
let query = vec![1.1, 0.1, 0.1];
let results = deserialized.search(&query, 2, SimilarityMetric::Cosine);
assert_eq!(results.len(), 2);
for i in 1..results.len() {
assert!(results[i-1].score >= results[i].score);
}
assert_eq!(results[0].id, 1);
assert!(results[0].score > 0.99);
}
#[test]
fn test_flat_index_with_cosine_similarity() {
let vectors = vec![
Vector { id: 1, values: vec![1.0, 0.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![0.0, 1.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 3, values: vec![0.0, 0.0, 1.0], text: "test".to_string(), metadata: None },
];
let index = FlatIndex::new(3, vectors);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2, SimilarityMetric::Cosine);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 1); assert!((results[0].score - 1.0).abs() < 1e-10);
}
#[test]
fn test_flat_index_with_euclidean_similarity() {
let vectors = vec![
Vector { id: 1, values: vec![0.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![3.0, 4.0], text: "test".to_string(), metadata: None },
Vector { id: 3, values: vec![6.0, 8.0], text: "test".to_string(), metadata: None },
];
let index = FlatIndex::new(2, vectors);
let query = vec![0.0, 0.0];
let results = index.search(&query, 2, SimilarityMetric::Euclidean);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 1); assert!((results[0].score - 1.0).abs() < 1e-10);
}
#[test]
fn test_flat_index_with_manhattan_similarity() {
let vectors = vec![
Vector { id: 1, values: vec![0.0, 0.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![3.0, 4.0], text: "test".to_string(), metadata: None },
Vector { id: 3, values: vec![6.0, 8.0], text: "test".to_string(), metadata: None },
];
let index = FlatIndex::new(2, vectors);
let query = vec![0.0, 0.0];
let results = index.search(&query, 2, SimilarityMetric::Manhattan);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 1); assert!((results[0].score - 1.0).abs() < 1e-10);
}
#[test]
fn test_flat_index_with_dot_product() {
let vectors = vec![
Vector { id: 1, values: vec![1.0, 2.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![2.0, 1.0], text: "test".to_string(), metadata: None },
Vector { id: 3, values: vec![0.0, 0.0], text: "test".to_string(), metadata: None },
];
let index = FlatIndex::new(2, vectors);
let query = vec![1.0, 2.0];
let results = index.search(&query, 2, SimilarityMetric::DotProduct);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 1); assert!((results[0].score - 5.0).abs() < 1e-10); }
#[test]
fn test_flat_index_change_similarity_metric() {
let vectors = vec![
Vector { id: 1, values: vec![1.0, 2.0], text: "test".to_string(), metadata: None },
Vector { id: 2, values: vec![2.0, 1.0], text: "test".to_string(), metadata: None },
];
let index = FlatIndex::new(2, vectors);
let query = vec![1.0, 2.0];
let results_cosine = index.search(&query, 1, SimilarityMetric::Cosine);
assert_eq!(results_cosine[0].id, 1);
let results_dot = index.search(&query, 1, SimilarityMetric::DotProduct);
assert_eq!(results_dot[0].id, 1);
assert_ne!(results_cosine[0].score, results_dot[0].score);
}
}