use std::collections::HashSet;
pub fn normalize_path(path: &str) -> String {
path.strip_prefix("./").unwrap_or(path).to_string()
}
pub fn extract_file_from_doc_id(doc_id: &str) -> String {
let path = if let Some(idx) = doc_id.find("::") {
&doc_id[..idx]
} else {
doc_id
};
normalize_path(path)
}
#[derive(serde::Deserialize, Debug)]
pub struct QueryCase {
pub query: String,
pub expected: Vec<String>,
pub category: String,
pub split: String,
}
#[derive(Debug, Clone)]
pub struct EvalMetrics {
pub name: String,
pub num_queries: usize,
pub p5: f32,
pub p10: f32,
pub mrr: f32,
pub hit_rate: f32,
}
pub fn compute_metrics(
queries: &[QueryCase],
results_fn: &dyn Fn(&str) -> Vec<String>,
name: &str,
) -> EvalMetrics {
let mut total_p5 = 0.0f32;
let mut total_p10 = 0.0f32;
let mut total_rr = 0.0f32;
let mut hits = 0usize;
let n = queries.len();
for case in queries {
let result_ids = results_fn(&case.query);
let expected: HashSet<String> = case.expected.iter().map(|p| normalize_path(p)).collect();
let unique_5: HashSet<String> = result_ids
.iter()
.take(5)
.map(|id| extract_file_from_doc_id(id))
.filter(|id| expected.contains(id))
.collect();
let unique_10: HashSet<String> = result_ids
.iter()
.take(10)
.map(|id| extract_file_from_doc_id(id))
.filter(|id| expected.contains(id))
.collect();
let denom_5 = expected.len().clamp(1, 5) as f32;
let denom_10 = expected.len().clamp(1, 10) as f32;
total_p5 += unique_5.len() as f32 / denom_5;
total_p10 += unique_10.len() as f32 / denom_10;
let rr = result_ids
.iter()
.enumerate()
.find(|(_, id)| expected.contains(&extract_file_from_doc_id(id)))
.map(|(i, _)| 1.0 / (i as f32 + 1.0))
.unwrap_or(0.0);
total_rr += rr;
if rr > 0.0 {
hits += 1;
}
}
EvalMetrics {
name: name.to_string(),
num_queries: n,
p5: if n > 0 { total_p5 / n as f32 } else { 0.0 },
p10: if n > 0 { total_p10 / n as f32 } else { 0.0 },
mrr: if n > 0 { total_rr / n as f32 } else { 0.0 },
hit_rate: if n > 0 { hits as f32 / n as f32 } else { 0.0 },
}
}
pub fn print_metrics(metrics: &EvalMetrics) {
println!(" Queries: {}", metrics.num_queries);
println!(" P@5: {:.1}%", metrics.p5 * 100.0);
println!(" P@10: {:.1}%", metrics.p10 * 100.0);
println!(" MRR: {:.3}", metrics.mrr);
println!(" Hit rate: {:.1}%", metrics.hit_rate * 100.0);
}
pub fn print_per_query_detail(cases: &[QueryCase], results_fn: &dyn Fn(&str) -> Vec<String>) {
println!("{:<55} {:>6} {:>6} {:>6}", "Query", "P@5", "P@10", "RR");
println!("{}", "─".repeat(77));
for case in cases {
let result_ids = results_fn(&case.query);
let expected: HashSet<String> = case.expected.iter().map(|p| normalize_path(p)).collect();
let unique_5: HashSet<String> = result_ids
.iter()
.take(5)
.map(|id| extract_file_from_doc_id(id))
.filter(|id| expected.contains(id))
.collect();
let unique_10: HashSet<String> = result_ids
.iter()
.take(10)
.map(|id| extract_file_from_doc_id(id))
.filter(|id| expected.contains(id))
.collect();
let denom_5 = expected.len().clamp(1, 5) as f32;
let denom_10 = expected.len().clamp(1, 10) as f32;
let p5 = unique_5.len() as f32 / denom_5;
let p10 = unique_10.len() as f32 / denom_10;
let rr = result_ids
.iter()
.enumerate()
.find(|(_, id)| expected.contains(&extract_file_from_doc_id(id)))
.map(|(i, _)| 1.0 / (i as f32 + 1.0))
.unwrap_or(0.0);
let display_q = if case.query.len() > 53 {
format!("{}...", &case.query[..50])
} else {
case.query.clone()
};
println!(
"{:<55} {:>5.0}% {:>5.0}% {:>.3}",
display_q,
p5 * 100.0,
p10 * 100.0,
rr
);
}
}