use crate::vector::cosine_similarity;
use crate::{Document, Result, SearchResult};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct FlatIndex {
embedding_dim: usize,
documents: HashMap<String, Document>,
}
impl FlatIndex {
pub fn new(embedding_dim: usize) -> Self {
Self {
embedding_dim,
documents: HashMap::new(),
}
}
pub fn add(&mut self, document: Document) -> Result<()> {
if document.embedding.len() != self.embedding_dim {
return Err(crate::RagError::DimensionMismatch {
expected: self.embedding_dim,
actual: document.embedding.len(),
});
}
if document.embedding.iter().any(|v| v.is_nan()) {
return Err(crate::RagError::IndexError(
"embedding contains NaN values".to_string(),
));
}
self.documents.insert(document.id.clone(), document);
Ok(())
}
pub fn add_batch(&mut self, documents: Vec<Document>) -> Result<()> {
for document in documents {
self.add(document)?;
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.embedding_dim {
return Err(crate::RagError::DimensionMismatch {
expected: self.embedding_dim,
actual: query.len(),
});
}
let mut scored_docs: Vec<(f32, &Document)> = self
.documents
.values()
.map(|doc| {
let score = cosine_similarity(query, &doc.embedding).unwrap_or(0.0); (score, doc)
})
.collect();
scored_docs.sort_by(|a, b| b.0.total_cmp(&a.0));
let results = scored_docs
.into_iter()
.take(k)
.map(|(score, doc)| SearchResult {
id: doc.id.clone(),
content: doc.content.clone(),
score,
metadata: doc.metadata.clone(),
})
.collect();
Ok(results)
}
pub fn remove(&mut self, id: &str) -> Option<Document> {
self.documents.remove(id)
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
pub fn clear(&mut self) {
self.documents.clear();
}
pub fn get_all_documents(&self) -> Vec<Document> {
self.documents.values().cloned().collect()
}
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
}
impl crate::index::VectorIndex for FlatIndex {
fn add(&mut self, document: Document) -> Result<()> {
self.add(document)
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
self.search(query, k)
}
fn len(&self) -> usize {
self.len()
}
fn clear(&mut self) {
self.clear()
}
fn embedding_dim(&self) -> usize {
self.embedding_dim()
}
}
impl crate::index::VectorIndexSnapshot for FlatIndex {
fn get_all_documents(&self) -> Vec<Document> {
self.get_all_documents()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_document(id: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.to_string(),
content: format!("Test document {}", id),
embedding,
metadata: None,
}
}
#[test]
fn test_new_index() {
let index = FlatIndex::new(384);
assert_eq!(index.embedding_dim, 384);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_add_document() {
let mut index = FlatIndex::new(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
let result = index.add(doc);
assert!(result.is_ok());
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_add_document_dimension_mismatch() {
let mut index = FlatIndex::new(3);
let doc = create_test_document("doc1", vec![1.0, 0.0]);
let result = index.add(doc);
assert!(result.is_err());
assert_eq!(index.len(), 0);
}
#[test]
fn test_add_duplicate_id_replaces() {
let mut index = FlatIndex::new(3);
let doc1 = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
let mut doc2 = create_test_document("doc1", vec![0.0, 1.0, 0.0]);
doc2.content = "Updated content".to_string();
index.add(doc1).unwrap();
assert_eq!(index.len(), 1);
index.add(doc2).unwrap();
assert_eq!(index.len(), 1);
let query = vec![0.0, 1.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results[0].content, "Updated content");
}
#[test]
fn test_add_batch() {
let mut index = FlatIndex::new(3);
let docs = vec![
create_test_document("doc1", vec![1.0, 0.0, 0.0]),
create_test_document("doc2", vec![0.0, 1.0, 0.0]),
create_test_document("doc3", vec![0.0, 0.0, 1.0]),
];
let result = index.add_batch(docs);
assert!(result.is_ok());
assert_eq!(index.len(), 3);
}
#[test]
fn test_add_batch_partial_failure() {
let mut index = FlatIndex::new(3);
let docs = vec![
create_test_document("doc1", vec![1.0, 0.0, 0.0]),
create_test_document("doc2", vec![0.0, 1.0]), create_test_document("doc3", vec![0.0, 0.0, 1.0]),
];
let result = index.add_batch(docs);
assert!(result.is_err());
assert_eq!(index.len(), 1);
}
#[test]
fn test_search_exact_match() {
let mut index = FlatIndex::new(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
index.add(doc).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc1");
assert!((results[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn test_search_multiple_results_sorted() {
let mut index = FlatIndex::new(3);
index
.add(create_test_document("exact", vec![1.0, 0.0, 0.0]))
.unwrap();
index
.add(create_test_document("close", vec![0.9, 0.1, 0.0]))
.unwrap();
index
.add(create_test_document("far", vec![0.0, 1.0, 0.0]))
.unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "exact");
assert_eq!(results[1].id, "close");
assert_eq!(results[2].id, "far");
assert!(results[0].score > results[1].score);
assert!(results[1].score > results[2].score);
}
#[test]
fn test_search_limit_k() {
let mut index = FlatIndex::new(3);
for i in 0..5 {
let embedding = vec![i as f32, 0.0, 0.0];
index
.add(create_test_document(&format!("doc{}", i), embedding))
.unwrap();
}
let query = vec![10.0, 0.0, 0.0];
let results = index.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_search_dimension_mismatch() {
let mut index = FlatIndex::new(3);
index
.add(create_test_document("doc1", vec![1.0, 0.0, 0.0]))
.unwrap();
let query = vec![1.0, 0.0]; let result = index.search(&query, 5);
assert!(result.is_err());
}
#[test]
fn test_search_empty_index() {
let index = FlatIndex::new(3);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_remove_existing() {
let mut index = FlatIndex::new(3);
let doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
index.add(doc.clone()).unwrap();
let removed = index.remove("doc1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().id, "doc1");
assert_eq!(index.len(), 0);
}
#[test]
fn test_remove_non_existing() {
let mut index = FlatIndex::new(3);
let removed = index.remove("nonexistent");
assert!(removed.is_none());
}
#[test]
fn test_clear() {
let mut index = FlatIndex::new(3);
for i in 0..5 {
let doc = create_test_document(&format!("doc{}", i), vec![i as f32, 0.0, 0.0]);
index.add(doc).unwrap();
}
assert_eq!(index.len(), 5);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_orthogonal_vectors() {
let mut index = FlatIndex::new(3);
index
.add(create_test_document("x_axis", vec![1.0, 0.0, 0.0]))
.unwrap();
index
.add(create_test_document("y_axis", vec![0.0, 1.0, 0.0]))
.unwrap();
index
.add(create_test_document("z_axis", vec![0.0, 0.0, 1.0]))
.unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "x_axis");
assert!((results[0].score - 1.0).abs() < 1e-6);
assert!(results[1].score.abs() < 1e-6);
assert!(results[2].score.abs() < 1e-6);
}
#[test]
fn test_with_metadata() {
let mut index = FlatIndex::new(3);
let mut doc = create_test_document("doc1", vec![1.0, 0.0, 0.0]);
doc.metadata = Some(serde_json::json!({
"source": "test",
"timestamp": 123456789
}));
index.add(doc.clone()).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].metadata.is_some());
assert_eq!(results[0].metadata.as_ref().unwrap()["source"], "test");
}
#[test]
fn test_negative_similarity() {
let mut index = FlatIndex::new(3);
index
.add(create_test_document("positive", vec![1.0, 0.0, 0.0]))
.unwrap();
index
.add(create_test_document("negative", vec![-1.0, 0.0, 0.0]))
.unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "positive");
assert!((results[0].score - 1.0).abs() < 1e-6);
assert_eq!(results[1].id, "negative");
assert!((results[1].score + 1.0).abs() < 1e-6); }
#[test]
fn test_large_batch() {
let mut index = FlatIndex::new(128);
let mut docs = Vec::new();
for i in 0..1000 {
let mut embedding = vec![0.0; 128];
embedding[0] = i as f32;
docs.push(create_test_document(&format!("doc{}", i), embedding));
}
let result = index.add_batch(docs);
assert!(result.is_ok());
assert_eq!(index.len(), 1000);
let mut query = vec![0.0; 128];
query[0] = 500.0;
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
}
#[test]
fn test_add_nan_embedding_rejected() {
let mut index = FlatIndex::new(3);
let doc = create_test_document("nan_doc", vec![1.0, f32::NAN, 0.0]);
let result = index.add(doc);
assert!(result.is_err());
assert_eq!(index.len(), 0);
}
}