use rag::{
vector_store::{Document, InMemoryVectorStore, MetadataFilter, MinimalVectorDB, VectorStore},
DistanceMetric, Index, FlatIndex,
};
#[tokio::test]
async fn test_batch_search_ranking_consistency() {
let store = InMemoryVectorStore::new();
let docs = vec![
Document::new("Rust".to_string()).with_embedding(vec![1.0, 0.0, 0.0, 0.0]),
Document::new("Go".to_string()).with_embedding(vec![0.9, 0.1, 0.0, 0.0]),
Document::new("Python".to_string()).with_embedding(vec![0.0, 1.0, 0.0, 0.0]),
Document::new("JavaScript".to_string()).with_embedding(vec![0.0, 0.0, 1.0, 0.0]),
Document::new("TypeScript".to_string()).with_embedding(vec![0.0, 0.1, 0.9, 0.0]),
];
store.add_batch(docs).await.unwrap();
let queries = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.5, 0.5, 0.0, 0.0],
];
let results = store.search_batch(&queries, 2).await.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].len(), 2);
assert_eq!(results[0][0].document.content, "Rust");
assert_eq!(results[0][1].document.content, "Go");
assert!(results[0][0].score >= results[0][1].score);
assert_eq!(results[1].len(), 2);
assert_eq!(results[1][0].document.content, "JavaScript");
assert_eq!(results[1][1].document.content, "TypeScript");
assert_eq!(results[2].len(), 2);
assert!(results[2][0].score >= results[2][1].score);
}
#[tokio::test]
async fn test_batch_search_euclidean_spatial() {
let store = InMemoryVectorStore::with_metric(DistanceMetric::Euclidean);
let docs = vec![
Document::new("Point A".to_string()).with_embedding(vec![0.0, 0.0, 0.0]),
Document::new("Point B".to_string()).with_embedding(vec![0.1, 0.1, 0.1]),
Document::new("Point C".to_string()).with_embedding(vec![10.0, 10.0, 10.0]),
];
store.add_batch(docs).await.unwrap();
let queries = vec![vec![0.0, 0.0, 0.0], vec![5.0, 5.0, 5.0]];
let results = store.search_batch(&queries, 2).await.unwrap();
assert_eq!(results[0][0].document.content, "Point A");
assert_eq!(results[0][1].document.content, "Point B");
assert_eq!(results[1][0].document.content, "Point B");
}
#[tokio::test]
async fn test_distance_metrics_produce_different_rankings() {
let docs = vec![
Document::new("A".to_string()).with_embedding(vec![1.0, 0.0, 0.0]),
Document::new("B".to_string()).with_embedding(vec![0.0, 1.0, 0.0]),
Document::new("C".to_string()).with_embedding(vec![1.0, 1.0, 0.0]),
Document::new("D".to_string()).with_embedding(vec![0.9, 0.1, 0.0]),
];
let query = vec![1.0, 0.0, 0.0];
let cosine_store = InMemoryVectorStore::with_metric(DistanceMetric::Cosine);
cosine_store.add_batch(docs.clone()).await.unwrap();
let cosine_results = cosine_store.search(&query, 4).await.unwrap();
assert_eq!(cosine_results[0].document.content, "A");
assert_eq!(cosine_results[1].document.content, "D");
let dot_store = InMemoryVectorStore::with_metric(DistanceMetric::DotProduct);
dot_store.add_batch(docs.clone()).await.unwrap();
let dot_results = dot_store.search(&query, 4).await.unwrap();
assert!(dot_results[0].score > 0.99);
assert_ne!(dot_results[3].document.content, "A");
assert_ne!(dot_results[3].document.content, "C");
let euclid_store = InMemoryVectorStore::with_metric(DistanceMetric::Euclidean);
euclid_store.add_batch(docs.clone()).await.unwrap();
let euclid_results = euclid_store.search(&query, 4).await.unwrap();
assert_eq!(euclid_results[0].document.content, "A");
assert_eq!(euclid_results[1].document.content, "D");
assert_eq!(euclid_results[2].document.content, "C");
}
#[tokio::test]
async fn test_cosine_ignores_magnitude() {
let store = InMemoryVectorStore::with_metric(DistanceMetric::Cosine);
let docs = vec![
Document::new("Small".to_string()).with_embedding(vec![1.0, 0.0]),
Document::new("Large".to_string()).with_embedding(vec![100.0, 0.0]),
Document::new("Orthogonal".to_string()).with_embedding(vec![0.0, 1.0]),
];
store.add_batch(docs).await.unwrap();
let results = store.search(&[1.0, 0.0], 3).await.unwrap();
assert!(results[0].score > 0.99); assert!(results[1].score > 0.99); assert!(results[2].score < 0.01); }
#[tokio::test]
async fn test_pure_memory_rag_document_lifecycle() {
let store = InMemoryVectorStore::with_metric(DistanceMetric::Cosine);
let docs = vec![
Document::new("Rust doc".to_string())
.with_embedding(vec![0.95, 0.10, 0.05, 0.00])
.with_metadata("topic".to_string(), "programming".to_string()),
Document::new("Python doc".to_string())
.with_embedding(vec![0.10, 0.90, 0.20, 0.05])
.with_metadata("topic".to_string(), "programming".to_string()),
Document::new("ML doc".to_string())
.with_embedding(vec![0.15, 0.20, 0.85, 0.10])
.with_metadata("topic".to_string(), "ai".to_string()),
Document::new("Deep learning doc".to_string())
.with_embedding(vec![0.20, 0.25, 0.80, 0.15])
.with_metadata("topic".to_string(), "ai".to_string()),
Document::new("Docker doc".to_string())
.with_embedding(vec![0.30, 0.40, 0.10, 0.85])
.with_metadata("topic".to_string(), "devops".to_string()),
];
store.add_batch(docs).await.unwrap();
assert_eq!(store.count().await.unwrap(), 5);
let rust_results = store.search(&[0.9, 0.1, 0.0, 0.0], 2).await.unwrap();
assert_eq!(rust_results[0].document.content, "Rust doc");
let ai_results = store.search(&[0.1, 0.1, 0.9, 0.1], 2).await.unwrap();
assert_eq!(ai_results[0].document.content, "ML doc");
let batch_queries = vec![
vec![0.9, 0.1, 0.0, 0.0],
vec![0.1, 0.1, 0.9, 0.1],
vec![0.3, 0.3, 0.1, 0.8],
];
let batch_results = store.search_batch(&batch_queries, 2).await.unwrap();
assert_eq!(batch_results.len(), 3);
assert_eq!(batch_results[0][0].document.content, "Rust doc");
assert_eq!(batch_results[1][0].document.content, "ML doc");
assert_eq!(batch_results[2][0].document.content, "Docker doc");
let all_docs = store.list(10, 0).await.unwrap();
assert_eq!(all_docs.len(), 5);
store.delete(&all_docs[0].id).await.unwrap();
assert_eq!(store.count().await.unwrap(), 4);
store.clear().await.unwrap();
assert_eq!(store.count().await.unwrap(), 0);
}
#[tokio::test]
async fn test_metadata_filter_with_embedding_search() {
let store = InMemoryVectorStore::new();
let docs = vec![
Document::new("Rust programming".to_string())
.with_embedding(vec![1.0, 0.0, 0.0])
.with_metadata("lang".to_string(), "rust".to_string()),
Document::new("Rust web framework".to_string())
.with_embedding(vec![0.9, 0.1, 0.0])
.with_metadata("lang".to_string(), "rust".to_string()),
Document::new("Python data science".to_string())
.with_embedding(vec![0.0, 1.0, 0.0])
.with_metadata("lang".to_string(), "python".to_string()),
];
store.add_batch(docs).await.unwrap();
let all_results = store.search(&[1.0, 0.0, 0.0], 3).await.unwrap();
assert_eq!(all_results.len(), 3);
let filter = MetadataFilter::new().add("lang".to_string(), "rust".to_string());
let filtered_results = store.search_with_filter(&[1.0, 0.0, 0.0], 3, &filter).await.unwrap();
assert_eq!(filtered_results.len(), 2);
assert!(filtered_results.iter().all(|r| r.document.metadata.get("lang") == Some(&"rust".to_string())));
}
#[tokio::test]
async fn test_minimal_db_equivalent_behavior() {
let inmem = InMemoryVectorStore::new();
let minimal = MinimalVectorDB::new();
let docs = vec![
Document::new("Doc 1".to_string()).with_embedding(vec![1.0, 0.0]),
Document::new("Doc 2".to_string()).with_embedding(vec![0.0, 1.0]),
];
inmem.add_batch(docs.clone()).await.unwrap();
minimal.add_batch(docs.clone()).await.unwrap();
let inmem_results = inmem.search(&[1.0, 0.0], 2).await.unwrap();
let minimal_results = minimal.search(&[1.0, 0.0], 2).await.unwrap();
assert_eq!(inmem_results.len(), minimal_results.len());
assert_eq!(inmem_results[0].document.content, minimal_results[0].document.content);
assert_eq!(inmem_results[1].document.content, minimal_results[1].document.content);
assert_eq!(inmem.count().await.unwrap(), minimal.count().await.unwrap());
inmem.clear().await.unwrap();
minimal.clear().await.unwrap();
assert_eq!(inmem.count().await.unwrap(), 0);
assert_eq!(minimal.count().await.unwrap(), 0);
}
#[test]
fn test_flat_index_metric_accessor() {
let index_cosine = FlatIndex::new();
assert_eq!(index_cosine.metric(), DistanceMetric::Cosine);
let index_euclid = FlatIndex::with_metric(DistanceMetric::Euclidean);
assert_eq!(index_euclid.metric(), DistanceMetric::Euclidean);
}
#[test]
fn test_flat_index_batch_search_parallel() {
let index = FlatIndex::new();
index.add(Document::new("A".to_string()).with_embedding(vec![1.0, 0.0]));
index.add(Document::new("B".to_string()).with_embedding(vec![0.0, 1.0]));
index.add(Document::new("C".to_string()).with_embedding(vec![0.5, 0.5]));
let queries = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.5, 0.5],
];
let results = index.search_batch(&queries, 2);
assert_eq!(results.len(), 3);
for r in &results {
assert_eq!(r.len(), 2);
}
assert_eq!(results[0][0].document.content, "A");
assert_eq!(results[1][0].document.content, "B");
assert_eq!(results[2][0].document.content, "C");
}