use std::collections::HashSet;
use std::hash::Hash;
#[must_use]
pub fn recall_at_k<T: Eq + Hash + Copy>(ground_truth: &[T], results: &[T]) -> f64 {
if ground_truth.is_empty() {
return 0.0;
}
let truth_set: HashSet<T> = ground_truth.iter().copied().collect();
let found = results.iter().filter(|id| truth_set.contains(id)).count();
#[allow(clippy::cast_precision_loss)]
let recall = found as f64 / ground_truth.len() as f64;
recall
}
#[must_use]
pub fn precision_at_k<T: Eq + Hash + Copy>(ground_truth: &[T], results: &[T]) -> f64 {
if results.is_empty() {
return 0.0;
}
let truth_set: HashSet<T> = ground_truth.iter().copied().collect();
let relevant = results.iter().filter(|id| truth_set.contains(id)).count();
#[allow(clippy::cast_precision_loss)]
let precision = relevant as f64 / results.len() as f64;
precision
}
#[must_use]
pub fn mrr<T: Eq + Hash + Copy>(ground_truth: &[T], results: &[T]) -> f64 {
let truth_set: HashSet<T> = ground_truth.iter().copied().collect();
for (rank, id) in results.iter().enumerate() {
if truth_set.contains(id) {
#[allow(clippy::cast_precision_loss)]
return 1.0 / (rank + 1) as f64;
}
}
0.0
}
#[must_use]
pub fn average_metrics<T: Eq + Hash + Copy>(
ground_truths: &[Vec<T>],
results_list: &[Vec<T>],
) -> (f64, f64, f64) {
if ground_truths.is_empty() || results_list.is_empty() {
return (0.0, 0.0, 0.0);
}
let n = ground_truths.len().min(results_list.len());
let mut total_recall = 0.0;
let mut total_precision = 0.0;
let mut total_mrr = 0.0;
for (gt, res) in ground_truths.iter().zip(results_list.iter()).take(n) {
total_recall += recall_at_k(gt, res);
total_precision += precision_at_k(gt, res);
total_mrr += mrr(gt, res);
}
#[allow(clippy::cast_precision_loss)]
let n_f64 = n as f64;
(
total_recall / n_f64,
total_precision / n_f64,
total_mrr / n_f64,
)
}
#[allow(clippy::cast_precision_loss)]
fn compute_dcg(relevances: &[f64], k: usize) -> f64 {
relevances
.iter()
.take(k)
.enumerate()
.map(|(i, &rel)| {
let gain = 2.0_f64.powf(rel) - 1.0;
let discount = (i as f64 + 2.0).log2();
gain / discount
})
.sum()
}
#[must_use]
pub fn ndcg_at_k(relevances: &[f64], k: usize) -> f64 {
if relevances.is_empty() {
return 0.0;
}
let k = k.min(relevances.len());
let dcg = compute_dcg(relevances, k);
let mut sorted_relevances = relevances.to_vec();
sorted_relevances.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let idcg = compute_dcg(&sorted_relevances, k);
if idcg == 0.0 {
return 0.0;
}
dcg / idcg
}
#[must_use]
pub fn hit_rate<T: Eq + Hash + Copy>(query_results: &[(Vec<T>, Vec<T>)], k: usize) -> f64 {
if query_results.is_empty() {
return 0.0;
}
let hits = query_results
.iter()
.filter(|(ground_truth, results)| {
let truth_set: HashSet<T> = ground_truth.iter().copied().collect();
results.iter().take(k).any(|r| truth_set.contains(r))
})
.count();
#[allow(clippy::cast_precision_loss)]
let hr = hits as f64 / query_results.len() as f64;
hr
}
#[must_use]
pub fn mean_average_precision(relevance_lists: &[Vec<bool>]) -> f64 {
if relevance_lists.is_empty() {
return 0.0;
}
let total_ap: f64 = relevance_lists
.iter()
.map(|relevances| {
let mut relevant_count = 0;
let mut precision_sum = 0.0;
for (i, &is_relevant) in relevances.iter().enumerate() {
if is_relevant {
relevant_count += 1;
#[allow(clippy::cast_precision_loss)]
let precision_at_i = f64::from(relevant_count) / (i + 1) as f64;
precision_sum += precision_at_i;
}
}
if relevant_count == 0 {
0.0
} else {
precision_sum / f64::from(relevant_count)
}
})
.sum();
#[allow(clippy::cast_precision_loss)]
let map = total_ap / relevance_lists.len() as f64;
map
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recall_at_k_perfect() {
let ground_truth = vec![1, 2, 3, 4, 5];
let results = vec![1, 2, 3, 4, 5];
let recall = recall_at_k(&ground_truth, &results);
assert!((recall - 1.0).abs() < 1e-5);
}
#[test]
fn test_recall_at_k_partial() {
let ground_truth = vec![1, 2, 3, 4, 5];
let results = vec![1, 3, 6, 2, 7];
let recall = recall_at_k(&ground_truth, &results);
assert!((recall - 0.6).abs() < 1e-5); }
#[test]
fn test_recall_at_k_empty_truth() {
let ground_truth: Vec<u64> = vec![];
let results = vec![1, 2, 3];
let recall = recall_at_k(&ground_truth, &results);
assert!((recall - 0.0).abs() < 1e-5);
}
#[test]
fn test_precision_at_k_perfect() {
let ground_truth = vec![1, 2, 3, 4, 5];
let results = vec![1, 2, 3];
let precision = precision_at_k(&ground_truth, &results);
assert!((precision - 1.0).abs() < 1e-5);
}
#[test]
fn test_precision_at_k_partial() {
let ground_truth = vec![1, 2, 3];
let results = vec![1, 4, 5, 6, 7];
let precision = precision_at_k(&ground_truth, &results);
assert!((precision - 0.2).abs() < 1e-5); }
#[test]
fn test_precision_at_k_empty_results() {
let ground_truth = vec![1, 2, 3];
let results: Vec<u64> = vec![];
let precision = precision_at_k(&ground_truth, &results);
assert!((precision - 0.0).abs() < 1e-5);
}
#[test]
fn test_mrr_first_relevant() {
let ground_truth = vec![1, 2, 3];
let results = vec![1, 4, 5];
let rank = mrr(&ground_truth, &results);
assert!((rank - 1.0).abs() < 1e-5); }
#[test]
fn test_mrr_second_relevant() {
let ground_truth = vec![1, 2, 3];
let results = vec![4, 1, 5];
let rank = mrr(&ground_truth, &results);
assert!((rank - 0.5).abs() < 1e-5); }
#[test]
fn test_mrr_no_relevant() {
let ground_truth = vec![1, 2, 3];
let results = vec![4, 5, 6];
let rank = mrr(&ground_truth, &results);
assert!((rank - 0.0).abs() < 1e-5);
}
}