avocado_core/
eval.rs

1//! Evaluation module for retrieval quality metrics
2//!
3//! Provides tools for evaluating retrieval quality using golden queries
4//! with expected results. Calculates recall@k, precision@k, MRR, and latency.
5
6use crate::compiler;
7use crate::db::Database;
8use crate::index::VectorIndex;
9use crate::types::{CompilerConfig, EvalResult, EvalSummary, GoldenQuery, Result};
10use std::collections::HashSet;
11
12/// Evaluate a set of golden queries and compute metrics
13///
14/// # Arguments
15///
16/// * `queries` - Golden queries with expected results
17/// * `db` - Database handle
18/// * `index` - Vector index
19/// * `config` - Compiler configuration
20///
21/// # Returns
22///
23/// Evaluation summary with aggregate metrics
24pub async fn evaluate(
25    queries: &[GoldenQuery],
26    db: &Database,
27    index: &VectorIndex,
28    config: &CompilerConfig,
29) -> Result<EvalSummary> {
30    let mut results = Vec::with_capacity(queries.len());
31    let mut latencies = Vec::with_capacity(queries.len());
32
33    for query in queries {
34        let start = std::time::Instant::now();
35
36        // Compile context for this query
37        let working_set = compiler::compile(&query.query, config.clone(), db, index, None).await?;
38
39        let latency_ms = start.elapsed().as_millis() as u64;
40        latencies.push(latency_ms);
41
42        // Get artifact paths from results
43        let result_paths: Vec<String> = working_set
44            .citations
45            .iter()
46            .map(|c| c.artifact_path.clone())
47            .collect();
48
49        // Calculate metrics
50        let (recall, precision, mrr) = calculate_metrics(&query.expected_paths, &result_paths, query.k);
51
52        results.push(EvalResult {
53            query: query.query.clone(),
54            recall_at_k: recall,
55            precision_at_k: precision,
56            mrr,
57            latency_ms,
58        });
59    }
60
61    // Calculate aggregate metrics
62    let mean_recall = results.iter().map(|r| r.recall_at_k).sum::<f32>() / results.len() as f32;
63    let mean_precision = results.iter().map(|r| r.precision_at_k).sum::<f32>() / results.len() as f32;
64    let mean_mrr = results.iter().map(|r| r.mrr).sum::<f32>() / results.len() as f32;
65
66    // Calculate latency percentiles
67    latencies.sort();
68    let p50_idx = latencies.len() / 2;
69    let p99_idx = (latencies.len() as f32 * 0.99) as usize;
70    let p50_latency_ms = latencies.get(p50_idx).copied().unwrap_or(0);
71    let p99_latency_ms = latencies.get(p99_idx.min(latencies.len().saturating_sub(1))).copied().unwrap_or(0);
72
73    Ok(EvalSummary {
74        mean_recall,
75        mean_precision,
76        mean_mrr,
77        p50_latency_ms,
78        p99_latency_ms,
79        query_count: queries.len(),
80        results,
81    })
82}
83
84/// Calculate recall@k, precision@k, and MRR for a single query
85///
86/// # Arguments
87///
88/// * `expected` - Expected artifact paths
89/// * `results` - Retrieved artifact paths (in order)
90/// * `k` - Top k to consider
91///
92/// # Returns
93///
94/// (recall_at_k, precision_at_k, mrr)
95fn calculate_metrics(expected: &[String], results: &[String], k: usize) -> (f32, f32, f32) {
96    let expected_set: HashSet<&String> = expected.iter().collect();
97    let top_k: Vec<&String> = results.iter().take(k).collect();
98    let top_k_set: HashSet<&String> = top_k.iter().copied().collect();
99
100    // Recall@k: What fraction of expected items appear in top k?
101    let found = expected_set.intersection(&top_k_set).count();
102    let recall = if expected.is_empty() {
103        1.0 // No expected items means perfect recall
104    } else {
105        found as f32 / expected.len() as f32
106    };
107
108    // Precision@k: What fraction of top k items are expected?
109    let precision = if top_k.is_empty() {
110        0.0
111    } else {
112        found as f32 / top_k.len() as f32
113    };
114
115    // MRR: Reciprocal of rank of first relevant result
116    let mrr = results
117        .iter()
118        .position(|r| expected_set.contains(r))
119        .map(|pos| 1.0 / (pos + 1) as f32)
120        .unwrap_or(0.0);
121
122    (recall, precision, mrr)
123}
124
125/// Load golden queries from JSON file
126///
127/// # Arguments
128///
129/// * `path` - Path to JSON file
130///
131/// # Returns
132///
133/// Vector of golden queries
134pub fn load_golden_queries(path: &std::path::Path) -> Result<Vec<GoldenQuery>> {
135    let content = std::fs::read_to_string(path)?;
136    let queries: Vec<GoldenQuery> = serde_json::from_str(&content)?;
137    Ok(queries)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_calculate_metrics_perfect() {
146        let expected = vec!["a.md".to_string(), "b.md".to_string()];
147        let results = vec!["a.md".to_string(), "b.md".to_string(), "c.md".to_string()];
148
149        let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
150
151        assert_eq!(recall, 1.0); // All expected found
152        assert!((precision - 2.0 / 3.0).abs() < 0.001); // 2/3 precision
153        assert_eq!(mrr, 1.0); // First result is relevant
154    }
155
156    #[test]
157    fn test_calculate_metrics_partial() {
158        let expected = vec!["a.md".to_string(), "b.md".to_string()];
159        let results = vec!["c.md".to_string(), "a.md".to_string(), "d.md".to_string()];
160
161        let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
162
163        assert_eq!(recall, 0.5); // 1/2 expected found
164        assert!((precision - 1.0 / 3.0).abs() < 0.001); // 1/3 precision
165        assert_eq!(mrr, 0.5); // First relevant at position 2
166    }
167
168    #[test]
169    fn test_calculate_metrics_none_found() {
170        let expected = vec!["a.md".to_string(), "b.md".to_string()];
171        let results = vec!["c.md".to_string(), "d.md".to_string()];
172
173        let (recall, precision, mrr) = calculate_metrics(&expected, &results, 3);
174
175        assert_eq!(recall, 0.0);
176        assert_eq!(precision, 0.0);
177        assert_eq!(mrr, 0.0);
178    }
179}