use langchainrust::{BM25Retriever, Document};
use std::fs;
fn load_documents_from_file(path: &str) -> Vec<Document> {
let content = fs::read_to_string(path).expect(&format!("Failed to load file: {}", path));
content
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| Document::new(line.trim()))
.collect()
}
#[test]
fn test_bm25_english_programming_languages() {
let documents = load_documents_from_file("tests/data/programming_languages_en.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
assert!(retriever.len() > 0, "应加载至少 1 个文档");
let results = retriever.search("systems programming language", 3);
assert!(results.len() > 0, "应返回匹配结果");
let rust_found = results
.iter()
.any(|r| r.document.content.contains("Rust") && r.document.content.contains("systems"));
assert!(rust_found, "应找到 Rust 相关文档");
}
#[test]
fn test_bm25_exact_keyword_match() {
let documents = load_documents_from_file("tests/data/programming_languages_en.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("garbage collection", 2);
let rust_doc = results
.iter()
.find(|r| r.document.content.contains("garbage collection"));
assert!(rust_doc.is_some(), "应找到包含 garbage collection 的文档");
}
#[test]
fn test_bm25_chinese_programming_languages() {
let documents = load_documents_from_file("tests/data/programming_languages_zh.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
assert!(retriever.len() > 0, "应加载中文文档");
let results = retriever.search("系统级编程", 3);
assert!(results.len() > 0, "应返回中文匹配结果");
let rust_found = results.iter().any(|r| r.document.content.contains("Rust"));
assert!(rust_found, "应找到 Rust 中文文档");
}
#[test]
fn test_bm25_chinese_keywords() {
let documents = load_documents_from_file("tests/data/programming_languages_zh.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("垃圾回收", 2);
assert!(
results
.iter()
.any(|r| r.document.content.contains("垃圾回收")),
"应找到垃圾回收相关文档"
);
let go_results = retriever.search("微服务", 2);
assert!(
go_results.iter().any(|r| r.document.content.contains("Go")),
"微服务应关联到 Go 文档"
);
}
#[test]
fn test_bm25_chinese_short_documents() {
let documents = load_documents_from_file("tests/data/programming_short_zh.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("机器学习", 2);
assert!(
results
.iter()
.any(|r| r.document.content.contains("Python")),
"机器学习应关联到 Python"
);
let web_results = retriever.search("前端开发", 2);
assert!(
web_results
.iter()
.any(|r| r.document.content.contains("JavaScript")),
"前端开发应关联到 JavaScript"
);
}
#[test]
fn test_bm25_langchainrust_docs() {
let documents = load_documents_from_file("tests/data/langchainrust_docs.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("LangGraph workflow", 5);
assert!(
results
.iter()
.any(|r| r.document.content.contains("LangGraph")),
"应找到 LangGraph 文档"
);
let memory_results = retriever.search("memory system", 3);
assert!(
memory_results
.iter()
.any(|r| r.document.content.contains("Memory")),
"应找到 Memory 文档"
);
}
#[test]
fn test_bm25_framework_features() {
let documents = load_documents_from_file("tests/data/langchainrust_docs.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("Human in the loop", 3);
assert!(
results.iter().any(|r| r.document.content.contains("Human")),
"应找到 Human-in-the-loop 文档"
);
let vector_results = retriever.search("Qdrant vector store", 2);
assert!(
vector_results
.iter()
.any(|r| r.document.content.contains("Qdrant")),
"应找到 Qdrant 文档"
);
}
#[test]
fn test_bm25_algorithm_docs() {
let documents = load_documents_from_file("tests/data/bm25_docs.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("IDF inverse document frequency", 3);
assert!(
results.iter().any(|r| r.document.content.contains("IDF")),
"应找到 IDF 文档"
);
let param_results = retriever.search("k1 parameter", 2);
assert!(
param_results
.iter()
.any(|r| r.document.content.contains("k1")),
"应找到 k1 参数文档"
);
}
#[test]
fn test_bm25_algorithm_principles() {
let documents = load_documents_from_file("tests/data/bm25_docs.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(documents);
let results = retriever.search("term frequency saturation", 3);
assert!(
results
.iter()
.any(|r| r.document.content.contains("saturation")),
"应找到词频饱和文档"
);
let norm_results = retriever.search("document length normalization", 2);
assert!(
norm_results
.iter()
.any(|r| r.document.content.contains("length")),
"应找到文档长度归一化文档"
);
}
#[test]
fn test_bm25_multi_file_collection() {
let mut all_documents = Vec::new();
all_documents.extend(load_documents_from_file(
"tests/data/programming_languages_en.txt",
));
all_documents.extend(load_documents_from_file(
"tests/data/langchainrust_docs.txt",
));
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(all_documents);
let total_docs = retriever.len();
assert!(total_docs > 20, "应加载超过 20 个文档");
let results = retriever.search("Rust", 5);
assert!(results.len() > 0, "应返回 Rust 相关结果");
}
#[test]
fn test_bm25_document_length_effect() {
let short_docs = load_documents_from_file("tests/data/programming_short_zh.txt");
let long_docs = load_documents_from_file("tests/data/programming_languages_zh.txt");
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(short_docs);
retriever.add_documents_sync(long_docs);
let results = retriever.search("Python", 5);
assert!(results.len() > 0, "应返回 Python 相关文档");
for i in 0..results.len().saturating_sub(1) {
assert!(
results[i].score >= results[i + 1].score,
"结果应按评分降序排列"
);
}
}
#[test]
fn test_bm25_empty_file_handling() {
let empty_docs: Vec<Document> = Vec::new();
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(empty_docs);
assert!(retriever.is_empty(), "空文档集合应返回 is_empty = true");
let results = retriever.search("test", 5);
assert!(results.is_empty(), "空索引搜索应返回空结果");
}
#[test]
fn test_bm25_single_document() {
let single_doc = vec![Document::new(
"This is a single test document about Rust programming",
)];
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(single_doc);
assert_eq!(retriever.len(), 1, "应有 1 个文档");
let results = retriever.search("Rust", 1);
assert_eq!(results.len(), 1, "应返回 1 个结果");
}