use std::collections::HashSet;
pub fn recall_at_k(ground_truth: &[u32], retrieved: &[u32], k: usize) -> f32 {
if k == 0 || ground_truth.is_empty() {
return 0.0;
}
let gt_set: HashSet<u32> = ground_truth.iter().take(k).copied().collect();
let retrieved_set: HashSet<u32> = retrieved.iter().take(k).copied().collect();
let intersection = gt_set.intersection(&retrieved_set).count();
intersection as f32 / k as f32
}
pub fn precision_at_k(ground_truth: &[u32], retrieved: &[u32], k: usize) -> f32 {
if retrieved.is_empty() {
return 0.0;
}
let gt_set: HashSet<u32> = ground_truth.iter().take(k).copied().collect();
let retrieved_k: Vec<u32> = retrieved.iter().take(k).copied().collect();
let hits = retrieved_k.iter().filter(|id| gt_set.contains(id)).count();
hits as f32 / retrieved_k.len() as f32
}
pub fn mean_recall(ground_truths: &[Vec<u32>], retrievals: &[Vec<u32>], k: usize) -> f32 {
if ground_truths.is_empty() {
return 0.0;
}
let total: f32 = ground_truths
.iter()
.zip(retrievals.iter())
.map(|(gt, ret)| recall_at_k(gt, ret, k))
.sum();
total / ground_truths.len() as f32
}
pub fn recall_curve(
ground_truth: &[u32],
retrieved: &[u32],
k_values: &[usize],
) -> Vec<(usize, f32)> {
k_values
.iter()
.map(|&k| (k, recall_at_k(ground_truth, retrieved, k)))
.collect()
}
#[derive(Debug, Clone)]
pub struct QueryEvaluation {
pub recall: f32,
pub precision: f32,
pub latency_us: u64,
pub n_distance_computations: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct EvaluationSummary {
pub n_queries: usize,
pub k: usize,
pub mean_recall: f32,
pub min_recall: f32,
pub max_recall: f32,
pub recall_std: f32,
pub mean_latency_us: f64,
pub p50_latency_us: u64,
pub p95_latency_us: u64,
pub p99_latency_us: u64,
pub max_latency_us: u64,
pub qps: f64,
}
impl EvaluationSummary {
pub fn from_evaluations(evaluations: &[QueryEvaluation], k: usize) -> Self {
let n = evaluations.len();
if n == 0 {
return Self {
n_queries: 0,
k,
mean_recall: 0.0,
min_recall: 0.0,
max_recall: 0.0,
recall_std: 0.0,
mean_latency_us: 0.0,
p50_latency_us: 0,
p95_latency_us: 0,
p99_latency_us: 0,
max_latency_us: 0,
qps: 0.0,
};
}
let recalls: Vec<f32> = evaluations.iter().map(|e| e.recall).collect();
let mean_recall = recalls.iter().sum::<f32>() / n as f32;
let min_recall = recalls.iter().cloned().fold(f32::INFINITY, f32::min);
let max_recall = recalls.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let recall_variance: f32 = recalls
.iter()
.map(|r| (r - mean_recall).powi(2))
.sum::<f32>()
/ n as f32;
let recall_std = recall_variance.sqrt();
let mut latencies: Vec<u64> = evaluations.iter().map(|e| e.latency_us).collect();
latencies.sort_unstable();
let mean_latency_us = latencies.iter().sum::<u64>() as f64 / n as f64;
let p50_latency_us = latencies[n / 2];
let p95_latency_us = latencies[(n as f64 * 0.95) as usize];
let p99_latency_us = latencies[(n as f64 * 0.99) as usize];
let max_latency_us = latencies.last().copied().unwrap_or(0);
let total_time_us: u64 = latencies.iter().sum();
let qps = if total_time_us > 0 {
n as f64 / (total_time_us as f64 / 1_000_000.0)
} else {
0.0
};
Self {
n_queries: n,
k,
mean_recall,
min_recall,
max_recall,
recall_std,
mean_latency_us,
p50_latency_us,
p95_latency_us,
p99_latency_us,
max_latency_us,
qps,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recall_at_k() {
let gt = vec![1, 2, 3, 4, 5];
let retrieved = vec![1, 2, 3, 6, 7];
assert!((recall_at_k(>, &retrieved, 5) - 0.6).abs() < 0.001);
let perfect = vec![1, 2, 3, 4, 5];
assert!((recall_at_k(>, &perfect, 5) - 1.0).abs() < 0.001);
let miss = vec![6, 7, 8, 9, 10];
assert!((recall_at_k(>, &miss, 5) - 0.0).abs() < 0.001);
}
#[test]
fn test_precision_at_k() {
let gt = vec![1, 2, 3, 4, 5];
let retrieved = vec![1, 2, 6, 7, 8];
assert!((precision_at_k(>, &retrieved, 5) - 0.4).abs() < 0.001);
}
#[test]
fn test_recall_curve() {
let gt = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let retrieved = vec![1, 2, 3, 11, 12, 6, 7, 13, 14, 15];
let curve = recall_curve(>, &retrieved, &[1, 5, 10]);
assert_eq!(curve.len(), 3);
assert!((curve[0].1 - 1.0).abs() < 0.001); assert!((curve[1].1 - 0.6).abs() < 0.001); }
#[test]
fn test_evaluation_summary() {
let evals = vec![
QueryEvaluation {
recall: 0.9,
precision: 0.9,
latency_us: 100,
n_distance_computations: None,
},
QueryEvaluation {
recall: 0.8,
precision: 0.8,
latency_us: 200,
n_distance_computations: None,
},
QueryEvaluation {
recall: 1.0,
precision: 1.0,
latency_us: 50,
n_distance_computations: None,
},
];
let summary = EvaluationSummary::from_evaluations(&evals, 10);
assert_eq!(summary.n_queries, 3);
assert!((summary.mean_recall - 0.9).abs() < 0.001);
assert!((summary.min_recall - 0.8).abs() < 0.001);
assert!((summary.max_recall - 1.0).abs() < 0.001);
}
}