Skip to main content

rag_rat_core/
eval.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fs,
4    path::{Path, PathBuf},
5    time::Instant,
6};
7
8use serde::{Deserialize, Serialize};
9
10use crate::{Config, IndexDatabase, index::ai};
11
12const TOP_K: usize = 10;
13
14#[derive(Debug, Clone, Deserialize)]
15pub struct EvalSuite {
16    #[serde(default)]
17    pub query: Vec<EvalQuery>,
18}
19
20#[derive(Debug, Clone, Deserialize)]
21pub struct ExpectedSuite {
22    #[serde(default)]
23    pub expected: Vec<ExpectedQuery>,
24}
25
26#[derive(Debug, Clone, Deserialize)]
27pub struct EvalQuery {
28    pub id: String,
29    pub text: String,
30    #[serde(default)]
31    pub evidence_class: Option<String>,
32    #[serde(default)]
33    pub requires_papertrail_cache: bool,
34    #[serde(default)]
35    pub must_include_paths: Vec<String>,
36    #[serde(default)]
37    pub must_include_symbols: Vec<String>,
38    #[serde(default)]
39    pub must_include_graph_targets: Vec<String>,
40    #[serde(default)]
41    pub must_include_impact_categories: Vec<String>,
42    #[serde(default)]
43    pub must_include_impact_paths: Vec<String>,
44    #[serde(default)]
45    pub must_include_impact_symbols: Vec<String>,
46    #[serde(default)]
47    pub should_include_git_subjects: Vec<String>,
48    #[serde(default)]
49    pub should_include_papertrail_kinds: Vec<String>,
50}
51
52#[derive(Debug, Clone, Deserialize, Serialize)]
53pub struct ExpectedQuery {
54    pub id: String,
55    #[serde(default)]
56    pub must_include_paths: Vec<String>,
57    #[serde(default)]
58    pub must_include_symbols: Vec<String>,
59    #[serde(default)]
60    pub must_include_graph_targets: Vec<String>,
61    #[serde(default)]
62    pub must_include_impact_categories: Vec<String>,
63    #[serde(default)]
64    pub must_include_impact_paths: Vec<String>,
65    #[serde(default)]
66    pub must_include_impact_symbols: Vec<String>,
67    #[serde(default)]
68    pub should_include_git_subjects: Vec<String>,
69    #[serde(default)]
70    pub should_include_papertrail_kinds: Vec<String>,
71}
72
73#[derive(Debug, Clone)]
74pub struct EvalOptions {
75    pub queries_path: PathBuf,
76    pub expected_path: PathBuf,
77    pub update_baseline: bool,
78}
79
80#[derive(Debug, Serialize)]
81pub struct EvalReport {
82    pub pass: bool,
83    pub queries: usize,
84    pub metrics: EvalMetrics,
85    pub hash_vector_baseline: EvalBaselineReport,
86    pub results: Vec<EvalQueryReport>,
87}
88
89#[derive(Debug, Serialize)]
90pub struct EvalBaselineReport {
91    pub model_id: String,
92    pub available: bool,
93    pub current_artifacts: u64,
94    pub metrics: EvalMetrics,
95    pub delta_mrr_at_10: f64,
96    pub delta_recall_at_10: f64,
97    pub delta_path_hit_rate: f64,
98    pub delta_symbol_hit_rate: f64,
99}
100
101#[derive(Debug, Serialize)]
102pub struct EvalMetrics {
103    pub mrr_at_10: f64,
104    pub recall_at_10: f64,
105    pub path_hit_rate: f64,
106    pub symbol_hit_rate: f64,
107    pub graph_evidence_hit_rate: f64,
108    pub impact_hit_rate: f64,
109    pub git_evidence_hit_rate: f64,
110    pub papertrail_evidence_hit_rate: f64,
111    pub stale_hit_rate: f64,
112    pub stale_current_source_violations: u64,
113    pub current_source_violation_count: u64,
114    pub papertrail_precision_sample: Option<f64>,
115    pub latency_p50_ms: f64,
116    pub latency_p95_ms: f64,
117}
118
119#[derive(Debug, Serialize)]
120pub struct EvalQueryReport {
121    pub id: String,
122    pub text: String,
123    pub passed: bool,
124    pub skipped: bool,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub skip_reason: Option<String>,
127    pub reciprocal_rank_at_10: f64,
128    pub recall_at_10: f64,
129    pub path_hits: Vec<String>,
130    pub missing_paths: Vec<String>,
131    pub symbol_hits: Vec<String>,
132    pub missing_symbols: Vec<String>,
133    pub graph_target_hits: Vec<String>,
134    pub missing_graph_targets: Vec<String>,
135    pub impact_category_hits: Vec<String>,
136    pub missing_impact_categories: Vec<String>,
137    pub impact_path_hits: Vec<String>,
138    pub missing_impact_paths: Vec<String>,
139    pub impact_symbol_hits: Vec<String>,
140    pub missing_impact_symbols: Vec<String>,
141    pub git_subject_hits: Vec<String>,
142    pub missing_git_subjects: Vec<String>,
143    pub papertrail_kind_hits: Vec<String>,
144    pub missing_papertrail_kinds: Vec<String>,
145    pub papertrail_precision_sample: Option<f64>,
146    pub stale_current_source_violations: u64,
147    pub current_source_violations: Vec<CurrentSourceViolation>,
148    pub latency_ms: f64,
149    pub top_hits: Vec<EvalSearchHit>,
150}
151
152#[derive(Debug, Serialize)]
153pub struct EvalSearchHit {
154    pub rank: usize,
155    pub chunk_id: i64,
156    pub path: String,
157    pub symbol_path: Option<String>,
158    pub start_line: i64,
159    pub end_line: i64,
160    pub score: f64,
161}
162
163#[derive(Debug, Serialize)]
164pub struct CurrentSourceViolation {
165    pub chunk_id: i64,
166    pub path: String,
167    pub reason: String,
168}
169
170#[derive(Debug, Serialize)]
171struct BaselineSuite {
172    expected: Vec<ExpectedQuery>,
173}
174
175pub fn run(config: &Config, options: &EvalOptions) -> anyhow::Result<EvalReport> {
176    let suite = load_queries(&options.queries_path)?;
177    let expected = load_expected(&options.expected_path)?;
178    let db = IndexDatabase::open_config(config)?;
179    let mut results = Vec::new();
180    let mut observed = Vec::new();
181
182    for query in &suite.query {
183        let expected_query = expected.get(&query.id);
184        let merged = merge_expected(query.clone(), expected_query);
185        let report = evaluate_query(config, &db, &merged, SearchMode::Active)?;
186        observed.push(observed_expected(&report));
187        results.push(report);
188    }
189
190    if options.update_baseline {
191        write_baseline(&options.expected_path, observed)?;
192    }
193
194    let metrics = aggregate(&results);
195    let baseline = hash_vector_baseline(config, &db, &suite.query, &expected, &metrics)?;
196    let pass = metrics.stale_current_source_violations == 0 && results.iter().all(|r| r.passed);
197    Ok(EvalReport {
198        pass,
199        queries: results.len(),
200        metrics,
201        hash_vector_baseline: baseline,
202        results,
203    })
204}
205
206fn load_queries(path: &Path) -> anyhow::Result<EvalSuite> {
207    let text = fs::read_to_string(path)
208        .map_err(|err| anyhow::anyhow!("failed to read eval queries {}: {err}", path.display()))?;
209    toml::from_str(&text)
210        .map_err(|err| anyhow::anyhow!("failed to parse eval queries {}: {err}", path.display()))
211}
212
213fn load_expected(path: &Path) -> anyhow::Result<BTreeMap<String, ExpectedQuery>> {
214    if !path.exists() {
215        return Ok(BTreeMap::new());
216    }
217    let text = fs::read_to_string(path).map_err(|err| {
218        anyhow::anyhow!("failed to read eval expected hits {}: {err}", path.display())
219    })?;
220    let suite: ExpectedSuite = toml::from_str(&text).map_err(|err| {
221        anyhow::anyhow!("failed to parse eval expected hits {}: {err}", path.display())
222    })?;
223    Ok(suite.expected.into_iter().map(|expected| (expected.id.clone(), expected)).collect())
224}
225
226fn merge_expected(query: EvalQuery, expected: Option<&ExpectedQuery>) -> EvalQuery {
227    let Some(expected) = expected else {
228        return query;
229    };
230    EvalQuery {
231        id: query.id,
232        text: query.text,
233        evidence_class: query.evidence_class,
234        requires_papertrail_cache: query.requires_papertrail_cache,
235        must_include_paths: union(query.must_include_paths, &expected.must_include_paths),
236        must_include_symbols: union(query.must_include_symbols, &expected.must_include_symbols),
237        must_include_graph_targets: union(
238            query.must_include_graph_targets,
239            &expected.must_include_graph_targets,
240        ),
241        must_include_impact_categories: union(
242            query.must_include_impact_categories,
243            &expected.must_include_impact_categories,
244        ),
245        must_include_impact_paths: union(
246            query.must_include_impact_paths,
247            &expected.must_include_impact_paths,
248        ),
249        must_include_impact_symbols: union(
250            query.must_include_impact_symbols,
251            &expected.must_include_impact_symbols,
252        ),
253        should_include_git_subjects: union(
254            query.should_include_git_subjects,
255            &expected.should_include_git_subjects,
256        ),
257        should_include_papertrail_kinds: union(
258            query.should_include_papertrail_kinds,
259            &expected.should_include_papertrail_kinds,
260        ),
261    }
262}
263
264fn union(mut values: Vec<String>, extra: &[String]) -> Vec<String> {
265    let mut seen = values.iter().cloned().collect::<BTreeSet<_>>();
266    for value in extra {
267        if seen.insert(value.clone()) {
268            values.push(value.clone());
269        }
270    }
271    values
272}
273
274fn evaluate_query(
275    config: &Config,
276    db: &IndexDatabase,
277    query: &EvalQuery,
278    mode: SearchMode,
279) -> anyhow::Result<EvalQueryReport> {
280    if query.requires_papertrail_cache && !papertrail_cache_available(db)? {
281        return Ok(skipped_report(
282            query,
283            "papertrail cache is empty; run `rag-rat github sync --from-refs`",
284        ));
285    }
286
287    let started = Instant::now();
288    let mut hits = search(db, mode, &query.text)?;
289    let mut latency_ms = started.elapsed().as_secs_f64() * 1000.0;
290    let mut current_source_violations = find_current_source_violations(config, db, &hits);
291    if !current_source_violations.is_empty() {
292        let retry_started = Instant::now();
293        hits = search(db, mode, &query.text)?;
294        latency_ms += retry_started.elapsed().as_secs_f64() * 1000.0;
295        current_source_violations = find_current_source_violations(config, db, &hits);
296    }
297    let top_hits = top_hits(&hits);
298
299    let path_hits = query
300        .must_include_paths
301        .iter()
302        .filter(|expected| hits.iter().any(|hit| hit.path == **expected))
303        .cloned()
304        .collect::<Vec<_>>();
305    let missing_paths = missing(&query.must_include_paths, &path_hits);
306    let symbol_hits = query
307        .must_include_symbols
308        .iter()
309        .filter(|expected| {
310            hits.iter()
311                .filter_map(|hit| hit.symbol_path.as_deref())
312                .any(|symbol| symbol == expected.as_str() || symbol.ends_with(expected.as_str()))
313        })
314        .cloned()
315        .collect::<Vec<_>>();
316    let missing_symbols = missing(&query.must_include_symbols, &symbol_hits);
317
318    let graph_target_hits = query
319        .must_include_graph_targets
320        .iter()
321        .filter(|expected| hits.iter().any(|hit| graph_hit_matches(hit, expected)))
322        .cloned()
323        .collect::<Vec<_>>();
324    let missing_graph_targets = missing(&query.must_include_graph_targets, &graph_target_hits);
325
326    let impact = if query.must_include_impact_categories.is_empty()
327        && query.must_include_impact_paths.is_empty()
328        && query.must_include_impact_symbols.is_empty()
329    {
330        Vec::new()
331    } else {
332        db.impact_surface(&query.text, TOP_K as u32).unwrap_or_default()
333    };
334    let impact_category_hits = query
335        .must_include_impact_categories
336        .iter()
337        .filter(|expected| impact.iter().any(|item| item.category == **expected))
338        .cloned()
339        .collect::<Vec<_>>();
340    let missing_impact_categories =
341        missing(&query.must_include_impact_categories, &impact_category_hits);
342    let impact_path_hits = query
343        .must_include_impact_paths
344        .iter()
345        .filter(|expected| impact.iter().any(|item| item.path == **expected))
346        .cloned()
347        .collect::<Vec<_>>();
348    let missing_impact_paths = missing(&query.must_include_impact_paths, &impact_path_hits);
349    let impact_symbol_hits = query
350        .must_include_impact_symbols
351        .iter()
352        .filter(|expected| {
353            impact
354                .iter()
355                .filter_map(|item| item.symbol.as_deref())
356                .any(|symbol| symbol == expected.as_str() || symbol.ends_with(expected.as_str()))
357        })
358        .cloned()
359        .collect::<Vec<_>>();
360    let missing_impact_symbols = missing(&query.must_include_impact_symbols, &impact_symbol_hits);
361
362    let commit_hits = db.commit_search(&query.text, TOP_K as u32).unwrap_or_default();
363    let git_subject_hits = query
364        .should_include_git_subjects
365        .iter()
366        .filter(|expected| {
367            let needle = expected.to_ascii_lowercase();
368            commit_hits.iter().any(|hit| hit.subject.to_ascii_lowercase().contains(&needle))
369        })
370        .cloned()
371        .collect::<Vec<_>>();
372    let missing_git_subjects = missing(&query.should_include_git_subjects, &git_subject_hits);
373
374    let papertrail = db.rationale_search(&query.text, TOP_K as u32).unwrap_or_default();
375    let papertrail_kind_hits = query
376        .should_include_papertrail_kinds
377        .iter()
378        .filter(|expected| {
379            let needle = normalize_kind(expected);
380            papertrail.iter().any(|item| normalize_kind(&item.classification) == needle)
381        })
382        .cloned()
383        .collect::<Vec<_>>();
384    let missing_papertrail_kinds =
385        missing(&query.should_include_papertrail_kinds, &papertrail_kind_hits);
386    let papertrail_precision_sample = if query.should_include_papertrail_kinds.is_empty() {
387        None
388    } else if papertrail.is_empty() {
389        Some(0.0)
390    } else {
391        let expected = query
392            .should_include_papertrail_kinds
393            .iter()
394            .map(|kind| normalize_kind(kind))
395            .collect::<BTreeSet<_>>();
396        let matched = papertrail
397            .iter()
398            .filter(|item| expected.contains(&normalize_kind(&item.classification)))
399            .count();
400        Some(matched as f64 / papertrail.len() as f64)
401    };
402
403    let stale_current_source_violations =
404        u64::try_from(current_source_violations.len()).unwrap_or(u64::MAX);
405    let relevant_rank = hits.iter().position(|hit| relevant(hit, query)).map(|rank| rank + 1);
406    let reciprocal_rank_at_10 = relevant_rank.map(|rank| 1.0 / rank as f64).unwrap_or(0.0);
407    let expected_relevant = query.must_include_paths.len() + query.must_include_symbols.len();
408    let found_relevant = path_hits.len() + symbol_hits.len();
409    let recall_at_10 =
410        if expected_relevant == 0 { 1.0 } else { found_relevant as f64 / expected_relevant as f64 };
411    let passed = stale_current_source_violations == 0
412        && missing_paths.is_empty()
413        && missing_symbols.is_empty()
414        && missing_graph_targets.is_empty()
415        && missing_impact_categories.is_empty()
416        && missing_impact_paths.is_empty()
417        && missing_impact_symbols.is_empty()
418        && missing_git_subjects.is_empty()
419        && missing_papertrail_kinds.is_empty();
420
421    Ok(EvalQueryReport {
422        id: query.id.clone(),
423        text: query.text.clone(),
424        passed,
425        skipped: false,
426        skip_reason: None,
427        reciprocal_rank_at_10,
428        recall_at_10,
429        path_hits,
430        missing_paths,
431        symbol_hits,
432        missing_symbols,
433        graph_target_hits,
434        missing_graph_targets,
435        impact_category_hits,
436        missing_impact_categories,
437        impact_path_hits,
438        missing_impact_paths,
439        impact_symbol_hits,
440        missing_impact_symbols,
441        git_subject_hits,
442        missing_git_subjects,
443        papertrail_kind_hits,
444        missing_papertrail_kinds,
445        papertrail_precision_sample,
446        stale_current_source_violations,
447        current_source_violations,
448        latency_ms,
449        top_hits,
450    })
451}
452
453fn skipped_report(query: &EvalQuery, reason: impl Into<String>) -> EvalQueryReport {
454    EvalQueryReport {
455        id: query.id.clone(),
456        text: query.text.clone(),
457        passed: true,
458        skipped: true,
459        skip_reason: Some(reason.into()),
460        reciprocal_rank_at_10: 0.0,
461        recall_at_10: 1.0,
462        path_hits: Vec::new(),
463        missing_paths: Vec::new(),
464        symbol_hits: Vec::new(),
465        missing_symbols: Vec::new(),
466        graph_target_hits: Vec::new(),
467        missing_graph_targets: Vec::new(),
468        impact_category_hits: Vec::new(),
469        missing_impact_categories: Vec::new(),
470        impact_path_hits: Vec::new(),
471        missing_impact_paths: Vec::new(),
472        impact_symbol_hits: Vec::new(),
473        missing_impact_symbols: Vec::new(),
474        git_subject_hits: Vec::new(),
475        missing_git_subjects: Vec::new(),
476        papertrail_kind_hits: Vec::new(),
477        missing_papertrail_kinds: Vec::new(),
478        papertrail_precision_sample: None,
479        stale_current_source_violations: 0,
480        current_source_violations: Vec::new(),
481        latency_ms: 0.0,
482        top_hits: Vec::new(),
483    }
484}
485
486fn papertrail_cache_available(db: &IndexDatabase) -> anyhow::Result<bool> {
487    let status = db.github_sync_status()?;
488    Ok(status.issues + status.comments + status.pulls + status.reviews + status.review_comments > 0)
489}
490
491#[derive(Debug, Clone, Copy)]
492enum SearchMode {
493    Active,
494    HashBaseline,
495}
496
497fn search(
498    db: &IndexDatabase,
499    mode: SearchMode,
500    query: &str,
501) -> anyhow::Result<Vec<crate::search::lexical::SearchHit>> {
502    match mode {
503        SearchMode::Active => db.search(query, TOP_K as u32, false),
504        SearchMode::HashBaseline => db.search_hash_baseline(query, TOP_K as u32, false),
505    }
506}
507
508fn hash_vector_baseline(
509    config: &Config,
510    db: &IndexDatabase,
511    queries: &[EvalQuery],
512    expected: &BTreeMap<String, ExpectedQuery>,
513    active_metrics: &EvalMetrics,
514) -> anyhow::Result<EvalBaselineReport> {
515    let mut results = Vec::new();
516    for query in queries {
517        let merged = merge_expected(query.clone(), expected.get(&query.id));
518        results.push(evaluate_query(config, db, &merged, SearchMode::HashBaseline)?);
519    }
520    let metrics = aggregate(&results);
521    let current_artifacts = db.current_embedding_count(ai::HASH_MODEL_ID)?;
522    Ok(EvalBaselineReport {
523        model_id: ai::HASH_MODEL_ID.to_string(),
524        available: current_artifacts > 0,
525        current_artifacts,
526        delta_mrr_at_10: active_metrics.mrr_at_10 - metrics.mrr_at_10,
527        delta_recall_at_10: active_metrics.recall_at_10 - metrics.recall_at_10,
528        delta_path_hit_rate: active_metrics.path_hit_rate - metrics.path_hit_rate,
529        delta_symbol_hit_rate: active_metrics.symbol_hit_rate - metrics.symbol_hit_rate,
530        metrics,
531    })
532}
533
534fn top_hits(hits: &[crate::search::lexical::SearchHit]) -> Vec<EvalSearchHit> {
535    hits.iter()
536        .enumerate()
537        .map(|(index, hit)| EvalSearchHit {
538            rank: index + 1,
539            chunk_id: hit.chunk_id,
540            path: hit.path.clone(),
541            symbol_path: hit.symbol_path.clone(),
542            start_line: hit.start_line,
543            end_line: hit.end_line,
544            score: hit.score,
545        })
546        .collect()
547}
548
549fn relevant(hit: &crate::search::lexical::SearchHit, query: &EvalQuery) -> bool {
550    query.must_include_paths.iter().any(|path| path == &hit.path)
551        || hit.symbol_path.as_deref().is_some_and(|symbol| {
552            query
553                .must_include_symbols
554                .iter()
555                .any(|expected| symbol == expected || symbol.ends_with(expected))
556        })
557        || query.must_include_graph_targets.iter().any(|expected| graph_hit_matches(hit, expected))
558}
559
560fn graph_hit_matches(hit: &crate::search::lexical::SearchHit, expected: &str) -> bool {
561    let Some(graph) = &hit.graph else {
562        return false;
563    };
564    graph.top_callers.iter().chain(graph.callers.iter()).any(|caller| {
565        caller.symbol_path.ends_with(expected) || caller.symbol_path.contains(expected)
566    }) || graph.top_callees.iter().chain(graph.callees.iter()).any(|callee| {
567        callee.target == expected
568            || callee.target.ends_with(expected)
569            || callee
570                .resolved_symbol_path
571                .as_deref()
572                .is_some_and(|symbol| symbol.ends_with(expected) || symbol.contains(expected))
573    }) || graph.imports.iter().any(|import| import.target.contains(expected))
574        || graph
575            .referenced_types
576            .iter()
577            .any(|ty| ty.name == expected || ty.name.ends_with(expected))
578}
579
580fn missing(expected: &[String], found: &[String]) -> Vec<String> {
581    let found = found.iter().collect::<BTreeSet<_>>();
582    expected.iter().filter(|value| !found.contains(value)).cloned().collect()
583}
584
585fn find_current_source_violations(
586    config: &Config,
587    db: &IndexDatabase,
588    hits: &[crate::search::lexical::SearchHit],
589) -> Vec<CurrentSourceViolation> {
590    let mut violations = Vec::new();
591    let mut checked = BTreeSet::new();
592    for hit in hits {
593        if !checked.insert(hit.chunk_id) {
594            continue;
595        }
596        match db.read_chunk(hit.chunk_id) {
597            Ok(Some(chunk)) => {
598                let source_path = config.root.join(&chunk.path);
599                match fs::read_to_string(&source_path) {
600                    Ok(source) => {
601                        let current = slice_lines(&source, chunk.start_line, chunk.end_line);
602                        if current.as_deref() != Some(chunk.text.as_str()) {
603                            violations.push(CurrentSourceViolation {
604                                chunk_id: hit.chunk_id,
605                                path: chunk.path,
606                                reason: "read_chunk text differs from current source line span"
607                                    .to_string(),
608                            });
609                        }
610                    },
611                    Err(err) => violations.push(CurrentSourceViolation {
612                        chunk_id: hit.chunk_id,
613                        path: chunk.path,
614                        reason: format!("current source unreadable: {err}"),
615                    }),
616                }
617            },
618            Ok(None) => violations.push(CurrentSourceViolation {
619                chunk_id: hit.chunk_id,
620                path: hit.path.clone(),
621                reason: "search hit chunk is missing".to_string(),
622            }),
623            Err(err) => violations.push(CurrentSourceViolation {
624                chunk_id: hit.chunk_id,
625                path: hit.path.clone(),
626                reason: format!("read_chunk failed: {err}"),
627            }),
628        }
629    }
630    violations
631}
632
633fn slice_lines(source: &str, start_line: i64, end_line: i64) -> Option<String> {
634    let start = usize::try_from(start_line).ok()?.max(1);
635    let end = usize::try_from(end_line).ok()?.max(start);
636    let lines = source.lines().collect::<Vec<_>>();
637    if start > lines.len() {
638        return None;
639    }
640    let mut text = lines[(start - 1)..end.min(lines.len())].join("\n");
641    text.push('\n');
642    Some(text)
643}
644
645fn normalize_kind(kind: &str) -> String {
646    kind.trim().to_ascii_lowercase().replace(['-', ' '], "_")
647}
648
649fn aggregate(results: &[EvalQueryReport]) -> EvalMetrics {
650    let measured = results.iter().filter(|result| !result.skipped).collect::<Vec<_>>();
651    let query_count = measured.len().max(1) as f64;
652    let total_hits = measured.iter().map(|r| r.top_hits.len() as u64).sum::<u64>();
653    let stale = measured.iter().map(|r| r.stale_current_source_violations).sum::<u64>();
654    let papertrail_samples =
655        measured.iter().filter_map(|r| r.papertrail_precision_sample).collect::<Vec<_>>();
656    EvalMetrics {
657        mrr_at_10: measured.iter().map(|r| r.reciprocal_rank_at_10).sum::<f64>() / query_count,
658        recall_at_10: measured.iter().map(|r| r.recall_at_10).sum::<f64>() / query_count,
659        path_hit_rate: hit_rate(&measured, |r| r.missing_paths.is_empty()),
660        symbol_hit_rate: hit_rate(&measured, |r| r.missing_symbols.is_empty()),
661        graph_evidence_hit_rate: expected_hit_rate(&measured, |r| {
662            (!r.graph_target_hits.is_empty() || !r.missing_graph_targets.is_empty())
663                .then_some(r.missing_graph_targets.is_empty())
664        }),
665        impact_hit_rate: expected_hit_rate(&measured, |r| {
666            (!r.impact_category_hits.is_empty()
667                || !r.missing_impact_categories.is_empty()
668                || !r.impact_path_hits.is_empty()
669                || !r.missing_impact_paths.is_empty()
670                || !r.impact_symbol_hits.is_empty()
671                || !r.missing_impact_symbols.is_empty())
672            .then_some(
673                r.missing_impact_categories.is_empty()
674                    && r.missing_impact_paths.is_empty()
675                    && r.missing_impact_symbols.is_empty(),
676            )
677        }),
678        git_evidence_hit_rate: expected_hit_rate(&measured, |r| {
679            (!r.git_subject_hits.is_empty() || !r.missing_git_subjects.is_empty())
680                .then_some(r.missing_git_subjects.is_empty())
681        }),
682        papertrail_evidence_hit_rate: expected_hit_rate(&measured, |r| {
683            (!r.papertrail_kind_hits.is_empty() || !r.missing_papertrail_kinds.is_empty())
684                .then_some(r.missing_papertrail_kinds.is_empty())
685        }),
686        stale_hit_rate: if total_hits == 0 { 0.0 } else { stale as f64 / total_hits as f64 },
687        stale_current_source_violations: stale,
688        current_source_violation_count: stale,
689        papertrail_precision_sample: (!papertrail_samples.is_empty())
690            .then(|| papertrail_samples.iter().sum::<f64>() / papertrail_samples.len() as f64),
691        latency_p50_ms: percentile(measured.iter().map(|r| r.latency_ms).collect(), 0.50),
692        latency_p95_ms: percentile(measured.iter().map(|r| r.latency_ms).collect(), 0.95),
693    }
694}
695
696fn hit_rate(results: &[&EvalQueryReport], predicate: fn(&EvalQueryReport) -> bool) -> f64 {
697    if results.is_empty() {
698        return 1.0;
699    }
700    results.iter().filter(|result| predicate(result)).count() as f64 / results.len() as f64
701}
702
703fn expected_hit_rate(
704    results: &[&EvalQueryReport],
705    predicate: fn(&EvalQueryReport) -> Option<bool>,
706) -> f64 {
707    let applicable = results.iter().filter_map(|result| predicate(result)).collect::<Vec<_>>();
708    if applicable.is_empty() {
709        return 1.0;
710    }
711    applicable.iter().filter(|passed| **passed).count() as f64 / applicable.len() as f64
712}
713
714fn percentile(mut values: Vec<f64>, percentile: f64) -> f64 {
715    if values.is_empty() {
716        return 0.0;
717    }
718    values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
719    let index = ((values.len() - 1) as f64 * percentile).ceil() as usize;
720    values[index.min(values.len() - 1)]
721}
722
723fn observed_expected(report: &EvalQueryReport) -> ExpectedQuery {
724    let mut paths = report.top_hits.iter().map(|hit| hit.path.clone()).collect::<Vec<_>>();
725    dedup(&mut paths);
726    let mut symbols =
727        report.top_hits.iter().filter_map(|hit| hit.symbol_path.clone()).collect::<Vec<_>>();
728    dedup(&mut symbols);
729    ExpectedQuery {
730        id: report.id.clone(),
731        must_include_paths: paths,
732        must_include_symbols: symbols,
733        must_include_graph_targets: report.graph_target_hits.clone(),
734        must_include_impact_categories: report.impact_category_hits.clone(),
735        must_include_impact_paths: report.impact_path_hits.clone(),
736        must_include_impact_symbols: report.impact_symbol_hits.clone(),
737        should_include_git_subjects: report.git_subject_hits.clone(),
738        should_include_papertrail_kinds: report.papertrail_kind_hits.clone(),
739    }
740}
741
742fn dedup(values: &mut Vec<String>) {
743    let mut seen = BTreeSet::new();
744    values.retain(|value| seen.insert(value.clone()));
745}
746
747fn write_baseline(path: &Path, expected: Vec<ExpectedQuery>) -> anyhow::Result<()> {
748    if let Some(parent) = path.parent() {
749        fs::create_dir_all(parent)?;
750    }
751    let text = toml::to_string_pretty(&BaselineSuite { expected })?;
752    fs::write(path, text)?;
753    Ok(())
754}
755
756#[cfg(test)]
757mod tests {
758    use std::path::PathBuf;
759
760    use super::*;
761    use crate::{Config, IndexDatabase};
762
763    #[test]
764    fn eval_suite_reports_search_quality_and_current_source_safety() {
765        let root = fixture_root();
766        let config = Config::load(root.join("rag-rat.toml")).unwrap();
767        IndexDatabase::rebuild(&config).unwrap();
768        let report = run(
769            &config,
770            &EvalOptions {
771                queries_path: workspace_root().join("evals/queries.toml"),
772                expected_path: workspace_root().join("evals/expected_hits.toml"),
773                update_baseline: false,
774            },
775        )
776        .unwrap();
777        assert_eq!(report.metrics.stale_current_source_violations, 0);
778        assert!(report.metrics.mrr_at_10 > 0.0);
779        assert!(report.metrics.recall_at_10 > 0.0);
780    }
781
782    fn workspace_root() -> PathBuf {
783        PathBuf::from(env!("CARGO_MANIFEST_DIR")).ancestors().nth(2).unwrap().to_path_buf()
784    }
785
786    fn fixture_root() -> PathBuf {
787        workspace_root().join("tests/fixtures/held-mini")
788    }
789}