use std::collections::HashMap;
use crate::OracleError;
const BM25_K1: f64 = 1.5;
const BM25_B: f64 = 0.75;
const RRF_K: f64 = 60.0;
#[derive(Debug, Clone)]
pub struct Bm25Scorer {
doc_frequencies: HashMap<String, usize>,
num_docs: usize,
avg_doc_len: f64,
idf_cache: HashMap<String, f64>,
documents: Vec<Vec<String>>,
}
impl Bm25Scorer {
#[must_use]
pub fn new() -> Self {
Self {
doc_frequencies: HashMap::new(),
num_docs: 0,
avg_doc_len: 0.0,
idf_cache: HashMap::new(),
documents: Vec::new(),
}
}
pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) -> Result<(), OracleError> {
if documents.is_empty() {
return Err(OracleError::Feature(
"Cannot fit BM25 on empty corpus".to_string(),
));
}
let mut total_len = 0usize;
self.documents.clear();
self.doc_frequencies.clear();
for doc in documents {
let tokens = tokenize(doc.as_ref());
total_len += tokens.len();
let unique_terms: std::collections::HashSet<_> = tokens.iter().cloned().collect();
for term in unique_terms {
*self.doc_frequencies.entry(term).or_insert(0) += 1;
}
self.documents.push(tokens);
}
self.num_docs = documents.len();
self.avg_doc_len = total_len as f64 / self.num_docs as f64;
self.idf_cache.clear();
for (term, df) in &self.doc_frequencies {
let idf = compute_idf(*df, self.num_docs);
self.idf_cache.insert(term.clone(), idf);
}
Ok(())
}
#[must_use]
pub fn score(&self, query: &str) -> Vec<(usize, f64)> {
let query_tokens = tokenize(query);
let mut scores: Vec<(usize, f64)> = self
.documents
.iter()
.enumerate()
.map(|(idx, doc_tokens)| {
let score = self.score_document(&query_tokens, doc_tokens);
(idx, score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
}
fn score_document(&self, query_tokens: &[String], doc_tokens: &[String]) -> f64 {
let doc_len = doc_tokens.len() as f64;
let mut tf_counts: HashMap<&str, usize> = HashMap::new();
for token in doc_tokens {
*tf_counts.entry(token.as_str()).or_insert(0) += 1;
}
let mut score = 0.0;
for term in query_tokens {
let tf = *tf_counts.get(term.as_str()).unwrap_or(&0) as f64;
let idf = self.idf_cache.get(term).copied().unwrap_or(0.0);
let numerator = tf * (BM25_K1 + 1.0);
let denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * doc_len / self.avg_doc_len);
score += idf * numerator / denominator;
}
score
}
#[must_use]
pub fn num_docs(&self) -> usize {
self.num_docs
}
#[must_use]
pub fn avg_doc_len(&self) -> f64 {
self.avg_doc_len
}
}
impl Default for Bm25Scorer {
fn default() -> Self {
Self::new()
}
}
fn compute_idf(doc_freq: usize, num_docs: usize) -> f64 {
let n = num_docs as f64;
let df = doc_freq as f64;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect()
}
#[derive(Debug, Clone)]
pub struct RrfResult {
pub doc_idx: usize,
pub score: f64,
pub bm25_rank: usize,
pub tfidf_rank: usize,
}
#[must_use]
pub fn reciprocal_rank_fusion(
bm25_ranking: &[(usize, f64)],
tfidf_ranking: &[(usize, f64)],
top_k: usize,
) -> Vec<RrfResult> {
let bm25_ranks: HashMap<usize, usize> = bm25_ranking
.iter()
.enumerate()
.map(|(rank, (idx, _))| (*idx, rank + 1))
.collect();
let tfidf_ranks: HashMap<usize, usize> = tfidf_ranking
.iter()
.enumerate()
.map(|(rank, (idx, _))| (*idx, rank + 1))
.collect();
let mut all_docs: std::collections::HashSet<usize> = std::collections::HashSet::new();
for (idx, _) in bm25_ranking {
all_docs.insert(*idx);
}
for (idx, _) in tfidf_ranking {
all_docs.insert(*idx);
}
let mut results: Vec<RrfResult> = all_docs
.into_iter()
.map(|doc_idx| {
let bm25_rank = bm25_ranks.get(&doc_idx).copied().unwrap_or(0);
let tfidf_rank = tfidf_ranks.get(&doc_idx).copied().unwrap_or(0);
let mut score = 0.0;
if bm25_rank > 0 {
score += 1.0 / (RRF_K + bm25_rank as f64);
}
if tfidf_rank > 0 {
score += 1.0 / (RRF_K + tfidf_rank as f64);
}
RrfResult {
doc_idx,
score,
bm25_rank,
tfidf_rank,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
pub struct HybridRetriever {
bm25: Bm25Scorer,
tfidf: crate::tfidf::TfidfFeatureExtractor,
documents: Vec<String>,
is_fitted: bool,
}
impl HybridRetriever {
#[must_use]
pub fn new() -> Self {
Self {
bm25: Bm25Scorer::new(),
tfidf: crate::tfidf::TfidfFeatureExtractor::new(),
documents: Vec::new(),
is_fitted: false,
}
}
pub fn fit<S: AsRef<str> + Clone>(&mut self, documents: &[S]) -> Result<(), OracleError> {
self.bm25.fit(documents)?;
self.tfidf.fit(documents)?;
self.documents = documents.iter().map(|d| d.as_ref().to_string()).collect();
self.is_fitted = true;
Ok(())
}
pub fn query(
&self,
query: &str,
top_k: usize,
) -> Result<Vec<(String, RrfResult)>, OracleError> {
if !self.is_fitted {
return Err(OracleError::Feature(
"HybridRetriever not fitted. Call fit() first".to_string(),
));
}
let bm25_ranking = self.bm25.score(query);
let tfidf_ranking = self.tfidf_rank(query)?;
let rrf_results = reciprocal_rank_fusion(&bm25_ranking, &tfidf_ranking, top_k);
let results: Vec<(String, RrfResult)> = rrf_results
.into_iter()
.filter_map(|r| self.documents.get(r.doc_idx).map(|doc| (doc.clone(), r)))
.collect();
Ok(results)
}
fn tfidf_rank(&self, query: &str) -> Result<Vec<(usize, f64)>, OracleError> {
let query_vec = self.tfidf.transform(&[query])?;
let doc_vecs = self.tfidf.transform(&self.documents)?;
let mut rankings: Vec<(usize, f64)> = (0..self.documents.len())
.map(|idx| {
let sim = cosine_similarity(&query_vec, 0, &doc_vecs, idx);
(idx, sim)
})
.collect();
rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(rankings)
}
#[must_use]
pub fn is_fitted(&self) -> bool {
self.is_fitted
}
#[must_use]
pub fn num_docs(&self) -> usize {
self.documents.len()
}
}
impl Default for HybridRetriever {
fn default() -> Self {
Self::new()
}
}
fn cosine_similarity(
a_matrix: &aprender::primitives::Matrix<f64>,
a_row: usize,
b_matrix: &aprender::primitives::Matrix<f64>,
b_row: usize,
) -> f64 {
let cols = a_matrix.n_cols();
if cols != b_matrix.n_cols() {
return 0.0;
}
let mut dot = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for col in 0..cols {
let a_val = a_matrix.get(a_row, col);
let b_val = b_matrix.get(b_row, col);
dot += a_val * b_val;
norm_a += a_val * a_val;
norm_b += b_val * b_val;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_scorer_new() {
let scorer = Bm25Scorer::new();
assert_eq!(scorer.num_docs(), 0);
assert_eq!(scorer.avg_doc_len(), 0.0);
}
#[test]
fn test_bm25_fit_empty_corpus() {
let mut scorer = Bm25Scorer::new();
let empty: Vec<&str> = vec![];
let result = scorer.fit(&empty);
assert!(result.is_err());
}
#[test]
fn test_bm25_fit_success() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"expected i32 found str",
"cannot borrow as mutable",
"lifetime does not live long enough",
];
let result = scorer.fit(&docs);
assert!(result.is_ok());
assert_eq!(scorer.num_docs(), 3);
assert!(scorer.avg_doc_len() > 0.0);
}
#[test]
fn test_bm25_score_exact_match_highest() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"expected i32 found str",
"cannot borrow as mutable",
"type mismatch error",
];
scorer.fit(&docs).unwrap();
let scores = scorer.score("expected i32 found str");
assert!(!scores.is_empty());
assert_eq!(scores[0].0, 0); assert!(scores[0].1 > scores[1].1); }
#[test]
fn test_bm25_score_partial_match() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"type mismatch expected i32",
"cannot borrow mutably",
"expected value found reference",
];
scorer.fit(&docs).unwrap();
let scores = scorer.score("expected");
let top_indices: Vec<usize> = scores.iter().take(2).map(|(idx, _)| *idx).collect();
assert!(top_indices.contains(&0)); assert!(top_indices.contains(&2)); }
#[test]
fn test_bm25_idf_common_terms_lower() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"error error error",
"error message here",
"error type found",
"unique distinct different",
];
scorer.fit(&docs).unwrap();
let error_idf = scorer.idf_cache.get("error").copied().unwrap_or(0.0);
let unique_idf = scorer.idf_cache.get("unique").copied().unwrap_or(f64::MAX);
assert!(unique_idf > error_idf, "Rare terms should have higher IDF");
}
#[test]
fn test_tokenize_basic() {
let tokens = tokenize("Hello World");
assert_eq!(tokens, vec!["hello", "world"]);
}
#[test]
fn test_tokenize_with_punctuation() {
let tokens = tokenize("error[E0308]: expected `i32`, found `str`");
assert!(tokens.contains(&"expected".to_string()));
assert!(tokens.contains(&"i32".to_string()));
assert!(tokens.contains(&"str".to_string()));
}
#[test]
fn test_tokenize_empty() {
let tokens = tokenize("");
assert!(tokens.is_empty());
}
#[test]
fn test_compute_idf_rare_term() {
let idf = compute_idf(1, 100);
assert!(idf > 4.0, "Rare term should have high IDF");
}
#[test]
fn test_compute_idf_common_term() {
let idf = compute_idf(90, 100);
assert!(idf < 1.0, "Common term should have low IDF");
}
#[test]
fn test_compute_idf_all_docs() {
let idf = compute_idf(100, 100);
assert!(idf > 0.0, "IDF should still be positive with smoothing");
}
#[test]
fn test_rrf_empty_rankings() {
let bm25: Vec<(usize, f64)> = vec![];
let tfidf: Vec<(usize, f64)> = vec![];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert!(result.is_empty());
}
#[test]
fn test_rrf_single_ranking() {
let bm25 = vec![(0, 1.0), (1, 0.5), (2, 0.3)];
let tfidf: Vec<(usize, f64)> = vec![];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 3);
assert_eq!(result[0].doc_idx, 0);
assert!(result[0].bm25_rank > 0);
assert_eq!(result[0].tfidf_rank, 0);
}
#[test]
fn test_rrf_fusion_boosts_agreement() {
let bm25 = vec![(0, 1.0), (1, 0.5), (2, 0.3)];
let tfidf = vec![(0, 0.9), (2, 0.4), (1, 0.2)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result[0].doc_idx, 0);
let expected_score = 1.0 / (RRF_K + 1.0) + 1.0 / (RRF_K + 1.0);
assert!((result[0].score - expected_score).abs() < 0.001);
}
#[test]
fn test_rrf_top_k_limiting() {
let bm25: Vec<(usize, f64)> = (0..100).map(|i| (i, 1.0 / (i as f64 + 1.0))).collect();
let tfidf: Vec<(usize, f64)> = (0..100).map(|i| (i, 1.0 / (i as f64 + 1.0))).collect();
let result = reciprocal_rank_fusion(&bm25, &tfidf, 5);
assert_eq!(result.len(), 5);
}
#[test]
fn test_rrf_disjoint_rankings() {
let bm25 = vec![(0, 1.0), (1, 0.5)];
let tfidf = vec![(2, 0.9), (3, 0.4)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 4);
let top_score = result[0].score;
let second_score = result[1].score;
assert!((top_score - second_score).abs() < 0.001);
}
#[test]
fn test_hybrid_retriever_new() {
let retriever = HybridRetriever::new();
assert!(!retriever.is_fitted());
assert_eq!(retriever.num_docs(), 0);
}
#[test]
fn test_hybrid_retriever_query_without_fit() {
let retriever = HybridRetriever::new();
let result = retriever.query("test query", 5);
assert!(result.is_err());
}
#[test]
fn test_hybrid_retriever_fit_and_query() {
let mut retriever = HybridRetriever::new();
let docs = vec![
"expected i32 found str type mismatch",
"cannot borrow as mutable borrow checker error",
"lifetime does not live long enough",
"missing lifetime specifier",
];
retriever.fit(&docs).unwrap();
assert!(retriever.is_fitted());
assert_eq!(retriever.num_docs(), 4);
let results = retriever.query("type mismatch expected", 3).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
let (top_doc, _) = &results[0];
assert!(
top_doc.contains("type") || top_doc.contains("expected"),
"Top result should match query terms"
);
}
#[test]
fn test_hybrid_retriever_returns_documents() {
let mut retriever = HybridRetriever::new();
let docs = vec!["document one", "document two", "document three"];
retriever.fit(&docs).unwrap();
let results = retriever.query("one", 5).unwrap();
for (doc, _) in &results {
assert!(docs.contains(&doc.as_str()));
}
}
#[test]
fn test_cosine_similarity_identical() {
let matrix =
aprender::primitives::Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
.unwrap();
let sim = cosine_similarity(&matrix, 0, &matrix, 1);
assert!(
(sim - 1.0).abs() < 0.001,
"Identical vectors should have similarity 1.0"
);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let matrix =
aprender::primitives::Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let sim = cosine_similarity(&matrix, 0, &matrix, 1);
assert!(
(sim - 0.0).abs() < 0.001,
"Orthogonal vectors should have similarity 0.0"
);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let matrix =
aprender::primitives::Matrix::from_vec(2, 2, vec![1.0, 2.0, 0.0, 0.0]).unwrap();
let sim = cosine_similarity(&matrix, 0, &matrix, 1);
assert_eq!(sim, 0.0, "Zero vector should return 0 similarity");
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_bm25_scores_non_negative(
doc1 in "[a-z ]{5,50}",
doc2 in "[a-z ]{5,50}",
query in "[a-z ]{1,20}"
) {
let mut scorer = Bm25Scorer::new();
scorer.fit(&[doc1.as_str(), doc2.as_str()]).unwrap();
let scores = scorer.score(&query);
for (_, score) in scores {
prop_assert!(score >= 0.0, "BM25 scores should be non-negative");
}
}
#[test]
fn prop_rrf_scores_bounded(
n_docs in 1usize..50
) {
let bm25: Vec<(usize, f64)> = (0..n_docs)
.map(|i| (i, 1.0 / (i as f64 + 1.0)))
.collect();
let tfidf: Vec<(usize, f64)> = (0..n_docs)
.map(|i| (i, 1.0 / (i as f64 + 1.0)))
.collect();
let results = reciprocal_rank_fusion(&bm25, &tfidf, n_docs);
for r in results {
let max_score = 2.0 / (RRF_K + 1.0);
prop_assert!(r.score <= max_score + 0.001);
prop_assert!(r.score >= 0.0);
}
}
#[test]
fn prop_tokenize_deterministic(text in "[a-zA-Z ]{0,100}") {
let tokens1 = tokenize(&text);
let tokens2 = tokenize(&text);
prop_assert_eq!(tokens1, tokens2);
}
#[test]
fn prop_idf_monotonic(
df1 in 1usize..50,
df2 in 1usize..50
) {
let n_docs = 100;
let idf1 = compute_idf(df1, n_docs);
let idf2 = compute_idf(df2, n_docs);
if df1 < df2 {
prop_assert!(idf1 >= idf2, "IDF should decrease as DF increases");
}
}
}
#[test]
fn test_hybrid_retrieval_full_pipeline() {
let mut retriever = HybridRetriever::new();
let corpus = vec![
"error[E0308]: expected `i32`, found `&str`",
"error[E0308]: mismatched types expected i32 found String",
"error[E0502]: cannot borrow `x` as mutable because it is also borrowed as immutable",
"error[E0597]: `x` does not live long enough",
"error[E0106]: missing lifetime specifier",
"error[E0277]: the trait bound `Foo: Clone` is not satisfied",
"error[E0425]: cannot find value `foo` in this scope",
];
retriever.fit(&corpus).unwrap();
let results = retriever.query("type mismatch expected found", 3).unwrap();
assert!(!results.is_empty());
let (top_doc, top_result) = &results[0];
assert!(
top_doc.contains("expected") || top_doc.contains("found"),
"Top result should match type mismatch query"
);
assert!(top_result.score > 0.0);
}
#[test]
fn test_rrf_both_rankings_empty() {
let bm25: Vec<(usize, f64)> = vec![];
let tfidf: Vec<(usize, f64)> = vec![];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert!(
result.is_empty(),
"Empty rankings should produce empty result"
);
}
#[test]
fn test_rrf_bm25_only_ranking() {
let bm25 = vec![(5, 2.5), (3, 1.8), (7, 0.9)];
let tfidf: Vec<(usize, f64)> = vec![];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 3);
assert_eq!(result[0].doc_idx, 5);
assert_eq!(result[0].bm25_rank, 1);
assert_eq!(result[0].tfidf_rank, 0);
let expected_score = 1.0 / (RRF_K + 1.0);
assert!((result[0].score - expected_score).abs() < 0.0001);
}
#[test]
fn test_rrf_tfidf_only_ranking() {
let bm25: Vec<(usize, f64)> = vec![];
let tfidf = vec![(2, 0.95), (8, 0.75), (1, 0.50)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 3);
assert_eq!(result[0].doc_idx, 2);
assert_eq!(result[0].bm25_rank, 0);
assert_eq!(result[0].tfidf_rank, 1);
}
#[test]
fn test_rrf_tie_breaking_by_earlier_appearance() {
let bm25 = vec![(0, 1.0), (1, 0.5)];
let tfidf = vec![(1, 1.0), (0, 0.5)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
let doc0_score = result.iter().find(|r| r.doc_idx == 0).unwrap().score;
let doc1_score = result.iter().find(|r| r.doc_idx == 1).unwrap().score;
assert!(
(doc0_score - doc1_score).abs() < 0.0001,
"Symmetric rankings should produce equal scores"
);
}
#[test]
fn test_rrf_single_document_both_rankings() {
let bm25 = vec![(42, 5.0)];
let tfidf = vec![(42, 0.99)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 1);
assert_eq!(result[0].doc_idx, 42);
assert_eq!(result[0].bm25_rank, 1);
assert_eq!(result[0].tfidf_rank, 1);
let expected = 2.0 / (RRF_K + 1.0);
assert!((result[0].score - expected).abs() < 0.0001);
}
#[test]
fn test_rrf_top_k_zero_returns_empty() {
let bm25 = vec![(0, 1.0), (1, 0.5)];
let tfidf = vec![(0, 0.9), (1, 0.4)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 0);
assert!(result.is_empty(), "top_k=0 should return empty result");
}
#[test]
fn test_rrf_top_k_larger_than_corpus() {
let bm25 = vec![(0, 1.0), (1, 0.5)];
let tfidf = vec![(0, 0.9), (1, 0.4)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 100);
assert_eq!(
result.len(),
2,
"Should return all docs when top_k > corpus"
);
}
#[test]
fn test_rrf_preserves_all_unique_docs() {
let bm25 = vec![(0, 1.0), (1, 0.8), (2, 0.6)];
let tfidf = vec![(2, 0.95), (3, 0.7), (4, 0.5)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(result.len(), 5);
let doc_ids: std::collections::HashSet<_> = result.iter().map(|r| r.doc_idx).collect();
assert!(doc_ids.contains(&0));
assert!(doc_ids.contains(&1));
assert!(doc_ids.contains(&2));
assert!(doc_ids.contains(&3));
assert!(doc_ids.contains(&4));
}
#[test]
fn test_rrf_overlapping_doc_ranks_higher() {
let bm25 = vec![(0, 1.0), (2, 0.5)];
let tfidf = vec![(2, 0.9), (1, 0.4)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
assert_eq!(
result[0].doc_idx, 2,
"Doc in both rankings should rank first"
);
assert!(result[0].score > result[1].score);
}
#[test]
fn test_rrf_score_calculation_precision() {
let bm25 = vec![(0, 1.0), (1, 0.5), (2, 0.3)];
let tfidf = vec![(0, 0.9), (1, 0.4), (2, 0.2)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
let doc0 = result.iter().find(|r| r.doc_idx == 0).unwrap();
let expected_doc0 = 2.0 / 61.0;
assert!(
(doc0.score - expected_doc0).abs() < 0.00001,
"Doc 0 score precision: {} vs {}",
doc0.score,
expected_doc0
);
let doc1 = result.iter().find(|r| r.doc_idx == 1).unwrap();
let expected_doc1 = 2.0 / 62.0;
assert!(
(doc1.score - expected_doc1).abs() < 0.00001,
"Doc 1 score precision: {} vs {}",
doc1.score,
expected_doc1
);
}
#[test]
fn test_rrf_rank_fields_populated() {
let bm25 = vec![(0, 1.0), (1, 0.5)];
let tfidf = vec![(1, 0.9), (2, 0.4)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
let doc0 = result.iter().find(|r| r.doc_idx == 0).unwrap();
assert_eq!(doc0.bm25_rank, 1, "Doc 0 should be rank 1 in BM25");
assert_eq!(doc0.tfidf_rank, 0, "Doc 0 should not be in TF-IDF");
let doc1 = result.iter().find(|r| r.doc_idx == 1).unwrap();
assert_eq!(doc1.bm25_rank, 2, "Doc 1 should be rank 2 in BM25");
assert_eq!(doc1.tfidf_rank, 1, "Doc 1 should be rank 1 in TF-IDF");
let doc2 = result.iter().find(|r| r.doc_idx == 2).unwrap();
assert_eq!(doc2.bm25_rank, 0, "Doc 2 should not be in BM25");
assert_eq!(doc2.tfidf_rank, 2, "Doc 2 should be rank 2 in TF-IDF");
}
#[test]
fn test_rrf_large_rank_values() {
let bm25: Vec<(usize, f64)> = (0..100).map(|i| (i, 1.0 / (i as f64 + 1.0))).collect();
let tfidf: Vec<(usize, f64)> = vec![];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 100);
let doc99 = result.iter().find(|r| r.doc_idx == 99).unwrap();
let expected = 1.0 / (RRF_K + 100.0);
assert!(
(doc99.score - expected).abs() < 0.00001,
"Large rank calculation: {} vs {}",
doc99.score,
expected
);
}
#[test]
fn test_rrf_descending_order_guaranteed() {
let bm25 = vec![(0, 1.0), (1, 0.9), (2, 0.8), (3, 0.7), (4, 0.6)];
let tfidf = vec![(4, 1.0), (3, 0.9), (2, 0.8), (1, 0.7), (0, 0.6)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
for i in 1..result.len() {
assert!(
result[i - 1].score >= result[i].score,
"Results should be in descending order: {} >= {}",
result[i - 1].score,
result[i].score
);
}
}
#[test]
fn test_rrf_duplicate_doc_in_same_ranking() {
let bm25 = vec![(0, 1.0), (0, 0.5)]; let tfidf = vec![(1, 0.9)];
let result = reciprocal_rank_fusion(&bm25, &tfidf, 10);
let doc0_count = result.iter().filter(|r| r.doc_idx == 0).count();
assert_eq!(doc0_count, 1, "Duplicate should be deduplicated");
let doc0 = result.iter().find(|r| r.doc_idx == 0).unwrap();
assert_eq!(doc0.bm25_rank, 2, "Should use last rank for duplicates");
}
#[test]
fn test_bm25_single_document_corpus() {
let mut scorer = Bm25Scorer::new();
let docs = vec!["only document in corpus"];
scorer.fit(&docs).unwrap();
assert_eq!(scorer.num_docs(), 1);
assert!(scorer.avg_doc_len() > 0.0);
let scores = scorer.score("only");
assert_eq!(scores.len(), 1);
assert!(scores[0].1 > 0.0);
}
#[test]
fn test_bm25_empty_query() {
let mut scorer = Bm25Scorer::new();
let docs = vec!["document one", "document two"];
scorer.fit(&docs).unwrap();
let scores = scorer.score("");
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].1, 0.0);
assert_eq!(scores[1].1, 0.0);
}
#[test]
fn test_bm25_query_term_not_in_corpus() {
let mut scorer = Bm25Scorer::new();
let docs = vec!["apple banana cherry", "dog elephant fox"];
scorer.fit(&docs).unwrap();
let scores = scorer.score("zebra xyz unknown");
for (_, score) in &scores {
assert_eq!(*score, 0.0, "Unknown terms should produce zero score");
}
}
#[test]
fn test_bm25_document_length_normalization() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"target",
"target word word word word word word word word word word word",
];
scorer.fit(&docs).unwrap();
let scores = scorer.score("target");
let short_score = scores.iter().find(|(idx, _)| *idx == 0).unwrap().1;
let long_score = scores.iter().find(|(idx, _)| *idx == 1).unwrap().1;
assert!(
short_score > long_score,
"Short doc should score higher: {} vs {}",
short_score,
long_score
);
}
#[test]
fn test_bm25_term_frequency_saturation() {
let mut scorer = Bm25Scorer::new();
let docs = vec![
"word word word word word word word word word word",
"word other text here different content various",
];
scorer.fit(&docs).unwrap();
let scores = scorer.score("word");
let high_tf_score = scores.iter().find(|(idx, _)| *idx == 0).unwrap().1;
let low_tf_score = scores.iter().find(|(idx, _)| *idx == 1).unwrap().1;
assert!(
high_tf_score > low_tf_score,
"High TF doc should score higher"
);
}
#[test]
fn test_bm25_refit_clears_state() {
let mut scorer = Bm25Scorer::new();
scorer.fit(&["doc one", "doc two"]).unwrap();
assert_eq!(scorer.num_docs(), 2);
scorer.fit(&["new doc", "another", "third"]).unwrap();
assert_eq!(scorer.num_docs(), 3);
}
#[test]
fn test_hybrid_retriever_single_doc_corpus() {
let mut retriever = HybridRetriever::new();
retriever.fit(&["single document"]).unwrap();
let results = retriever.query("single", 5).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "single document");
}
#[test]
fn test_hybrid_retriever_query_not_matching() {
let mut retriever = HybridRetriever::new();
retriever.fit(&["apple banana", "cherry date"]).unwrap();
let results = retriever.query("zebra xyz unknown", 5).unwrap();
assert!(!results.is_empty());
for (_, rrf) in &results {
assert!(
rrf.score < 0.05,
"Non-matching query should have low scores"
);
}
}
#[test]
fn test_hybrid_retriever_default_trait() {
let retriever = HybridRetriever::default();
assert!(!retriever.is_fitted());
}
#[test]
fn test_bm25_default_trait() {
let scorer = Bm25Scorer::default();
assert_eq!(scorer.num_docs(), 0);
}
#[test]
fn test_rrf_result_clone() {
let result = RrfResult {
doc_idx: 42,
score: 0.5,
bm25_rank: 1,
tfidf_rank: 2,
};
let cloned = result.clone();
assert_eq!(cloned.doc_idx, 42);
assert_eq!(cloned.score, 0.5);
assert_eq!(cloned.bm25_rank, 1);
assert_eq!(cloned.tfidf_rank, 2);
}
#[test]
fn test_rrf_result_debug() {
let result = RrfResult {
doc_idx: 1,
score: 0.033,
bm25_rank: 1,
tfidf_rank: 1,
};
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("doc_idx"));
assert!(debug_str.contains("score"));
}
}