Skip to main content

lean_ctx/core/
eval_harness.rs

1//! Retrieval evaluation harness for lean-ctx hybrid search.
2//!
3//! Runs a standardized query→expected_file benchmark to measure Recall@k,
4//! MRR (Mean Reciprocal Rank), and latency. Outputs NDJSON scorecards.
5//!
6//! Usage: `lean-ctx benchmark --eval [path]`
7
8use std::path::Path;
9use std::time::Instant;
10
11use crate::core::bm25_index::BM25Index;
12use crate::core::hybrid_search::HybridConfig;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct EvalQuery {
16    pub query: String,
17    pub expected_files: Vec<String>,
18    #[serde(default)]
19    pub category: String,
20}
21
22#[derive(Debug, Clone, serde::Serialize)]
23pub struct EvalResult {
24    pub query: String,
25    pub category: String,
26    pub recall_at_5: f64,
27    pub recall_at_10: f64,
28    pub mrr: f64,
29    pub latency_us: u64,
30    pub retrieved_files: Vec<String>,
31    pub expected_files: Vec<String>,
32}
33
34#[derive(Debug, Clone, serde::Serialize)]
35pub struct EvalScorecard {
36    pub project: String,
37    pub total_queries: usize,
38    pub avg_recall_at_5: f64,
39    pub avg_recall_at_10: f64,
40    pub avg_mrr: f64,
41    pub avg_latency_us: u64,
42    pub per_category: Vec<CategoryScore>,
43    pub results: Vec<EvalResult>,
44}
45
46#[derive(Debug, Clone, serde::Serialize)]
47pub struct CategoryScore {
48    pub category: String,
49    pub count: usize,
50    pub avg_recall_at_5: f64,
51    pub avg_mrr: f64,
52}
53
54impl std::fmt::Display for EvalScorecard {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        writeln!(f, "Eval: {} ({} queries)", self.project, self.total_queries)?;
57        writeln!(f, "  R@5:  {:.1}%", self.avg_recall_at_5 * 100.0)?;
58        writeln!(f, "  R@10: {:.1}%", self.avg_recall_at_10 * 100.0)?;
59        writeln!(f, "  MRR:  {:.3}", self.avg_mrr)?;
60        writeln!(f, "  Latency: {}µs avg", self.avg_latency_us)?;
61        for cat in &self.per_category {
62            writeln!(
63                f,
64                "  [{:12}] R@5={:.1}% MRR={:.3} (n={})",
65                cat.category,
66                cat.avg_recall_at_5 * 100.0,
67                cat.avg_mrr,
68                cat.count
69            )?;
70        }
71        Ok(())
72    }
73}
74
75/// Run evaluation using the full hybrid search pipeline (BM25 + embeddings + SPLADE).
76/// Falls back to BM25-only if embeddings are not available.
77pub fn run_eval(
78    project_root: &Path,
79    queries: &[EvalQuery],
80    index: &BM25Index,
81    config: &HybridConfig,
82) -> EvalScorecard {
83    let label = project_root
84        .file_name()
85        .and_then(|s| s.to_str())
86        .unwrap_or("unknown")
87        .to_string();
88
89    let mut results = Vec::with_capacity(queries.len());
90
91    for q in queries {
92        let start = Instant::now();
93        let retrieved = hybrid_eval_search(project_root, &q.query, index, config);
94        let latency = start.elapsed().as_micros() as u64;
95
96        let recall_5 = recall_at_k(&retrieved, &q.expected_files, 5);
97        let recall_10 = recall_at_k(&retrieved, &q.expected_files, 10);
98        let mrr = mean_reciprocal_rank(&retrieved, &q.expected_files);
99
100        results.push(EvalResult {
101            query: q.query.clone(),
102            category: q.category.clone(),
103            recall_at_5: recall_5,
104            recall_at_10: recall_10,
105            mrr,
106            latency_us: latency,
107            retrieved_files: retrieved.into_iter().take(10).collect(),
108            expected_files: q.expected_files.clone(),
109        });
110    }
111
112    let total = results.len();
113    let avg_r5 = results.iter().map(|r| r.recall_at_5).sum::<f64>() / total.max(1) as f64;
114    let avg_r10 = results.iter().map(|r| r.recall_at_10).sum::<f64>() / total.max(1) as f64;
115    let avg_mrr = results.iter().map(|r| r.mrr).sum::<f64>() / total.max(1) as f64;
116    let avg_lat = results.iter().map(|r| r.latency_us).sum::<u64>() / total.max(1) as u64;
117
118    let per_category = build_category_scores(&results);
119
120    EvalScorecard {
121        project: label,
122        total_queries: total,
123        avg_recall_at_5: avg_r5,
124        avg_recall_at_10: avg_r10,
125        avg_mrr,
126        avg_latency_us: avg_lat,
127        per_category,
128        results,
129    }
130}
131
132/// Full hybrid search for eval: BM25 + dense embeddings + SPLADE + RRF.
133/// Falls back to BM25-only when embeddings are unavailable.
134fn hybrid_eval_search(
135    project_root: &Path,
136    query: &str,
137    index: &BM25Index,
138    config: &HybridConfig,
139) -> Vec<String> {
140    #[cfg(feature = "embeddings")]
141    {
142        if let Ok(results) = try_hybrid_search(project_root, query, index, config) {
143            return results;
144        }
145    }
146    let _ = project_root;
147    index
148        .search(query, config.bm25_candidates)
149        .iter()
150        .map(|r| r.file_path.clone())
151        .collect()
152}
153
154#[cfg(feature = "embeddings")]
155fn try_hybrid_search(
156    project_root: &Path,
157    query: &str,
158    index: &BM25Index,
159    config: &HybridConfig,
160) -> Result<Vec<String>, String> {
161    use crate::core::dense_backend;
162    use crate::tools::ctx_semantic_search;
163
164    let (engine, mut embed_idx) = ctx_semantic_search::load_engine_and_index_pub(project_root)?;
165
166    let (aligned, _coverage, changed_files) = ctx_semantic_search::ensure_embeddings_for_eval(
167        project_root,
168        index,
169        engine,
170        &mut embed_idx,
171    )?;
172
173    let backend = dense_backend::DenseBackendKind::try_from_env()?;
174    let candidate_k = config.bm25_candidates.max(config.dense_candidates);
175
176    let mut results = dense_backend::hybrid_results(
177        backend,
178        project_root,
179        index,
180        engine,
181        &aligned,
182        &changed_files,
183        query,
184        candidate_k,
185        config,
186        None,
187        None,
188    )?;
189
190    if config.splade_weight > 0.0 {
191        let splade = crate::core::splade_retrieval::hybrid_retrieve(query, index, candidate_k);
192        if !splade.is_empty() {
193            ctx_semantic_search::boost_with_splade_pub(&mut results, &splade, config.splade_weight);
194        }
195    }
196
197    results.truncate(10);
198    Ok(results.iter().map(|r| r.file_path.clone()).collect())
199}
200
201/// Generate self-eval queries from an indexed codebase.
202/// Picks random symbols/files and constructs retrieval queries.
203pub fn generate_self_eval(index: &BM25Index, max_queries: usize) -> Vec<EvalQuery> {
204    let mut queries = Vec::new();
205
206    for chunk in index.chunks.iter().take(max_queries * 2) {
207        if queries.len() >= max_queries {
208            break;
209        }
210        if chunk.symbol_name.is_empty() || chunk.file_path.is_empty() {
211            continue;
212        }
213
214        let category = if chunk.symbol_name.starts_with("fn ") || chunk.symbol_name.contains("()") {
215            "function"
216        } else if chunk.symbol_name.starts_with("struct ")
217            || chunk.symbol_name.starts_with("class ")
218        {
219            "type"
220        } else {
221            "symbol"
222        };
223
224        let clean_name = chunk
225            .symbol_name
226            .replace("fn ", "")
227            .replace("struct ", "")
228            .replace("class ", "")
229            .replace("()", "");
230
231        queries.push(EvalQuery {
232            query: format!("where is {clean_name} defined"),
233            expected_files: vec![chunk.file_path.clone()],
234            category: category.to_string(),
235        });
236    }
237
238    queries
239}
240
241/// Normalizes path separators so comparisons are platform-independent (the
242/// retrieved paths use the OS separator — `\` on Windows — while expected paths
243/// in eval fixtures use `/`).
244fn normalize_sep(p: &str) -> String {
245    p.replace('\\', "/")
246}
247
248fn recall_at_k(retrieved: &[String], expected: &[String], k: usize) -> f64 {
249    if expected.is_empty() {
250        return 0.0;
251    }
252    let top_k: Vec<String> = retrieved.iter().take(k).map(|r| normalize_sep(r)).collect();
253    let hits = expected
254        .iter()
255        .filter(|e| {
256            let e = normalize_sep(e);
257            top_k.iter().any(|r| r.ends_with(&e) || e.ends_with(r))
258        })
259        .count();
260    hits as f64 / expected.len() as f64
261}
262
263fn mean_reciprocal_rank(retrieved: &[String], expected: &[String]) -> f64 {
264    for (rank, r) in retrieved.iter().enumerate() {
265        let r = normalize_sep(r);
266        if expected.iter().any(|e| {
267            let e = normalize_sep(e);
268            r.ends_with(&e) || e.ends_with(&r)
269        }) {
270            return 1.0 / (rank as f64 + 1.0);
271        }
272    }
273    0.0
274}
275
276fn build_category_scores(results: &[EvalResult]) -> Vec<CategoryScore> {
277    use std::collections::HashMap;
278    let mut cat_map: HashMap<&str, Vec<&EvalResult>> = HashMap::new();
279    for r in results {
280        cat_map.entry(r.category.as_str()).or_default().push(r);
281    }
282
283    let mut scores: Vec<CategoryScore> = cat_map
284        .into_iter()
285        .map(|(cat, items)| {
286            let n = items.len();
287            CategoryScore {
288                category: cat.to_string(),
289                count: n,
290                avg_recall_at_5: items.iter().map(|r| r.recall_at_5).sum::<f64>() / n as f64,
291                avg_mrr: items.iter().map(|r| r.mrr).sum::<f64>() / n as f64,
292            }
293        })
294        .collect();
295    scores.sort_by(|a, b| a.category.cmp(&b.category));
296    scores
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn recall_at_k_full_match() {
305        let retrieved = vec!["a.rs".into(), "b.rs".into(), "c.rs".into()];
306        let expected = vec!["a.rs".into()];
307        assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
308    }
309
310    #[test]
311    fn recall_at_k_matches_across_path_separators() {
312        // Retrieved paths may use the OS separator (backslash on Windows) while
313        // expected fixtures use '/'. They must still match.
314        let retrieved = vec!["proj\\src\\auth.rs".into(), "proj\\src\\db.rs".into()];
315        let expected = vec!["src/auth.rs".into()];
316        assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
317        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
318    }
319
320    #[test]
321    fn recall_at_k_no_match() {
322        let retrieved = vec!["x.rs".into(), "y.rs".into()];
323        let expected = vec!["a.rs".into()];
324        assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.0);
325    }
326
327    #[test]
328    fn recall_at_k_partial() {
329        let retrieved = vec!["a.rs".into(), "x.rs".into()];
330        let expected = vec!["a.rs".into(), "b.rs".into()];
331        assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.5);
332    }
333
334    #[test]
335    fn mrr_first_hit() {
336        let retrieved = vec!["a.rs".into(), "b.rs".into()];
337        let expected = vec!["a.rs".into()];
338        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
339    }
340
341    #[test]
342    fn mrr_second_hit() {
343        let retrieved = vec!["x.rs".into(), "a.rs".into()];
344        let expected = vec!["a.rs".into()];
345        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.5);
346    }
347
348    #[test]
349    fn mrr_no_hit() {
350        let retrieved = vec!["x.rs".into()];
351        let expected = vec!["a.rs".into()];
352        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.0);
353    }
354
355    #[test]
356    fn empty_expected() {
357        assert_eq!(recall_at_k(&["a.rs".into()], &[], 5), 0.0);
358    }
359
360    #[test]
361    fn scorecard_display() {
362        let sc = EvalScorecard {
363            project: "test".into(),
364            total_queries: 10,
365            avg_recall_at_5: 0.8,
366            avg_recall_at_10: 0.9,
367            avg_mrr: 0.75,
368            avg_latency_us: 100,
369            per_category: vec![],
370            results: vec![],
371        };
372        let s = format!("{sc}");
373        assert!(s.contains("80.0%"));
374        assert!(s.contains("0.750"));
375    }
376}