1use std::path::Path;
9use std::time::Instant;
10
11use crate::core::bm25_index::BM25Index;
12use crate::core::hybrid_search::HybridConfig;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct EvalQuery {
16 pub query: String,
17 pub expected_files: Vec<String>,
18 #[serde(default)]
19 pub category: String,
20}
21
22#[derive(Debug, Clone, serde::Serialize)]
23pub struct EvalResult {
24 pub query: String,
25 pub category: String,
26 pub recall_at_5: f64,
27 pub recall_at_10: f64,
28 pub mrr: f64,
29 pub latency_us: u64,
30 pub retrieved_files: Vec<String>,
31 pub expected_files: Vec<String>,
32}
33
34#[derive(Debug, Clone, serde::Serialize)]
35pub struct EvalScorecard {
36 pub project: String,
37 pub total_queries: usize,
38 pub avg_recall_at_5: f64,
39 pub avg_recall_at_10: f64,
40 pub avg_mrr: f64,
41 pub avg_latency_us: u64,
42 pub per_category: Vec<CategoryScore>,
43 pub results: Vec<EvalResult>,
44}
45
46#[derive(Debug, Clone, serde::Serialize)]
47pub struct CategoryScore {
48 pub category: String,
49 pub count: usize,
50 pub avg_recall_at_5: f64,
51 pub avg_mrr: f64,
52}
53
54impl std::fmt::Display for EvalScorecard {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 writeln!(f, "Eval: {} ({} queries)", self.project, self.total_queries)?;
57 writeln!(f, " R@5: {:.1}%", self.avg_recall_at_5 * 100.0)?;
58 writeln!(f, " R@10: {:.1}%", self.avg_recall_at_10 * 100.0)?;
59 writeln!(f, " MRR: {:.3}", self.avg_mrr)?;
60 writeln!(f, " Latency: {}µs avg", self.avg_latency_us)?;
61 for cat in &self.per_category {
62 writeln!(
63 f,
64 " [{:12}] R@5={:.1}% MRR={:.3} (n={})",
65 cat.category,
66 cat.avg_recall_at_5 * 100.0,
67 cat.avg_mrr,
68 cat.count
69 )?;
70 }
71 Ok(())
72 }
73}
74
75pub fn run_eval(
78 project_root: &Path,
79 queries: &[EvalQuery],
80 index: &BM25Index,
81 config: &HybridConfig,
82) -> EvalScorecard {
83 let label = project_root
84 .file_name()
85 .and_then(|s| s.to_str())
86 .unwrap_or("unknown")
87 .to_string();
88
89 let mut results = Vec::with_capacity(queries.len());
90
91 for q in queries {
92 let start = Instant::now();
93 let retrieved = hybrid_eval_search(project_root, &q.query, index, config);
94 let latency = start.elapsed().as_micros() as u64;
95
96 let recall_5 = recall_at_k(&retrieved, &q.expected_files, 5);
97 let recall_10 = recall_at_k(&retrieved, &q.expected_files, 10);
98 let mrr = mean_reciprocal_rank(&retrieved, &q.expected_files);
99
100 results.push(EvalResult {
101 query: q.query.clone(),
102 category: q.category.clone(),
103 recall_at_5: recall_5,
104 recall_at_10: recall_10,
105 mrr,
106 latency_us: latency,
107 retrieved_files: retrieved.into_iter().take(10).collect(),
108 expected_files: q.expected_files.clone(),
109 });
110 }
111
112 let total = results.len();
113 let avg_r5 = results.iter().map(|r| r.recall_at_5).sum::<f64>() / total.max(1) as f64;
114 let avg_r10 = results.iter().map(|r| r.recall_at_10).sum::<f64>() / total.max(1) as f64;
115 let avg_mrr = results.iter().map(|r| r.mrr).sum::<f64>() / total.max(1) as f64;
116 let avg_lat = results.iter().map(|r| r.latency_us).sum::<u64>() / total.max(1) as u64;
117
118 let per_category = build_category_scores(&results);
119
120 EvalScorecard {
121 project: label,
122 total_queries: total,
123 avg_recall_at_5: avg_r5,
124 avg_recall_at_10: avg_r10,
125 avg_mrr,
126 avg_latency_us: avg_lat,
127 per_category,
128 results,
129 }
130}
131
132fn hybrid_eval_search(
135 project_root: &Path,
136 query: &str,
137 index: &BM25Index,
138 config: &HybridConfig,
139) -> Vec<String> {
140 #[cfg(feature = "embeddings")]
141 {
142 if let Ok(results) = try_hybrid_search(project_root, query, index, config) {
143 return results;
144 }
145 }
146 let _ = project_root;
147 index
148 .search(query, config.bm25_candidates)
149 .iter()
150 .map(|r| r.file_path.clone())
151 .collect()
152}
153
154#[cfg(feature = "embeddings")]
155fn try_hybrid_search(
156 project_root: &Path,
157 query: &str,
158 index: &BM25Index,
159 config: &HybridConfig,
160) -> Result<Vec<String>, String> {
161 use crate::core::dense_backend;
162 use crate::tools::ctx_semantic_search;
163
164 let (engine, mut embed_idx) = ctx_semantic_search::load_engine_and_index_pub(project_root)?;
165
166 let (aligned, _coverage, changed_files) = ctx_semantic_search::ensure_embeddings_for_eval(
167 project_root,
168 index,
169 engine,
170 &mut embed_idx,
171 )?;
172
173 let backend = dense_backend::DenseBackendKind::try_from_env()?;
174 let candidate_k = config.bm25_candidates.max(config.dense_candidates);
175
176 let mut results = dense_backend::hybrid_results(
177 backend,
178 project_root,
179 index,
180 engine,
181 &aligned,
182 &changed_files,
183 query,
184 candidate_k,
185 config,
186 None,
187 None,
188 )?;
189
190 if config.splade_weight > 0.0 {
191 let splade = crate::core::splade_retrieval::hybrid_retrieve(query, index, candidate_k);
192 if !splade.is_empty() {
193 ctx_semantic_search::boost_with_splade_pub(&mut results, &splade, config.splade_weight);
194 }
195 }
196
197 results.truncate(10);
198 Ok(results.iter().map(|r| r.file_path.clone()).collect())
199}
200
201pub fn generate_self_eval(index: &BM25Index, max_queries: usize) -> Vec<EvalQuery> {
204 let mut queries = Vec::new();
205
206 for chunk in index.chunks.iter().take(max_queries * 2) {
207 if queries.len() >= max_queries {
208 break;
209 }
210 if chunk.symbol_name.is_empty() || chunk.file_path.is_empty() {
211 continue;
212 }
213
214 let category = if chunk.symbol_name.starts_with("fn ") || chunk.symbol_name.contains("()") {
215 "function"
216 } else if chunk.symbol_name.starts_with("struct ")
217 || chunk.symbol_name.starts_with("class ")
218 {
219 "type"
220 } else {
221 "symbol"
222 };
223
224 let clean_name = chunk
225 .symbol_name
226 .replace("fn ", "")
227 .replace("struct ", "")
228 .replace("class ", "")
229 .replace("()", "");
230
231 queries.push(EvalQuery {
232 query: format!("where is {clean_name} defined"),
233 expected_files: vec![chunk.file_path.clone()],
234 category: category.to_string(),
235 });
236 }
237
238 queries
239}
240
241fn recall_at_k(retrieved: &[String], expected: &[String], k: usize) -> f64 {
242 if expected.is_empty() {
243 return 0.0;
244 }
245 let top_k: Vec<&str> = retrieved
246 .iter()
247 .take(k)
248 .map(std::string::String::as_str)
249 .collect();
250 let hits = expected
251 .iter()
252 .filter(|e| {
253 top_k
254 .iter()
255 .any(|r| r.ends_with(e.as_str()) || e.ends_with(r))
256 })
257 .count();
258 hits as f64 / expected.len() as f64
259}
260
261fn mean_reciprocal_rank(retrieved: &[String], expected: &[String]) -> f64 {
262 for (rank, r) in retrieved.iter().enumerate() {
263 if expected
264 .iter()
265 .any(|e| r.ends_with(e.as_str()) || e.ends_with(r.as_str()))
266 {
267 return 1.0 / (rank as f64 + 1.0);
268 }
269 }
270 0.0
271}
272
273fn build_category_scores(results: &[EvalResult]) -> Vec<CategoryScore> {
274 use std::collections::HashMap;
275 let mut cat_map: HashMap<&str, Vec<&EvalResult>> = HashMap::new();
276 for r in results {
277 cat_map.entry(r.category.as_str()).or_default().push(r);
278 }
279
280 let mut scores: Vec<CategoryScore> = cat_map
281 .into_iter()
282 .map(|(cat, items)| {
283 let n = items.len();
284 CategoryScore {
285 category: cat.to_string(),
286 count: n,
287 avg_recall_at_5: items.iter().map(|r| r.recall_at_5).sum::<f64>() / n as f64,
288 avg_mrr: items.iter().map(|r| r.mrr).sum::<f64>() / n as f64,
289 }
290 })
291 .collect();
292 scores.sort_by(|a, b| a.category.cmp(&b.category));
293 scores
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn recall_at_k_full_match() {
302 let retrieved = vec!["a.rs".into(), "b.rs".into(), "c.rs".into()];
303 let expected = vec!["a.rs".into()];
304 assert_eq!(recall_at_k(&retrieved, &expected, 5), 1.0);
305 }
306
307 #[test]
308 fn recall_at_k_no_match() {
309 let retrieved = vec!["x.rs".into(), "y.rs".into()];
310 let expected = vec!["a.rs".into()];
311 assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.0);
312 }
313
314 #[test]
315 fn recall_at_k_partial() {
316 let retrieved = vec!["a.rs".into(), "x.rs".into()];
317 let expected = vec!["a.rs".into(), "b.rs".into()];
318 assert_eq!(recall_at_k(&retrieved, &expected, 5), 0.5);
319 }
320
321 #[test]
322 fn mrr_first_hit() {
323 let retrieved = vec!["a.rs".into(), "b.rs".into()];
324 let expected = vec!["a.rs".into()];
325 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 1.0);
326 }
327
328 #[test]
329 fn mrr_second_hit() {
330 let retrieved = vec!["x.rs".into(), "a.rs".into()];
331 let expected = vec!["a.rs".into()];
332 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.5);
333 }
334
335 #[test]
336 fn mrr_no_hit() {
337 let retrieved = vec!["x.rs".into()];
338 let expected = vec!["a.rs".into()];
339 assert_eq!(mean_reciprocal_rank(&retrieved, &expected), 0.0);
340 }
341
342 #[test]
343 fn empty_expected() {
344 assert_eq!(recall_at_k(&["a.rs".into()], &[], 5), 0.0);
345 }
346
347 #[test]
348 fn scorecard_display() {
349 let sc = EvalScorecard {
350 project: "test".into(),
351 total_queries: 10,
352 avg_recall_at_5: 0.8,
353 avg_recall_at_10: 0.9,
354 avg_mrr: 0.75,
355 avg_latency_us: 100,
356 per_category: vec![],
357 results: vec![],
358 };
359 let s = format!("{sc}");
360 assert!(s.contains("80.0%"));
361 assert!(s.contains("0.750"));
362 }
363}