#[cfg(test)]
mod tests {
use langchainrust::vector_stores::{
Document, VectorStore, InMemoryVectorStore,
VectorStoreProvider, VectorStoreType, VectorStoreBuilder,
};
use langchainrust::retrieval::{SimilarityRetriever, RetrieverTrait};
use langchainrust::embeddings::{MockEmbeddings, cosine_similarity};
use std::sync::Arc;
use std::collections::HashMap;
#[test]
fn test_document_creation() {
let doc = Document::new("Test content")
.with_id("doc-1")
.with_metadata("author", "test");
assert_eq!(doc.content, "Test content");
assert_eq!(doc.id, Some("doc-1".to_string()));
assert_eq!(doc.metadata.get("author"), Some(&"test".to_string()));
}
#[test]
fn test_document_serialization() {
let doc = Document::new("Serialization test")
.with_id("serde-doc")
.with_metadata("key", "value");
let json = serde_json::to_string(&doc).unwrap();
let decoded: Document = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.content, doc.content);
assert_eq!(decoded.id, doc.id);
}
#[tokio::test]
async fn test_add_and_search() {
let store = InMemoryVectorStore::new();
let docs = vec![
Document::new("Rust programming"),
Document::new("Python scripting"),
];
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
];
let ids = store.add_documents(docs, embeddings).await.unwrap();
assert_eq!(ids.len(), 2);
assert_eq!(store.count().await, 2);
let query = vec![0.9, 0.1, 0.0];
let results = store.similarity_search(&query, 2).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].document.content.contains("Rust"));
}
#[tokio::test]
async fn test_get_delete_document() {
let store = InMemoryVectorStore::new();
let doc = Document::new("Test doc").with_id("test-id");
store.add_documents(vec![doc], vec![vec![1.0, 0.0]]).await.unwrap();
let retrieved = store.get_document("test-id").await.unwrap();
assert!(retrieved.is_some());
store.delete_document("test-id").await.unwrap();
let deleted = store.get_document("test-id").await.unwrap();
assert!(deleted.is_none());
}
#[tokio::test]
async fn test_count_mismatch_error() {
let store = InMemoryVectorStore::new();
let docs = vec![Document::new("A"), Document::new("B")];
let embeddings = vec![vec![1.0, 0.0]]; assert!(store.add_documents(docs, embeddings).await.is_err());
}
#[tokio::test]
async fn test_clear_and_count() {
let store = InMemoryVectorStore::new();
for i in 0..5 {
let doc = Document::new(format!("Doc {}", i));
store.add_documents(vec![doc], vec![vec![i as f32, 0.0]]).await.unwrap();
}
assert_eq!(store.count().await, 5);
store.clear().await.unwrap();
assert_eq!(store.count().await, 0);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
let zero = vec![0.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &zero), 0.0);
}
#[tokio::test]
async fn test_retriever() {
let store = Arc::new(InMemoryVectorStore::new());
let embeddings = Arc::new(MockEmbeddings::new(64));
let retriever = SimilarityRetriever::new(store.clone(), embeddings);
retriever.add_documents(vec![
Document::new("Rust tutorial").with_metadata("type", "lang"),
Document::new("Qdrant database").with_metadata("type", "db"),
]).await.unwrap();
let results = retriever.retrieve("programming", 2).await.unwrap();
assert_eq!(results.len(), 2);
}
#[tokio::test]
async fn test_provider_in_memory() {
let store = VectorStoreProvider::create(VectorStoreType::InMemory).await.unwrap();
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_builder() {
let store = VectorStoreBuilder::in_memory().build().await.unwrap();
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_provider_trait_object() {
let store: Arc<dyn VectorStore> = VectorStoreBuilder::in_memory().build().await.unwrap();
store.add_documents(
vec![Document::new("Test")],
vec![vec![1.0, 0.0]]
).await.unwrap();
assert_eq!(store.count().await, 1);
}
}