#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RankingAlgorithm {
BM25,
TFIDF,
CosineSimilarity,
LinearCombination,
LearningToRank,
ReciprocalRankFusion,
}
impl RankingAlgorithm {
pub fn name(&self) -> &'static str {
match self {
RankingAlgorithm::BM25 => "BM25",
RankingAlgorithm::TFIDF => "TF-IDF",
RankingAlgorithm::CosineSimilarity => "Cosine Similarity",
RankingAlgorithm::LinearCombination => "Linear Combination",
RankingAlgorithm::LearningToRank => "Learning to Rank",
RankingAlgorithm::ReciprocalRankFusion => "Reciprocal Rank Fusion",
}
}
pub fn description(&self) -> &'static str {
match self {
RankingAlgorithm::BM25 => {
"Probabilistic ranking function based on term frequency and document length"
}
RankingAlgorithm::TFIDF => {
"Statistical measure of term importance based on frequency and rarity"
}
RankingAlgorithm::CosineSimilarity => {
"Similarity measure between query and document vectors"
}
RankingAlgorithm::LinearCombination => {
"Weighted combination of multiple ranking signals"
}
RankingAlgorithm::LearningToRank => "Machine learning-based ranking model",
RankingAlgorithm::ReciprocalRankFusion => "Method for combining multiple ranked lists",
}
}
}
#[derive(Debug, Clone)]
pub struct RankingScore {
pub score: f64,
pub algorithm: RankingAlgorithm,
pub signal_scores: Vec<SignalScore>,
pub confidence: f64,
pub explanation: Option<String>,
}
impl RankingScore {
pub fn new(score: f64, algorithm: RankingAlgorithm) -> Self {
Self {
score,
algorithm,
signal_scores: Vec::new(),
confidence: 1.0,
explanation: None,
}
}
pub fn with_signal(mut self, name: String, score: f64, weight: f64) -> Self {
self.signal_scores.push(SignalScore {
name,
score,
weight,
});
self
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence;
self
}
pub fn with_explanation(mut self, explanation: String) -> Self {
self.explanation = Some(explanation);
self
}
pub fn weighted_score(&self) -> f64 {
self.score * self.confidence
}
pub fn is_relevant(&self, threshold: f64) -> bool {
self.score >= threshold
}
}
#[derive(Debug, Clone)]
pub struct SignalScore {
pub name: String,
pub score: f64,
pub weight: f64,
}
#[derive(Debug, Clone)]
pub struct BM25Params {
pub k1: f64,
pub b: f64,
pub min_idf: f64,
}
impl Default for BM25Params {
fn default() -> Self {
Self {
k1: 1.5,
b: 0.75,
min_idf: 0.0,
}
}
}
pub fn bm25_score(
term_freq: f64,
doc_length: f64,
avg_doc_length: f64,
num_docs: f64,
doc_freq: f64,
params: &BM25Params,
) -> f64 {
let idf = ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0)
.ln()
.max(params.min_idf);
let norm_length = 1.0 - params.b + params.b * (doc_length / avg_doc_length);
let tf = (term_freq * (params.k1 + 1.0)) / (term_freq + params.k1 * norm_length);
idf * tf
}
pub fn tfidf_score(term_freq: f64, doc_freq: f64, num_docs: f64) -> f64 {
let tf = term_freq;
let idf = (num_docs / (doc_freq + 1.0)).ln();
tf * idf
}
pub fn linear_combination(scores: &[(f64, f64)]) -> f64 {
let total_weight: f64 = scores.iter().map(|(_, w)| w).sum();
if total_weight < 1e-10 {
return 0.0;
}
scores.iter().map(|(s, w)| s * w).sum::<f64>() / total_weight
}
pub fn reciprocal_rank_fusion(ranked_lists: &[Vec<String>], k: f64) -> Vec<(String, f64)> {
use std::collections::HashMap;
let mut scores: HashMap<String, f64> = HashMap::new();
for ranked_list in ranked_lists {
for (rank, doc_id) in ranked_list.iter().enumerate() {
let rrf_score = 1.0 / (k + (rank as f64 + 1.0));
*scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
}
}
let mut result: Vec<(String, f64)> = scores.into_iter().collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_score() {
let params = BM25Params::default();
let score = bm25_score(5.0, 100.0, 80.0, 1000.0, 10.0, ¶ms);
assert!(score > 0.0);
}
#[test]
fn test_tfidf_score() {
let score = tfidf_score(3.0, 10.0, 100.0);
assert!(score > 0.0);
}
#[test]
fn test_linear_combination() {
let scores = vec![(0.8, 0.5), (0.6, 0.3), (0.9, 0.2)];
let combined = linear_combination(&scores);
assert!(combined > 0.0 && combined < 1.0);
}
#[test]
fn test_reciprocal_rank_fusion() {
let list1 = vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
let list2 = vec!["doc2".to_string(), "doc1".to_string(), "doc4".to_string()];
let result = reciprocal_rank_fusion(&[list1, list2], 60.0);
assert!(!result.is_empty());
assert!(result[0].1 > 0.0);
}
#[test]
fn test_ranking_score() {
let score = RankingScore::new(0.85, RankingAlgorithm::BM25)
.with_signal("term_match".to_string(), 0.9, 0.6)
.with_signal("semantic".to_string(), 0.8, 0.4)
.with_confidence(0.95);
assert_eq!(score.score, 0.85);
assert_eq!(score.signal_scores.len(), 2);
assert_eq!(score.confidence, 0.95);
assert!((score.weighted_score() - 0.8075).abs() < 1e-6);
}
#[test]
fn test_ranking_algorithm_name() {
assert_eq!(RankingAlgorithm::BM25.name(), "BM25");
assert_eq!(RankingAlgorithm::TFIDF.name(), "TF-IDF");
}
}