use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RankedResult {
pub id: i64,
pub relevant: bool,
pub relevance_grade: u8,
}
#[derive(Debug, Clone)]
pub struct GroundTruthResult {
pub chunk_id: i64,
pub file_path: String,
pub symbol: String,
pub relevance: u8,
pub rationale: String,
}
#[derive(Debug, Clone)]
pub struct EvaluationMetrics {
pub precision_at_k: HashMap<usize, f64>,
pub recall_at_k: HashMap<usize, f64>,
pub ndcg_at_k: HashMap<usize, f64>,
pub mrr: f64,
}
pub fn calculate_precision_at_k(results: &[RankedResult], k: usize) -> f64 {
if k == 0 {
return 0.0;
}
let top_k = results.iter().take(k);
let relevant_count = top_k.filter(|r| r.relevant).count();
relevant_count as f64 / k as f64
}
pub fn calculate_recall_at_k(results: &[RankedResult], k: usize, total_relevant: usize) -> f64 {
if total_relevant == 0 {
return 0.0;
}
let top_k = results.iter().take(k);
let relevant_count = top_k.filter(|r| r.relevant).count();
relevant_count as f64 / total_relevant as f64
}
pub fn calculate_ndcg_at_k(results: &[RankedResult], k: usize) -> f64 {
if k == 0 {
return 0.0;
}
let dcg = results
.iter()
.take(k)
.enumerate()
.map(|(i, result)| {
let relevance = result.relevance_grade as f64;
let position = (i + 2) as f64; relevance / position.log2()
})
.sum::<f64>();
let mut ideal_results: Vec<_> = results.to_vec();
ideal_results.sort_by(|a, b| b.relevance_grade.cmp(&a.relevance_grade));
let idcg = ideal_results
.iter()
.take(k)
.enumerate()
.map(|(i, result)| {
let relevance = result.relevance_grade as f64;
let position = (i + 2) as f64;
relevance / position.log2()
})
.sum::<f64>();
if idcg == 0.0 {
0.0
} else {
dcg / idcg
}
}
pub fn calculate_mrr(results: &[RankedResult]) -> f64 {
for (i, result) in results.iter().enumerate() {
if result.relevant {
return 1.0 / (i + 1) as f64;
}
}
0.0 }
pub fn calculate_all_metrics(
results: &[RankedResult],
total_relevant: usize,
k_values: &[usize],
) -> EvaluationMetrics {
let mut precision_at_k = HashMap::new();
let mut recall_at_k = HashMap::new();
let mut ndcg_at_k = HashMap::new();
for &k in k_values {
precision_at_k.insert(k, calculate_precision_at_k(results, k));
recall_at_k.insert(k, calculate_recall_at_k(results, k, total_relevant));
ndcg_at_k.insert(k, calculate_ndcg_at_k(results, k));
}
let mrr = calculate_mrr(results);
EvaluationMetrics {
precision_at_k,
recall_at_k,
ndcg_at_k,
mrr,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_results() -> Vec<RankedResult> {
vec![
RankedResult {
id: 1,
relevant: true,
relevance_grade: 3,
},
RankedResult {
id: 2,
relevant: true,
relevance_grade: 2,
},
RankedResult {
id: 3,
relevant: false,
relevance_grade: 0,
},
RankedResult {
id: 4,
relevant: true,
relevance_grade: 2,
},
RankedResult {
id: 5,
relevant: false,
relevance_grade: 0,
},
]
}
#[test]
fn test_precision_at_k() {
let results = create_test_results();
assert_eq!(calculate_precision_at_k(&results, 1), 1.0);
let p3 = calculate_precision_at_k(&results, 3);
assert!((p3 - 2.0 / 3.0).abs() < 1e-10);
assert_eq!(calculate_precision_at_k(&results, 5), 0.6);
assert_eq!(calculate_precision_at_k(&results, 0), 0.0);
}
#[test]
fn test_recall_at_k() {
let results = create_test_results();
let total_relevant = 5;
assert_eq!(calculate_recall_at_k(&results, 1, total_relevant), 0.2);
assert_eq!(calculate_recall_at_k(&results, 3, total_relevant), 0.4);
assert_eq!(calculate_recall_at_k(&results, 5, total_relevant), 0.6);
assert_eq!(calculate_recall_at_k(&results, 3, 0), 0.0);
}
#[test]
fn test_ndcg_at_k() {
let results = create_test_results();
let ndcg_3 = calculate_ndcg_at_k(&results, 3);
assert!(ndcg_3 >= 0.0 && ndcg_3 <= 1.0);
let ndcg_5 = calculate_ndcg_at_k(&results, 5);
assert!(ndcg_5 >= 0.0 && ndcg_5 <= 1.0);
let perfect_results = vec![
RankedResult {
id: 1,
relevant: true,
relevance_grade: 3,
},
RankedResult {
id: 2,
relevant: true,
relevance_grade: 2,
},
RankedResult {
id: 3,
relevant: true,
relevance_grade: 1,
},
];
let perfect_ndcg = calculate_ndcg_at_k(&perfect_results, 3);
assert!((perfect_ndcg - 1.0).abs() < 1e-10);
assert_eq!(calculate_ndcg_at_k(&results, 0), 0.0);
}
#[test]
fn test_mrr() {
let results = create_test_results();
assert_eq!(calculate_mrr(&results), 1.0);
let results2 = vec![
RankedResult {
id: 1,
relevant: false,
relevance_grade: 0,
},
RankedResult {
id: 2,
relevant: false,
relevance_grade: 0,
},
RankedResult {
id: 3,
relevant: true,
relevance_grade: 2,
},
];
let mrr = calculate_mrr(&results2);
assert!((mrr - 1.0 / 3.0).abs() < 1e-10);
let no_relevant = vec![RankedResult {
id: 1,
relevant: false,
relevance_grade: 0,
}];
assert_eq!(calculate_mrr(&no_relevant), 0.0);
}
#[test]
fn test_calculate_all_metrics() {
let results = create_test_results();
let k_values = vec![1, 3, 5];
let metrics = calculate_all_metrics(&results, 5, &k_values);
assert!(metrics.precision_at_k.contains_key(&1));
assert!(metrics.precision_at_k.contains_key(&3));
assert!(metrics.precision_at_k.contains_key(&5));
assert!(metrics.recall_at_k.contains_key(&1));
assert!(metrics.recall_at_k.contains_key(&3));
assert!(metrics.recall_at_k.contains_key(&5));
assert!(metrics.ndcg_at_k.contains_key(&1));
assert!(metrics.ndcg_at_k.contains_key(&3));
assert!(metrics.ndcg_at_k.contains_key(&5));
assert!(metrics.mrr > 0.0);
assert_eq!(metrics.precision_at_k[&1], 1.0);
assert_eq!(metrics.mrr, 1.0);
}
#[test]
fn test_edge_cases() {
let empty: Vec<RankedResult> = vec![];
assert_eq!(calculate_precision_at_k(&empty, 10), 0.0);
assert_eq!(calculate_recall_at_k(&empty, 10, 5), 0.0);
assert_eq!(calculate_ndcg_at_k(&empty, 10), 0.0);
assert_eq!(calculate_mrr(&empty), 0.0);
let all_irrelevant = vec![
RankedResult {
id: 1,
relevant: false,
relevance_grade: 0,
},
RankedResult {
id: 2,
relevant: false,
relevance_grade: 0,
},
];
assert_eq!(calculate_precision_at_k(&all_irrelevant, 2), 0.0);
assert_eq!(calculate_ndcg_at_k(&all_irrelevant, 2), 0.0);
assert_eq!(calculate_mrr(&all_irrelevant), 0.0);
}
}