Skip to main content

lean_ctx/core/
eval_harness.rs

1//! Retrieval evaluation harness for lean-ctx hybrid search.
2//!
3//! Runs a standardized query→expected_file benchmark to measure Recall@k,
4//! MRR (Mean Reciprocal Rank), and latency. Outputs NDJSON scorecards.
5//!
6//! Usage: `lean-ctx benchmark --eval [path]`
7
8use std::path::Path;
9use std::time::Instant;
10
11use crate::core::bm25_index::BM25Index;
12use crate::core::hybrid_search::HybridConfig;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct EvalQuery {
16    pub query: String,
17    pub expected_files: Vec<String>,
18    #[serde(default)]
19    pub category: String,
20}
21
22#[derive(Debug, Clone, serde::Serialize)]
23pub struct EvalResult {
24    pub query: String,
25    pub category: String,
26    pub recall_at_5: f64,
27    pub recall_at_10: f64,
28    pub mrr: f64,
29    pub latency_us: u64,
30    pub retrieved_files: Vec<String>,
31    pub expected_files: Vec<String>,
32}
33
34#[derive(Debug, Clone, serde::Serialize)]
35pub struct EvalScorecard {
36    pub project: String,
37    pub total_queries: usize,
38    pub avg_recall_at_5: f64,
39    pub avg_recall_at_10: f64,
40    pub avg_mrr: f64,
41    pub avg_latency_us: u64,
42    pub per_category: Vec<CategoryScore>,
43    pub results: Vec<EvalResult>,
44}
45
46#[derive(Debug, Clone, serde::Serialize)]
47pub struct CategoryScore {
48    pub category: String,
49    pub count: usize,
50    pub avg_recall_at_5: f64,
51    pub avg_mrr: f64,
52}
53
54impl std::fmt::Display for EvalScorecard {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        writeln!(f, "Eval: {} ({} queries)", self.project, self.total_queries)?;
57        writeln!(f, "  R@5:  {:.1}%", self.avg_recall_at_5 * 100.0)?;
58        writeln!(f, "  R@10: {:.1}%", self.avg_recall_at_10 * 100.0)?;
59        writeln!(f, "  MRR:  {:.3}", self.avg_mrr)?;
60        writeln!(f, "  Latency: {}µs avg", self.avg_latency_us)?;
61        for cat in &self.per_category {
62            writeln!(
63                f,
64                "  [{:12}] R@5={:.1}% MRR={:.3} (n={})",
65                cat.category,
66                cat.avg_recall_at_5 * 100.0,
67                cat.avg_mrr,
68                cat.count
69            )?;
70        }
71        Ok(())
72    }
73}
74
75/// Run evaluation using the full hybrid search pipeline (BM25 + embeddings + SPLADE).
76/// Falls back to BM25-only if embeddings are not available.
77pub fn run_eval(
78    project_root: &Path,
79    queries: &[EvalQuery],
80    index: &BM25Index,
81    config: &HybridConfig,
82) -> EvalScorecard {
83    let label = project_root
84        .file_name()
85        .and_then(|s| s.to_str())
86        .unwrap_or("unknown")
87        .to_string();
88
89    let mut results = Vec::with_capacity(queries.len());
90
91    for q in queries {
92        let start = Instant::now();
93        let retrieved = hybrid_eval_search(project_root, &q.query, index, config);
94        let latency = start.elapsed().as_micros() as u64;
95
96        let recall_5 = recall_at_k(&retrieved, &q.expected_files, 5);
97        let recall_10 = recall_at_k(&retrieved, &q.expected_files, 10);
98        let mrr = mean_reciprocal_rank(&retrieved, &q.expected_files);
99
100        results.push(EvalResult {
101            query: q.query.clone(),
102            category: q.category.clone(),
103            recall_at_5: recall_5,
104            recall_at_10: recall_10,
105            mrr,
106            latency_us: latency,
107            retrieved_files: retrieved.into_iter().take(10).collect(),
108            expected_files: q.expected_files.clone(),
109        });
110    }
111
112    let total = results.len();
113    let avg_r5 = results.iter().map(|r| r.recall_at_5).sum::<f64>() / total.max(1) as f64;
114    let avg_r10 = results.iter().map(|r| r.recall_at_10).sum::<f64>() / total.max(1) as f64;
115    let avg_mrr = results.iter().map(|r| r.mrr).sum::<f64>() / total.max(1) as f64;
116    let avg_lat = results.iter().map(|r| r.latency_us).sum::<u64>() / total.max(1) as u64;
117
118    let per_category = build_category_scores(&results);
119
120    EvalScorecard {
121        project: label,
122        total_queries: total,
123        avg_recall_at_5: avg_r5,
124        avg_recall_at_10: avg_r10,
125        avg_mrr,
126        avg_latency_us: avg_lat,
127        per_category,
128        results,
129    }
130}
131
132/// Full hybrid search for eval: BM25 + dense embeddings + SPLADE + RRF.
133/// Falls back to BM25-only when embeddings are unavailable.
134fn hybrid_eval_search(
135    project_root: &Path,
136    query: &str,
137    index: &BM25Index,
138    config: &HybridConfig,
139) -> Vec<String> {
140    #[cfg(feature = "embeddings")]
141    {
142        if let Ok(results) = try_hybrid_search(project_root, query, index, config) {
143            return results;
144        }
145    }
146    let _ = project_root;
147    index
148        .search(query, config.bm25_candidates)
149        .iter()
150        .map(|r| r.file_path.clone())
151        .collect()
152}
153
154#[cfg(feature = "embeddings")]
155fn try_hybrid_search(
156    project_root: &Path,
157    query: &str,
158    index: &BM25Index,
159    config: &HybridConfig,
160) -> Result<Vec<String>, String> {
161    use crate::core::dense_backend;
162    use crate::tools::ctx_semantic_search;
163
164    let (engine, mut embed_idx) = ctx_semantic_search::load_engine_and_index_pub(project_root)?;
165
166    let (aligned, _coverage, changed_files) = ctx_semantic_search::ensure_embeddings_for_eval(
167        project_root,
168        index,
169        engine,
170        &mut embed_idx,
171    )?;
172
173    let backend = dense_backend::DenseBackendKind::try_from_env()?;
174    let candidate_k = config.bm25_candidates.max(config.dense_candidates);
175
176    let mut results = dense_backend::hybrid_results(
177        backend,
178        project_root,
179        index,
180        engine,
181        &aligned,
182        &changed_files,
183        query,
184        candidate_k,
185        config,
186        None,
187        None,
188    )?;
189
190    if config.splade_weight > 0.0 {
191        let splade = crate::core::splade_retrieval::hybrid_retrieve(query, index, candidate_k);
192        if !splade.is_empty() {
193            ctx_semantic_search::boost_with_splade_pub(&mut results, &splade, config.splade_weight);
194        }
195    }
196
197    results.truncate(10);
198    Ok(results.iter().map(|r| r.file_path.clone()).collect())
199}
200
201/// Generate self-eval queries from an indexed codebase.
202/// Picks random symbols/files and constructs retrieval queries.
203pub fn generate_self_eval(index: &BM25Index, max_queries: usize) -> Vec<EvalQuery> {
204    let mut queries = Vec::new();
205
206    for chunk in index.chunks.iter().take(max_queries * 2) {
207        if queries.len() >= max_queries {
208            break;
209        }
210        if chunk.symbol_name.is_empty() || chunk.file_path.is_empty() {
211            continue;
212        }
213
214        let category = if chunk.symbol_name.starts_with("fn ") || chunk.symbol_name.contains("()") {
215            "function"
216        } else if chunk.symbol_name.starts_with("struct ")
217            || chunk.symbol_name.starts_with("class ")
218        {
219            "type"
220        } else {
221            "symbol"
222        };
223
224        let clean_name = chunk
225            .symbol_name
226            .replace("fn ", "")
227            .replace("struct ", "")
228            .replace("class ", "")
229            .replace("()", "");
230
231        queries.push(EvalQuery {
232            query: format!("where is {clean_name} defined"),
233            expected_files: vec![chunk.file_path.clone()],
234            category: category.to_string(),
235        });
236    }
237
238    queries
239}
240
241fn recall_at_k(retrieved: &[String], expected: &[String], k: usize) -> f64 {
242    if expected.is_empty() {
243        return 0.0;
244    }
245    let top_k: Vec<&str> = retrieved
246        .iter()
247        .take(k)
248        .map(std::string::String::as_str)
249        .collect();
250    let hits = expected
251        .iter()
252        .filter(|e| {
253            top_k
254                .iter()
255                .any(|r| r.ends_with(e.as_str()) || e.ends_with(r))
256        })
257        .count();
258    hits as f64 / expected.len() as f64
259}
260
261fn mean_reciprocal_rank(retrieved: &[String], expected: &[String]) -> f64 {
262    for (rank, r) in retrieved.iter().enumerate() {
263        if expected
264            .iter()
265            .any(|e| r.ends_with(e.as_str()) || e.ends_with(r.as_str()))
266        {
267            return 1.0 / (rank as f64 + 1.0);
268        }
269    }
270    0.0
271}
272
273fn build_category_scores(results: &[EvalResult]) -> Vec<CategoryScore> {
274    use std::collections::HashMap;
275    let mut cat_map: HashMap<&str, Vec<&EvalResult>> = HashMap::new();
276    for r in results {
277        cat_map.entry(r.category.as_str()).or_default().push(r);
278    }
279
280    let mut scores: Vec<CategoryScore> = cat_map
281        .into_iter()
282        .map(|(cat, items)| {
283            let n = items.len();
284            CategoryScore {
285                category: cat.to_string(),
286                count: n,
287                avg_recall_at_5: items.iter().map(|r| r.recall_at_5).sum::<f64>() / n as f64,
288                avg_mrr: items.iter().map(|r| r.mrr).sum::<f64>() / n as f64,
289            }
290        })
291        .collect();
292    scores.sort_by(|a, b| a.category.cmp(&b.category));
293    scores
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn recall_at_k_full_match() {
302        let retrieved = vec!["a.rs".into(), "b.rs".into(), "c.rs".into()];
303        let expected = vec!["a.rs".into()];
304        assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
305    }
306
307    #[test]
308    fn recall_at_k_no_match() {
309        let retrieved = vec!["x.rs".into(), "y.rs".into()];
310        let expected = vec!["a.rs".into()];
311        assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.0);
312    }
313
314    #[test]
315    fn recall_at_k_partial() {
316        let retrieved = vec!["a.rs".into(), "x.rs".into()];
317        let expected = vec!["a.rs".into(), "b.rs".into()];
318        assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.5);
319    }
320
321    #[test]
322    fn mrr_first_hit() {
323        let retrieved = vec!["a.rs".into(), "b.rs".into()];
324        let expected = vec!["a.rs".into()];
325        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
326    }
327
328    #[test]
329    fn mrr_second_hit() {
330        let retrieved = vec!["x.rs".into(), "a.rs".into()];
331        let expected = vec!["a.rs".into()];
332        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.5);
333    }
334
335    #[test]
336    fn mrr_no_hit() {
337        let retrieved = vec!["x.rs".into()];
338        let expected = vec!["a.rs".into()];
339        assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.0);
340    }
341
342    #[test]
343    fn empty_expected() {
344        assert_eq!(recall_at_k(&["a.rs".into()], &[], 5), 0.0);
345    }
346
347    #[test]
348    fn scorecard_display() {
349        let sc = EvalScorecard {
350            project: "test".into(),
351            total_queries: 10,
352            avg_recall_at_5: 0.8,
353            avg_recall_at_10: 0.9,
354            avg_mrr: 0.75,
355            avg_latency_us: 100,
356            per_category: vec![],
357            results: vec![],
358        };
359        let s = format!("{sc}");
360        assert!(s.contains("80.0%"));
361        assert!(s.contains("0.750"));
362    }
363}