#[path = "../common/mod.rs"]
mod common;
use common::TestConfig;
use langchainrust::{
Document, InMemoryVectorStore, RecursiveCharacterSplitter,
SimilarityRetriever, RetrieverTrait, TextSplitter,
};
use langchainrust::schema::Message;
use langchainrust::BaseChatModel;
use std::sync::Arc;
#[tokio::test]
#[ignore = "需要配置 API Key"]
async fn test_rag_retrieval() {
let config = TestConfig::get();
let embeddings = Arc::new(config.embeddings());
let store = Arc::new(InMemoryVectorStore::new());
let retriever = SimilarityRetriever::new(store.clone(), embeddings.clone());
let doc = Document::new(
"Rust is a systems programming language focused on safety, speed, and concurrency. \
It prevents common programming errors through its ownership system. \
Rust achieves memory safety without garbage collection."
);
let splitter = RecursiveCharacterSplitter::new(50, 10);
let chunks: Vec<Document> = splitter.split_text(&doc.page_content())
.into_iter()
.map(Document::new)
.collect();
retriever.add_documents(chunks).await.unwrap();
let relevant_docs = retriever.retrieve("What makes Rust safe?", 2).await.unwrap();
println!("Retrieved {} documents", relevant_docs.len());
for (i, doc) in relevant_docs.iter().enumerate() {
println!("Doc {}: {}", i, doc.page_content());
}
assert!(!relevant_docs.is_empty());
}
#[tokio::test]
#[ignore = "需要配置 API Key"]
async fn test_rag_with_llm_generation() {
let config = TestConfig::get();
let embeddings = Arc::new(config.embeddings());
let llm = config.openai_chat();
let store = Arc::new(InMemoryVectorStore::new());
let retriever = SimilarityRetriever::new(store.clone(), embeddings.clone());
let docs = vec![
Document::new("Rust was created by Mozilla."),
Document::new("Rust 1.0 was released in 2015."),
Document::new("Rust focuses on memory safety."),
];
retriever.add_documents(docs).await.unwrap();
let relevant_docs = retriever.retrieve("When was Rust released?", 2).await.unwrap();
let context = relevant_docs.iter()
.map(|d| d.page_content())
.collect::<Vec<_>>()
.join("\n");
let messages = vec![
Message::system("Answer based on the context provided."),
Message::human(&format!("Context:\n{}\n\nQuestion: When was Rust 1.0 released?", context)),
];
let response = llm.chat(messages, None).await.unwrap();
println!("Answer: {}", response.content);
assert!(response.content.contains("2015"));
}