use langchainrust::{
BaseChain, ChainError,
ConversationRetrievalChain, StuffDocumentsChain,
RefineDocumentsChain, MapReduceDocumentsChain, MapRerankDocumentsChain,
ConversationBufferMemory, Document,
};
use langchainrust::language_models::{OpenAIChat, OpenAIConfig};
use std::collections::HashMap;
use std::sync::Arc;
fn create_test_config() -> OpenAIConfig {
OpenAIConfig {
api_key: "sk-test".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
model: "gpt-3.5-turbo".to_string(),
streaming: false,
organization: None,
frequency_penalty: None,
max_tokens: None,
presence_penalty: None,
temperature: None,
top_p: None,
tools: None,
tool_choice: None,
}
}
fn create_test_llm() -> OpenAIChat {
OpenAIChat::new(create_test_config())
}
#[test]
fn test_conversation_retrieval_new() {
let llm = create_test_llm();
let store = Arc::new(langchainrust::InMemoryVectorStore::new());
let embeddings = Arc::new(langchainrust::MockEmbeddings::new(64));
let retriever = Arc::new(langchainrust::SimilarityRetriever::new(store, embeddings));
let memory = ConversationBufferMemory::new();
let chain = ConversationRetrievalChain::new(llm, retriever, memory);
assert_eq!(chain.input_keys(), vec!["query"]);
assert_eq!(chain.output_keys(), vec!["result"]);
assert_eq!(chain.name(), "conversation_retrieval");
}
#[test]
fn test_conversation_retrieval_with_options() {
let llm = create_test_llm();
let store = Arc::new(langchainrust::InMemoryVectorStore::new());
let embeddings = Arc::new(langchainrust::MockEmbeddings::new(64));
let retriever = Arc::new(langchainrust::SimilarityRetriever::new(store, embeddings));
let memory = ConversationBufferMemory::new();
let chain = ConversationRetrievalChain::new(llm, retriever, memory)
.with_system_prompt("你是一个 Rust 专家")
.with_k(5)
.with_input_key("question")
.with_output_key("answer")
.with_return_source_documents(true)
.with_verbose(true);
assert_eq!(chain.input_keys(), vec!["question"]);
assert_eq!(chain.output_keys(), vec!["answer", "source_documents"]);
}
#[tokio::test]
async fn test_conversation_retrieval_missing_input() {
let llm = create_test_llm();
let store = Arc::new(langchainrust::InMemoryVectorStore::new());
let embeddings = Arc::new(langchainrust::MockEmbeddings::new(64));
let retriever = Arc::new(langchainrust::SimilarityRetriever::new(store, embeddings));
let memory = ConversationBufferMemory::new();
let chain = ConversationRetrievalChain::new(llm, retriever, memory);
let inputs = HashMap::new();
let result = chain.invoke(inputs).await;
assert!(result.is_err());
match result {
Err(ChainError::MissingInput(_)) => {} _ => panic!("应当返回 MissingInput 错误"),
}
}
#[test]
fn test_stuff_documents_new() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm);
assert_eq!(chain.input_keys(), vec!["input", "documents"]);
assert_eq!(chain.output_keys(), vec!["output"]);
assert_eq!(chain.name(), "stuff_documents");
}
#[test]
fn test_stuff_documents_with_options() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm)
.with_input_key("question")
.with_output_key("answer")
.with_max_doc_length(500)
.with_verbose(true);
assert_eq!(chain.input_keys(), vec!["question", "documents"]);
assert_eq!(chain.output_keys(), vec!["answer"]);
}
#[test]
fn test_stuff_documents_format_documents() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm);
let docs = vec![
Document::new("文档一的内容"),
Document::new("文档二的内容"),
];
let formatted = chain.format_documents(&docs);
assert!(formatted.contains("文档 1:"));
assert!(formatted.contains("文档一的内容"));
assert!(formatted.contains("文档 2:"));
assert!(formatted.contains("文档二的内容"));
}
#[test]
fn test_stuff_documents_truncation() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm)
.with_max_doc_length(10);
let docs = vec![
Document::new("这是一段超过十个字符的文档内容"),
];
let formatted = chain.format_documents(&docs);
assert!(formatted.contains("[文档已截断]"));
assert!(formatted.len() < 100);
}
#[test]
fn test_stuff_documents_build_prompt() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm);
let prompt = chain.build_prompt("这是上下文", "这是问题");
assert!(prompt.contains("这是上下文"));
assert!(prompt.contains("这是问题"));
assert!(prompt.contains("{context}") == false); assert!(prompt.contains("{input}") == false);
}
#[test]
fn test_stuff_documents_custom_template() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm)
.with_prompt_template("背景:{context}\n问题:{input}")
.with_document_variable("context");
let prompt = chain.build_prompt("测试背景", "测试问题");
assert!(prompt.contains("背景:测试背景"));
assert!(prompt.contains("问题:测试问题"));
}
#[test]
fn test_stuff_documents_empty_docs() {
let llm = create_test_llm();
let chain = StuffDocumentsChain::new(llm);
let docs = vec![];
let formatted = chain.format_documents(&docs);
assert!(formatted.is_empty());
}
#[test]
fn test_refine_documents_new() {
let llm = create_test_llm();
let chain = RefineDocumentsChain::new(llm);
assert_eq!(chain.input_keys(), vec!["input", "documents"]);
assert_eq!(chain.output_keys(), vec!["output"]);
assert_eq!(chain.name(), "refine_documents");
}
#[test]
fn test_refine_documents_build_prompts() {
let llm = create_test_llm();
let chain = RefineDocumentsChain::new(llm)
.with_initial_prompt("初始:{context} - {input}")
.with_refine_prompt("优化:{context} - {input} - 已有:{existing_answer}")
.with_document_variable("context");
let initial = chain.build_initial_prompt("文档内容", "我的问题");
assert!(initial.contains("初始:文档内容 - 我的问题"));
let refine = chain.build_refine_prompt("新文档", "问题", "已有答案");
assert!(refine.contains("优化:新文档 - 问题 - 已有:已有答案"));
}
#[tokio::test]
async fn test_refine_documents_empty_docs() {
let llm = create_test_llm();
let chain = RefineDocumentsChain::new(llm);
let result = chain.invoke_with_documents(vec![], "测试").await;
assert!(result.is_err());
}
#[test]
fn test_map_reduce_new() {
let llm = create_test_llm();
let chain = MapReduceDocumentsChain::new(llm);
assert_eq!(chain.input_keys(), vec!["input", "documents"]);
assert_eq!(chain.output_keys(), vec!["output"]);
assert_eq!(chain.name(), "map_reduce_documents");
}
#[test]
fn test_map_reduce_build_prompts() {
let llm = create_test_llm();
let chain = MapReduceDocumentsChain::new(llm)
.with_map_prompt("处理:{context} - {input}")
.with_reduce_prompt("合并:{summaries}\n问题:{input}")
.with_document_variable("context");
let map_prompt = chain.build_map_prompt("文档内容", "问题");
assert!(map_prompt.contains("处理:文档内容 - 问题"));
let reduce_prompt = chain.build_reduce_prompt(&["答案1".into(), "答案2".into()], "原始问题");
assert!(reduce_prompt.contains("合并:"));
assert!(reduce_prompt.contains("答案1"));
assert!(reduce_prompt.contains("答案2"));
assert!(reduce_prompt.contains("原始问题"));
}
#[tokio::test]
async fn test_map_reduce_empty_docs() {
let llm = create_test_llm();
let chain = MapReduceDocumentsChain::new(llm);
let result = chain.invoke_with_documents(vec![], "测试").await;
assert!(result.is_err());
}
#[test]
fn test_all_chains_implement_base_chain() {
let store = Arc::new(langchainrust::InMemoryVectorStore::new());
let embeddings = Arc::new(langchainrust::MockEmbeddings::new(64));
let stuff = StuffDocumentsChain::new(create_test_llm());
assert!(!stuff.input_keys().is_empty());
assert!(!stuff.output_keys().is_empty());
let refine = RefineDocumentsChain::new(create_test_llm());
assert!(!refine.input_keys().is_empty());
assert!(!refine.output_keys().is_empty());
let map_reduce = MapReduceDocumentsChain::new(create_test_llm());
assert!(!map_reduce.input_keys().is_empty());
assert!(!map_reduce.output_keys().is_empty());
let map_rerank = MapRerankDocumentsChain::new(create_test_llm());
assert!(!map_rerank.input_keys().is_empty());
assert!(!map_rerank.output_keys().is_empty());
let retriever = Arc::new(langchainrust::SimilarityRetriever::new(store, embeddings));
let memory = ConversationBufferMemory::new();
let conv_retrieval = ConversationRetrievalChain::new(create_test_llm(), retriever, memory);
assert!(!conv_retrieval.input_keys().is_empty());
assert!(!conv_retrieval.output_keys().is_empty());
}
#[test]
fn test_map_rerank_new() {
let llm = create_test_llm();
let chain = MapRerankDocumentsChain::new(llm);
assert_eq!(chain.input_keys(), vec!["input", "documents"]);
assert_eq!(chain.output_keys(), vec!["output"]);
assert_eq!(chain.name(), "map_rerank_documents");
}
#[test]
fn test_map_rerank_with_options() {
let llm = create_test_llm();
let chain = MapRerankDocumentsChain::new(llm)
.with_top_k(3)
.with_input_key("question")
.with_output_key("ranked_results")
.with_verbose(true);
assert_eq!(chain.input_keys(), vec!["question", "documents"]);
assert_eq!(chain.output_keys(), vec!["ranked_results"]);
}
#[tokio::test]
async fn test_map_rerank_empty_docs() {
let llm = create_test_llm();
let chain = MapRerankDocumentsChain::new(llm);
let result = chain.invoke_with_documents(vec![], "测试").await;
assert!(result.is_err());
}
#[test]
fn test_map_rerank_build_prompt() {
let llm = create_test_llm();
let chain = MapRerankDocumentsChain::new(llm)
.with_map_prompt("评分:{context}\n问题:{input}")
.with_document_variable("context");
let prompt = chain.build_map_prompt("文档内容", "我的问题");
assert!(prompt.contains("评分:文档内容"));
assert!(prompt.contains("问题:我的问题"));
}
#[test]
fn test_map_rerank_extract_score() {
let (score, answer) = MapRerankDocumentsChain::extract_score("相关性评分:85\n答案:Rust 是一门系统编程语言");
assert_eq!(score, 85);
assert!(answer.contains("Rust"));
let (score2, answer2) = MapRerankDocumentsChain::extract_score("Score: 92\nAnswer: It's a programming language");
assert_eq!(score2, 92);
let (score3, _) = MapRerankDocumentsChain::extract_score("这是一段普通文本");
assert_eq!(score3, 50);
let (score4, _) = MapRerankDocumentsChain::extract_score("相关性评分:150");
assert_eq!(score4, 100);
}