1use std::path::Path;
9use std::time::Instant;
10
11use crate::core::bm25_index::BM25Index;
12use crate::core::hybrid_search::HybridConfig;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct EvalQuery {
16 pub query: String,
17 pub expected_files: Vec<String>,
18 #[serde(default)]
19 pub category: String,
20}
21
22#[derive(Debug, Clone, serde::Serialize)]
23pub struct EvalResult {
24 pub query: String,
25 pub category: String,
26 pub recall_at_5: f64,
27 pub recall_at_10: f64,
28 pub mrr: f64,
29 pub latency_us: u64,
30 pub retrieved_files: Vec<String>,
31 pub expected_files: Vec<String>,
32}
33
34#[derive(Debug, Clone, serde::Serialize)]
35pub struct EvalScorecard {
36 pub project: String,
37 pub total_queries: usize,
38 pub avg_recall_at_5: f64,
39 pub avg_recall_at_10: f64,
40 pub avg_mrr: f64,
41 pub avg_latency_us: u64,
42 pub per_category: Vec<CategoryScore>,
43 pub results: Vec<EvalResult>,
44}
45
46#[derive(Debug, Clone, serde::Serialize)]
47pub struct CategoryScore {
48 pub category: String,
49 pub count: usize,
50 pub avg_recall_at_5: f64,
51 pub avg_mrr: f64,
52}
53
54impl std::fmt::Display for EvalScorecard {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 writeln!(f, "Eval: {} ({} queries)", self.project, self.total_queries)?;
57 writeln!(f, " R@5: {:.1}%", self.avg_recall_at_5 * 100.0)?;
58 writeln!(f, " R@10: {:.1}%", self.avg_recall_at_10 * 100.0)?;
59 writeln!(f, " MRR: {:.3}", self.avg_mrr)?;
60 writeln!(f, " Latency: {}µs avg", self.avg_latency_us)?;
61 for cat in &self.per_category {
62 writeln!(
63 f,
64 " [{:12}] R@5={:.1}% MRR={:.3} (n={})",
65 cat.category,
66 cat.avg_recall_at_5 * 100.0,
67 cat.avg_mrr,
68 cat.count
69 )?;
70 }
71 Ok(())
72 }
73}
74
75pub fn run_eval(
78 project_root: &Path,
79 queries: &[EvalQuery],
80 index: &BM25Index,
81 config: &HybridConfig,
82) -> EvalScorecard {
83 let label = project_root
84 .file_name()
85 .and_then(|s| s.to_str())
86 .unwrap_or("unknown")
87 .to_string();
88
89 let mut results = Vec::with_capacity(queries.len());
90
91 for q in queries {
92 let start = Instant::now();
93 let retrieved = hybrid_eval_search(project_root, &q.query, index, config);
94 let latency = start.elapsed().as_micros() as u64;
95
96 let recall_5 = recall_at_k(&retrieved, &q.expected_files, 5);
97 let recall_10 = recall_at_k(&retrieved, &q.expected_files, 10);
98 let mrr = mean_reciprocal_rank(&retrieved, &q.expected_files);
99
100 results.push(EvalResult {
101 query: q.query.clone(),
102 category: q.category.clone(),
103 recall_at_5: recall_5,
104 recall_at_10: recall_10,
105 mrr,
106 latency_us: latency,
107 retrieved_files: retrieved.into_iter().take(10).collect(),
108 expected_files: q.expected_files.clone(),
109 });
110 }
111
112 let total = results.len();
113 let avg_r5 = results.iter().map(|r| r.recall_at_5).sum::<f64>() / total.max(1) as f64;
114 let avg_r10 = results.iter().map(|r| r.recall_at_10).sum::<f64>() / total.max(1) as f64;
115 let avg_mrr = results.iter().map(|r| r.mrr).sum::<f64>() / total.max(1) as f64;
116 let avg_lat = results.iter().map(|r| r.latency_us).sum::<u64>() / total.max(1) as u64;
117
118 let per_category = build_category_scores(&results);
119
120 EvalScorecard {
121 project: label,
122 total_queries: total,
123 avg_recall_at_5: avg_r5,
124 avg_recall_at_10: avg_r10,
125 avg_mrr,
126 avg_latency_us: avg_lat,
127 per_category,
128 results,
129 }
130}
131
132fn hybrid_eval_search(
135 project_root: &Path,
136 query: &str,
137 index: &BM25Index,
138 config: &HybridConfig,
139) -> Vec<String> {
140 #[cfg(feature = "embeddings")]
141 {
142 if let Ok(results) = try_hybrid_search(project_root, query, index, config) {
143 return results;
144 }
145 }
146 let _ = project_root;
147 index
148 .search(query, config.bm25_candidates)
149 .iter()
150 .map(|r| r.file_path.clone())
151 .collect()
152}
153
154#[cfg(feature = "embeddings")]
155fn try_hybrid_search(
156 project_root: &Path,
157 query: &str,
158 index: &BM25Index,
159 config: &HybridConfig,
160) -> Result<Vec<String>, String> {
161 use crate::core::dense_backend;
162 use crate::tools::ctx_semantic_search;
163
164 let (engine, mut embed_idx) = ctx_semantic_search::load_engine_and_index_pub(project_root)?;
165
166 let (aligned, _coverage, changed_files) = ctx_semantic_search::ensure_embeddings_for_eval(
167 project_root,
168 index,
169 engine,
170 &mut embed_idx,
171 )?;
172
173 let backend = dense_backend::DenseBackendKind::try_from_env()?;
174 let candidate_k = config.bm25_candidates.max(config.dense_candidates);
175
176 let mut results = dense_backend::hybrid_results(
177 backend,
178 project_root,
179 index,
180 engine,
181 &aligned,
182 &changed_files,
183 query,
184 candidate_k,
185 config,
186 None,
187 None,
188 )?;
189
190 if config.splade_weight > 0.0 {
191 let splade = crate::core::splade_retrieval::hybrid_retrieve(query, index, candidate_k);
192 if !splade.is_empty() {
193 ctx_semantic_search::boost_with_splade_pub(&mut results, &splade, config.splade_weight);
194 }
195 }
196
197 results.truncate(10);
198 Ok(results.iter().map(|r| r.file_path.clone()).collect())
199}
200
201pub fn generate_self_eval(index: &BM25Index, max_queries: usize) -> Vec<EvalQuery> {
204 let mut queries = Vec::new();
205
206 for chunk in index.chunks.iter().take(max_queries * 2) {
207 if queries.len() >= max_queries {
208 break;
209 }
210 if chunk.symbol_name.is_empty() || chunk.file_path.is_empty() {
211 continue;
212 }
213
214 let category = if chunk.symbol_name.starts_with("fn ") || chunk.symbol_name.contains("()") {
215 "function"
216 } else if chunk.symbol_name.starts_with("struct ")
217 || chunk.symbol_name.starts_with("class ")
218 {
219 "type"
220 } else {
221 "symbol"
222 };
223
224 let clean_name = chunk
225 .symbol_name
226 .replace("fn ", "")
227 .replace("struct ", "")
228 .replace("class ", "")
229 .replace("()", "");
230
231 queries.push(EvalQuery {
232 query: format!("where is {clean_name} defined"),
233 expected_files: vec![chunk.file_path.clone()],
234 category: category.to_string(),
235 });
236 }
237
238 queries
239}
240
241fn normalize_sep(p: &str) -> String {
245 p.replace('\\', "/")
246}
247
248fn recall_at_k(retrieved: &[String], expected: &[String], k: usize) -> f64 {
249 if expected.is_empty() {
250 return 0.0;
251 }
252 let top_k: Vec<String> = retrieved.iter().take(k).map(|r| normalize_sep(r)).collect();
253 let hits = expected
254 .iter()
255 .filter(|e| {
256 let e = normalize_sep(e);
257 top_k.iter().any(|r| r.ends_with(&e) || e.ends_with(r))
258 })
259 .count();
260 hits as f64 / expected.len() as f64
261}
262
263fn mean_reciprocal_rank(retrieved: &[String], expected: &[String]) -> f64 {
264 for (rank, r) in retrieved.iter().enumerate() {
265 let r = normalize_sep(r);
266 if expected.iter().any(|e| {
267 let e = normalize_sep(e);
268 r.ends_with(&e) || e.ends_with(&r)
269 }) {
270 return 1.0 / (rank as f64 + 1.0);
271 }
272 }
273 0.0
274}
275
276fn build_category_scores(results: &[EvalResult]) -> Vec<CategoryScore> {
277 use std::collections::HashMap;
278 let mut cat_map: HashMap<&str, Vec<&EvalResult>> = HashMap::new();
279 for r in results {
280 cat_map.entry(r.category.as_str()).or_default().push(r);
281 }
282
283 let mut scores: Vec<CategoryScore> = cat_map
284 .into_iter()
285 .map(|(cat, items)| {
286 let n = items.len();
287 CategoryScore {
288 category: cat.to_string(),
289 count: n,
290 avg_recall_at_5: items.iter().map(|r| r.recall_at_5).sum::<f64>() / n as f64,
291 avg_mrr: items.iter().map(|r| r.mrr).sum::<f64>() / n as f64,
292 }
293 })
294 .collect();
295 scores.sort_by(|a, b| a.category.cmp(&b.category));
296 scores
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn recall_at_k_full_match() {
305 let retrieved = vec!["a.rs".into(), "b.rs".into(), "c.rs".into()];
306 let expected = vec!["a.rs".into()];
307 assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
308 }
309
310 #[test]
311 fn recall_at_k_matches_across_path_separators() {
312 let retrieved = vec!["proj\\src\\auth.rs".into(), "proj\\src\\db.rs".into()];
315 let expected = vec!["src/auth.rs".into()];
316 assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
317 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
318 }
319
320 #[test]
321 fn recall_at_k_no_match() {
322 let retrieved = vec!["x.rs".into(), "y.rs".into()];
323 let expected = vec!["a.rs".into()];
324 assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.0);
325 }
326
327 #[test]
328 fn recall_at_k_partial() {
329 let retrieved = vec!["a.rs".into(), "x.rs".into()];
330 let expected = vec!["a.rs".into(), "b.rs".into()];
331 assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.5);
332 }
333
334 #[test]
335 fn mrr_first_hit() {
336 let retrieved = vec!["a.rs".into(), "b.rs".into()];
337 let expected = vec!["a.rs".into()];
338 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
339 }
340
341 #[test]
342 fn mrr_second_hit() {
343 let retrieved = vec!["x.rs".into(), "a.rs".into()];
344 let expected = vec!["a.rs".into()];
345 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.5);
346 }
347
348 #[test]
349 fn mrr_no_hit() {
350 let retrieved = vec!["x.rs".into()];
351 let expected = vec!["a.rs".into()];
352 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.0);
353 }
354
355 #[test]
356 fn empty_expected() {
357 assert_eq!(recall_at_k(&["a.rs".into()], &[], 5), 0.0);
358 }
359
360 #[test]
361 fn scorecard_display() {
362 let sc = EvalScorecard {
363 project: "test".into(),
364 total_queries: 10,
365 avg_recall_at_5: 0.8,
366 avg_recall_at_10: 0.9,
367 avg_mrr: 0.75,
368 avg_latency_us: 100,
369 per_category: vec![],
370 results: vec![],
371 };
372 let s = format!("{sc}");
373 assert!(s.contains("80.0%"));
374 assert!(s.contains("0.750"));
375 }
376}