Skip to main content

lean_ctx/core/
eval.rs

1//! Downstream task evaluation framework for search quality.
2//!
3//! Measures how well the search pipeline supports actual coding tasks:
4//! - Retrieval precision/recall against known-relevant chunks
5//! - Mean Reciprocal Rank (MRR) for expected top results
6//! - Normalized Discounted Cumulative Gain (nDCG)
7//!
8//! Designed to compare BM25-only vs hybrid search and track quality over time.
9
10use std::collections::HashSet;
11
12/// A single evaluation query with expected relevant results.
13#[derive(Debug, Clone)]
14pub struct EvalQuery {
15    pub query: String,
16    pub relevant_files: Vec<String>,
17    pub expected_top: Option<String>,
18}
19
20/// Result of evaluating a search system against a query set.
21#[derive(Debug, Clone)]
22pub struct EvalReport {
23    pub query_count: usize,
24    pub precision_at_5: f64,
25    pub precision_at_10: f64,
26    pub recall_at_10: f64,
27    pub mrr: f64,
28    pub ndcg_at_10: f64,
29    pub per_query: Vec<QueryScore>,
30}
31
32#[derive(Debug, Clone)]
33pub struct QueryScore {
34    pub query: String,
35    pub precision_at_k: f64,
36    pub recall: f64,
37    pub reciprocal_rank: f64,
38    pub ndcg: f64,
39}
40
41/// Retrieved result for evaluation (file path + score).
42#[derive(Debug, Clone)]
43pub struct RetrievedItem {
44    pub file_path: String,
45    pub score: f64,
46}
47
48/// Evaluate search results against a set of queries with known relevance.
49pub fn evaluate(queries: &[EvalQuery], retrieve_fn: &dyn Fn(&str) -> Vec<RetrievedItem>) -> EvalReport {
50    let mut per_query = Vec::with_capacity(queries.len());
51    let mut sum_p5 = 0.0;
52    let mut sum_p10 = 0.0;
53    let mut sum_r10 = 0.0;
54    let mut sum_mrr = 0.0;
55    let mut sum_ndcg = 0.0;
56
57    for q in queries {
58        let results = retrieve_fn(&q.query);
59        let relevant: HashSet<&str> = q.relevant_files.iter().map(|s| s.as_str()).collect();
60
61        let p5 = precision_at_k(&results, &relevant, 5);
62        let p10 = precision_at_k(&results, &relevant, 10);
63        let r10 = recall_at_k(&results, &relevant, 10);
64        let rr = reciprocal_rank(&results, &relevant);
65        let ndcg = ndcg_at_k(&results, &relevant, 10);
66
67        sum_p5 += p5;
68        sum_p10 += p10;
69        sum_r10 += r10;
70        sum_mrr += rr;
71        sum_ndcg += ndcg;
72
73        per_query.push(QueryScore {
74            query: q.query.clone(),
75            precision_at_k: p10,
76            recall: r10,
77            reciprocal_rank: rr,
78            ndcg,
79        });
80    }
81
82    let n = queries.len().max(1) as f64;
83    EvalReport {
84        query_count: queries.len(),
85        precision_at_5: sum_p5 / n,
86        precision_at_10: sum_p10 / n,
87        recall_at_10: sum_r10 / n,
88        mrr: sum_mrr / n,
89        ndcg_at_10: sum_ndcg / n,
90        per_query,
91    }
92}
93
94fn precision_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
95    let top_k: Vec<&RetrievedItem> = results.iter().take(k).collect();
96    if top_k.is_empty() {
97        return 0.0;
98    }
99    let hits = top_k
100        .iter()
101        .filter(|r| relevant.contains(r.file_path.as_str()))
102        .count();
103    hits as f64 / top_k.len() as f64
104}
105
106fn recall_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
107    if relevant.is_empty() {
108        return 0.0;
109    }
110    let hits = results
111        .iter()
112        .take(k)
113        .filter(|r| relevant.contains(r.file_path.as_str()))
114        .count();
115    hits as f64 / relevant.len() as f64
116}
117
118fn reciprocal_rank(results: &[RetrievedItem], relevant: &HashSet<&str>) -> f64 {
119    for (i, r) in results.iter().enumerate() {
120        if relevant.contains(r.file_path.as_str()) {
121            return 1.0 / (i + 1) as f64;
122        }
123    }
124    0.0
125}
126
127fn ndcg_at_k(results: &[RetrievedItem], relevant: &HashSet<&str>, k: usize) -> f64 {
128    let dcg = results
129        .iter()
130        .take(k)
131        .enumerate()
132        .map(|(i, r)| {
133            let gain = if relevant.contains(r.file_path.as_str()) { 1.0 } else { 0.0 };
134            gain / (2.0f64 + i as f64).log2()
135        })
136        .sum::<f64>();
137
138    let ideal_count = relevant.len().min(k);
139    let ideal_dcg: f64 = (0..ideal_count)
140        .map(|i| 1.0 / (2.0f64 + i as f64).log2())
141        .sum();
142
143    if ideal_dcg == 0.0 {
144        return 0.0;
145    }
146    dcg / ideal_dcg
147}
148
149impl EvalReport {
150    pub fn to_compact_string(&self) -> String {
151        format!(
152            "P@5={:.3} P@10={:.3} R@10={:.3} MRR={:.3} nDCG@10={:.3} (n={})",
153            self.precision_at_5,
154            self.precision_at_10,
155            self.recall_at_10,
156            self.mrr,
157            self.ndcg_at_10,
158            self.query_count,
159        )
160    }
161
162    pub fn passed_threshold(&self, min_mrr: f64, min_ndcg: f64) -> bool {
163        self.mrr >= min_mrr && self.ndcg_at_10 >= min_ndcg
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    fn items(files: &[&str]) -> Vec<RetrievedItem> {
172        files
173            .iter()
174            .enumerate()
175            .map(|(i, f)| RetrievedItem {
176                file_path: f.to_string(),
177                score: 10.0 - i as f64,
178            })
179            .collect()
180    }
181
182    #[test]
183    fn precision_at_k_perfect() {
184        let relevant: HashSet<&str> = ["a.rs", "b.rs"].into_iter().collect();
185        let results = items(&["a.rs", "b.rs", "c.rs"]);
186        assert!((precision_at_k(&results, &relevant, 2) - 1.0).abs() < 1e-6);
187    }
188
189    #[test]
190    fn precision_at_k_half() {
191        let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
192        let results = items(&["a.rs", "b.rs"]);
193        assert!((precision_at_k(&results, &relevant, 2) - 0.5).abs() < 1e-6);
194    }
195
196    #[test]
197    fn precision_at_k_none() {
198        let relevant: HashSet<&str> = ["x.rs"].into_iter().collect();
199        let results = items(&["a.rs", "b.rs"]);
200        assert!((precision_at_k(&results, &relevant, 2) - 0.0).abs() < 1e-6);
201    }
202
203    #[test]
204    fn recall_at_k_full() {
205        let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
206        let results = items(&["a.rs", "b.rs", "c.rs"]);
207        assert!((recall_at_k(&results, &relevant, 3) - 1.0).abs() < 1e-6);
208    }
209
210    #[test]
211    fn recall_at_k_partial() {
212        let relevant: HashSet<&str> = ["a.rs", "d.rs"].into_iter().collect();
213        let results = items(&["a.rs", "b.rs", "c.rs"]);
214        assert!((recall_at_k(&results, &relevant, 3) - 0.5).abs() < 1e-6);
215    }
216
217    #[test]
218    fn mrr_first_position() {
219        let relevant: HashSet<&str> = ["a.rs"].into_iter().collect();
220        let results = items(&["a.rs", "b.rs"]);
221        assert!((reciprocal_rank(&results, &relevant) - 1.0).abs() < 1e-6);
222    }
223
224    #[test]
225    fn mrr_second_position() {
226        let relevant: HashSet<&str> = ["b.rs"].into_iter().collect();
227        let results = items(&["a.rs", "b.rs"]);
228        assert!((reciprocal_rank(&results, &relevant) - 0.5).abs() < 1e-6);
229    }
230
231    #[test]
232    fn mrr_not_found() {
233        let relevant: HashSet<&str> = ["x.rs"].into_iter().collect();
234        let results = items(&["a.rs", "b.rs"]);
235        assert!((reciprocal_rank(&results, &relevant) - 0.0).abs() < 1e-6);
236    }
237
238    #[test]
239    fn ndcg_perfect() {
240        let relevant: HashSet<&str> = ["a.rs", "b.rs"].into_iter().collect();
241        let results = items(&["a.rs", "b.rs", "c.rs"]);
242        let score = ndcg_at_k(&results, &relevant, 3);
243        assert!((score - 1.0).abs() < 1e-6, "perfect ranking should give nDCG=1.0, got {score}");
244    }
245
246    #[test]
247    fn ndcg_imperfect() {
248        let relevant: HashSet<&str> = ["b.rs"].into_iter().collect();
249        let results = items(&["a.rs", "b.rs", "c.rs"]);
250        let score = ndcg_at_k(&results, &relevant, 3);
251        assert!(score > 0.0 && score < 1.0, "imperfect ranking: {score}");
252    }
253
254    #[test]
255    fn evaluate_pipeline() {
256        let queries = vec![
257            EvalQuery {
258                query: "authentication".to_string(),
259                relevant_files: vec!["auth.rs".to_string()],
260                expected_top: Some("auth.rs".to_string()),
261            },
262            EvalQuery {
263                query: "database connection".to_string(),
264                relevant_files: vec!["db.rs".to_string(), "pool.rs".to_string()],
265                expected_top: Some("db.rs".to_string()),
266            },
267        ];
268
269        let report = evaluate(&queries, &|q| {
270            if q.contains("auth") {
271                items(&["auth.rs", "user.rs", "session.rs"])
272            } else {
273                items(&["db.rs", "pool.rs", "config.rs"])
274            }
275        });
276
277        assert_eq!(report.query_count, 2);
278        assert!(report.mrr > 0.5, "MRR should be high: {}", report.mrr);
279        assert!(report.ndcg_at_10 > 0.5, "nDCG should be high: {}", report.ndcg_at_10);
280    }
281
282    #[test]
283    fn report_threshold() {
284        let queries = vec![EvalQuery {
285            query: "test".to_string(),
286            relevant_files: vec!["test.rs".to_string()],
287            expected_top: None,
288        }];
289        let report = evaluate(&queries, &|_| items(&["test.rs", "other.rs"]));
290        assert!(report.passed_threshold(0.5, 0.5));
291        assert!(!report.passed_threshold(2.0, 2.0));
292    }
293}