1use crate::compiler;
7use crate::db::Database;
8use crate::index::VectorIndex;
9use crate::types::{CompilerConfig, EvalResult, EvalSummary, GoldenQuery, Result};
10use std::collections::HashSet;
11
12pub async fn evaluate(
25 queries: &[GoldenQuery],
26 db: &Database,
27 index: &VectorIndex,
28 config: &CompilerConfig,
29) -> Result<EvalSummary> {
30 let mut results = Vec::with_capacity(queries.len());
31 let mut latencies = Vec::with_capacity(queries.len());
32
33 for query in queries {
34 let start = std::time::Instant::now();
35
36 let working_set = compiler::compile(&query.query, config.clone(), db, index, None).await?;
38
39 let latency_ms = start.elapsed().as_millis() as u64;
40 latencies.push(latency_ms);
41
42 let result_paths: Vec<String> = working_set
44 .citations
45 .iter()
46 .map(|c| c.artifact_path.clone())
47 .collect();
48
49 let (recall, precision, mrr) = calculate_metrics(&query.expected_paths, &result_paths, query.k);
51
52 results.push(EvalResult {
53 query: query.query.clone(),
54 recall_at_k: recall,
55 precision_at_k: precision,
56 mrr,
57 latency_ms,
58 });
59 }
60
61 let mean_recall = results.iter().map(|r| r.recall_at_k).sum::<f32>() / results.len() as f32;
63 let mean_precision = results.iter().map(|r| r.precision_at_k).sum::<f32>() / results.len() as f32;
64 let mean_mrr = results.iter().map(|r| r.mrr).sum::<f32>() / results.len() as f32;
65
66 latencies.sort();
68 let p50_idx = latencies.len() / 2;
69 let p99_idx = (latencies.len() as f32 * 0.99) as usize;
70 let p50_latency_ms = latencies.get(p50_idx).copied().unwrap_or(0);
71 let p99_latency_ms = latencies.get(p99_idx.min(latencies.len().saturating_sub(1))).copied().unwrap_or(0);
72
73 Ok(EvalSummary {
74 mean_recall,
75 mean_precision,
76 mean_mrr,
77 p50_latency_ms,
78 p99_latency_ms,
79 query_count: queries.len(),
80 results,
81 })
82}
83
84fn calculate_metrics(expected: &[String], results: &[String], k: usize) -> (f32, f32, f32) {
96 let expected_set: HashSet<&String> = expected.iter().collect();
97 let top_k: Vec<&String> = results.iter().take(k).collect();
98 let top_k_set: HashSet<&String> = top_k.iter().copied().collect();
99
100 let found = expected_set.intersection(&top_k_set).count();
102 let recall = if expected.is_empty() {
103 1.0 } else {
105 found as f32 / expected.len() as f32
106 };
107
108 let precision = if top_k.is_empty() {
110 0.0
111 } else {
112 found as f32 / top_k.len() as f32
113 };
114
115 let mrr = results
117 .iter()
118 .position(|r| expected_set.contains(r))
119 .map(|pos| 1.0 / (pos + 1) as f32)
120 .unwrap_or(0.0);
121
122 (recall, precision, mrr)
123}
124
125pub fn load_golden_queries(path: &std::path::Path) -> Result<Vec<GoldenQuery>> {
135 let content = std::fs::read_to_string(path)?;
136 let queries: Vec<GoldenQuery> = serde_json::from_str(&content)?;
137 Ok(queries)
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_calculate_metrics_perfect() {
146 let expected = vec!["a.md".to_string(), "b.md".to_string()];
147 let results = vec!["a.md".to_string(), "b.md".to_string(), "c.md".to_string()];
148
149 let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
150
151 assert_eq!(recall, 1.0); assert!((precision - 2.0 / 3.0).abs() < 0.001); assert_eq!(mrr, 1.0); }
155
156 #[test]
157 fn test_calculate_metrics_partial() {
158 let expected = vec!["a.md".to_string(), "b.md".to_string()];
159 let results = vec!["c.md".to_string(), "a.md".to_string(), "d.md".to_string()];
160
161 let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
162
163 assert_eq!(recall, 0.5); assert!((precision - 1.0 / 3.0).abs() < 0.001); assert_eq!(mrr, 0.5); }
167
168 #[test]
169 fn test_calculate_metrics_none_found() {
170 let expected = vec!["a.md".to_string(), "b.md".to_string()];
171 let results = vec!["c.md".to_string(), "d.md".to_string()];
172
173 let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
174
175 assert_eq!(recall, 0.0);
176 assert_eq!(precision, 0.0);
177 assert_eq!(mrr, 0.0);
178 }
179}