use langchainrust::retrieval::bm25::{bm25_score, compute_idf};
use langchainrust::{BM25Index, BM25Params, BM25Retriever, Document, Tokenizer};
#[test]
fn test_bm25_idf_calculation() {
let idf_common = compute_idf(100, 100);
assert!(idf_common < 1.0, "常见词 IDF 应小于 1.0");
let idf_rare = compute_idf(1, 100);
assert!(idf_rare > idf_common, "稀有词 IDF 应大于常见词");
let idf_zero = compute_idf(0, 100);
assert_eq!(idf_zero, 0.0, "不存在的词 IDF 应为 0");
let idf_empty = compute_idf(1, 0);
assert_eq!(idf_empty, 0.0, "空文档集 IDF 应为 0");
}
#[test]
fn test_bm25_score_calculation() {
use std::collections::HashMap;
let params = BM25Params::default();
let query_terms = vec!["rust".to_string(), "programming".to_string()];
let mut doc_term_freqs = HashMap::new();
doc_term_freqs.insert("rust".to_string(), 2);
doc_term_freqs.insert("programming".to_string(), 1);
let mut idf_values = HashMap::new();
idf_values.insert("rust".to_string(), 2.0);
idf_values.insert("programming".to_string(), 1.5);
let score = bm25_score(
&query_terms,
&doc_term_freqs,
10, 15.0, &idf_values,
¶ms,
);
assert!(score > 0.0, "BM25 评分应大于 0");
}
#[test]
fn test_bm25_high_term_frequency() {
use std::collections::HashMap;
let params = BM25Params::default();
let query = vec!["rust".to_string()];
let idf = HashMap::from([("rust".to_string(), 2.0)]);
let low_tf = HashMap::from([("rust".to_string(), 1)]);
let score_low = bm25_score(&query, &low_tf, 10, 15.0, &idf, ¶ms);
let high_tf = HashMap::from([("rust".to_string(), 5)]);
let score_high = bm25_score(&query, &high_tf, 10, 15.0, &idf, ¶ms);
assert!(score_high > score_low, "高词频文档应得更高分");
}
#[test]
fn test_bm25_parameters() {
let default = BM25Params::default();
assert_eq!(default.k1, 1.5, "默认 k1 应为 1.5");
assert_eq!(default.b, 0.75, "默认 b 应为 0.75");
let custom = BM25Params::with_values(2.0, 0.5);
assert_eq!(custom.k1, 2.0, "自定义 k1 应为 2.0");
assert_eq!(custom.b, 0.5, "自定义 b 应为 0.5");
}
#[test]
fn test_bm25_index_basic_operations() {
let mut index = BM25Index::new();
let doc = Document::new("Rust programming language");
let terms = vec![
"rust".to_string(),
"programming".to_string(),
"language".to_string(),
];
index.add_document(doc, terms);
assert_eq!(index.n_docs(), 1, "文档数量应为 1");
assert_eq!(index.get_doc_length(0), 3, "文档长度应为 3");
assert!(index.get_document(0).is_some(), "应能获取文档");
}
#[test]
fn test_bm25_index_idf_values() {
let mut index = BM25Index::new();
index.add_document(
Document::new("Rust programming language"),
vec![
"rust".to_string(),
"programming".to_string(),
"language".to_string(),
],
);
index.add_document(
Document::new("Python scripting language"),
vec![
"python".to_string(),
"scripting".to_string(),
"language".to_string(),
],
);
let idf_language = index.compute_idf_for_term("language");
let idf_rust = index.compute_idf_for_term("rust");
assert!(idf_rust > idf_language, "稀有词 IDF 应高于常见词");
}
#[test]
fn test_bm25_index_average_document_length() {
let mut index = BM25Index::new();
index.add_document(Document::new("a"), vec!["a".to_string()]);
index.add_document(
Document::new("a b c"),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
assert_eq!(index.avgdl(), 2.0, "平均文档长度应为 2.0");
}
#[test]
fn test_tokenizer_english() {
let tokenizer = Tokenizer::new();
let terms = tokenizer.tokenize_english("Hello World Rust");
assert_eq!(terms, vec!["hello", "world", "rust"], "英文分词结果");
}
#[test]
fn test_tokenizer_english_stopwords() {
let tokenizer = Tokenizer::new();
let terms = tokenizer.tokenize_english("The Rust is a programming language");
assert!(!terms.contains(&"the".to_string()), "'the' 应被过滤");
assert!(!terms.contains(&"is".to_string()), "'is' 应被过滤");
assert!(!terms.contains(&"a".to_string()), "'a' 应被过滤");
assert!(terms.contains(&"rust".to_string()), "'rust' 应保留");
assert!(
terms.contains(&"programming".to_string()),
"'programming' 应保留"
);
assert!(terms.contains(&"language".to_string()), "'language' 应保留");
}
#[test]
fn test_tokenizer_chinese() {
let tokenizer = Tokenizer::new();
let terms = tokenizer.tokenize_chinese("编程语言");
assert!(terms.contains(&"编".to_string()), "应包含单字 '编'");
assert!(terms.contains(&"程".to_string()), "应包含单字 '程'");
assert!(terms.contains(&"语".to_string()), "应包含单字 '语'");
assert!(terms.contains(&"言".to_string()), "应包含单字 '言'");
assert!(terms.contains(&"编程".to_string()), "应包含双字 '编程'");
assert!(terms.contains(&"程语".to_string()), "应包含双字 '程语'");
assert!(terms.contains(&"语言".to_string()), "应包含双字 '语言'");
}
#[test]
fn test_tokenizer_chinese_stopwords() {
let tokenizer = Tokenizer::new();
let terms = tokenizer.tokenize_chinese("编程的语言");
assert!(!terms.contains(&"的".to_string()), "'的' 应被过滤");
assert!(terms.contains(&"编".to_string()), "'编' 应保留");
assert!(terms.contains(&"程".to_string()), "'程' 应保留");
}
#[test]
fn test_tokenizer_mixed_chinese_english() {
let tokenizer = Tokenizer::new();
let terms = tokenizer.tokenize("Rust 编程语言");
assert!(terms.contains(&"rust".to_string()), "应包含 'rust'");
assert!(terms.contains(&"编".to_string()), "应包含 '编'");
assert!(terms.contains(&"程".to_string()), "应包含 '程'");
assert!(terms.contains(&"编程".to_string()), "应包含 '编程'");
assert!(terms.contains(&"语言".to_string()), "应包含 '语言'");
}
#[test]
fn test_tokenizer_keep_stopwords() {
let tokenizer = Tokenizer::with_stopwords();
let terms = tokenizer.tokenize("The programming language");
assert!(terms.contains(&"the".to_string()), "'the' 应保留");
assert!(
terms.contains(&"programming".to_string()),
"'programming' 应保留"
);
assert!(terms.contains(&"language".to_string()), "'language' 应保留");
}
#[test]
fn test_bm25_retriever_basic_search() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![
Document::new("Rust is a systems programming language"),
Document::new("Python is a scripting language"),
Document::new("JavaScript is used for web development"),
]);
assert_eq!(retriever.len(), 3, "文档数量应为 3");
let results = retriever.search("programming language", 2);
assert_eq!(results.len(), 2, "应返回 2 个结果");
assert!(
results[0].document.content.contains("programming"),
"第一个结果应包含 'programming'"
);
}
#[test]
fn test_bm25_retriever_chinese_search() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![
Document::new("Rust 是一门系统编程语言"),
Document::new("Python 是脚本语言"),
Document::new("JavaScript 用于网页开发"),
]);
let results = retriever.search("编程语言", 2);
assert!(results.len() > 0, "应返回至少 1 个结果");
assert!(
results[0].document.content.contains("编程"),
"结果应包含 '编程'"
);
}
#[test]
fn test_bm25_retriever_empty_index() {
let mut retriever = BM25Retriever::new();
let results = retriever.search("test query", 5);
assert!(results.is_empty(), "空索引应返回空结果");
}
#[test]
fn test_bm25_retriever_custom_parameters() {
let mut retriever = BM25Retriever::with_params(2.0, 0.5);
retriever.add_documents_sync(vec![
Document::new("Rust programming"),
Document::new("Python scripting"),
]);
let results = retriever.search("programming", 1);
assert_eq!(results.len(), 1, "应返回 1 个结果");
}
#[test]
fn test_bm25_retriever_no_matching_documents() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![
Document::new("Rust programming language"),
Document::new("Python scripting language"),
]);
let results = retriever.search("javascript typescript", 5);
assert!(results.is_empty(), "无匹配时应返回空结果");
}
#[test]
fn test_bm25_retriever_score_ordering() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![
Document::new("Rust Rust Rust programming"), Document::new("Python programming"), ]);
let results = retriever.search("rust", 2);
if results.len() >= 2 {
assert!(results[0].score >= results[1].score, "结果应按评分降序排列");
}
}
#[test]
fn test_bm25_retriever_clear_index() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![Document::new("Test document")]);
assert_eq!(retriever.len(), 1, "添加后应有 1 个文档");
retriever.clear();
assert_eq!(retriever.len(), 0, "清空后应为 0 个文档");
assert!(retriever.is_empty(), "is_empty() 应返回 true");
}
#[test]
fn test_bm25_retriever_document_length_normalization() {
let mut retriever = BM25Retriever::new();
retriever.add_documents_sync(vec![
Document::new("Rust"), ]);
retriever.add_documents_sync(vec![
Document::new("Rust is a systems programming language with many features"), ]);
let results = retriever.search("rust", 2);
assert!(results.len() > 0, "应有匹配结果");
}