use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
use crate::CoreError;
use crate::context::retrieval::{self, RuleSearchRetrievalOptions};
use crate::context::rule_source::RuleDocument;
use crate::domain::glob_match::{GlobErrorPolicy, glob_match};
pub const GOLDEN_SMOKE_FIXTURE: &str =
include_str!("../../../tests/fixtures/rag-eval-seed-cases.json");
pub const GOLDEN_K: usize = 3;
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenFixture {
pub rules: Vec<GoldenRule>,
pub cases: Vec<GoldenCase>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenRule {
pub id: String,
pub title: String,
pub body: String,
#[serde(default, rename = "filePatterns")]
pub file_patterns: Vec<String>,
#[serde(default, rename = "sourceRepo")]
pub source_repo: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenCase {
pub id: String,
pub query: String,
#[serde(default)]
pub file: Option<String>,
#[serde(default, rename = "expectedRuleIds")]
pub expected_rule_ids: Vec<String>,
#[serde(default, rename = "forbiddenRuleIds")]
pub forbidden_rule_ids: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct GoldenCaseResult {
pub case_id: String,
pub expected: usize,
pub first_relevant_rank: Option<usize>,
pub recall_at_k: Option<f64>,
pub precision_at_k: Option<f64>,
pub forbidden_hits: usize,
pub strict_file_match: Option<bool>,
pub abstained_correctly: Option<bool>,
pub top_rule: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct GoldenReport {
pub k: usize,
pub total_cases: usize,
pub positive_cases: usize,
pub negative_cases: usize,
pub mean_recall_at_k: f64,
pub mean_precision_at_k: f64,
pub mean_reciprocal_rank: f64,
pub positive_forbidden_hits: usize,
pub negative_clean: usize,
pub strict_file_correct: usize,
pub strict_file_total: usize,
pub cases: Vec<GoldenCaseResult>,
}
pub fn parse_golden_fixture(json: &str) -> Result<GoldenFixture, CoreError> {
Ok(serde_json::from_str(json)?)
}
#[must_use]
pub fn golden_rules_to_documents(fixture: &GoldenFixture) -> Vec<RuleDocument> {
fixture
.rules
.iter()
.map(|rule| {
let file_patterns = if rule.file_patterns.is_empty() {
None
} else {
serde_json::to_string(&rule.file_patterns).ok()
};
RuleDocument {
skill_id: rule.id.clone(),
title: rule.title.clone(),
content: format!("{}\n\n{}", rule.title, rule.body),
confidence: 1.0,
file_patterns,
language: None,
repo_scope: None,
}
})
.collect()
}
pub async fn score_golden_cases(
index_pool: &crate::SqlitePool,
fixture: &GoldenFixture,
top_k: usize,
) -> Result<GoldenReport, CoreError> {
let file_patterns: std::collections::HashMap<&str, Option<String>> = fixture
.rules
.iter()
.map(|rule| {
let blob = if rule.file_patterns.is_empty() {
None
} else {
serde_json::to_string(&rule.file_patterns).ok()
};
(rule.id.as_str(), blob)
})
.collect();
let mut results = Vec::with_capacity(fixture.cases.len());
for case in &fixture.cases {
let hits = retrieval::retrieve_rules_for_search(
index_pool,
RuleSearchRetrievalOptions {
query: &case.query,
lexical_query: &case.query,
top_k,
confidence_map: None,
age_days_map: None,
effectiveness_map: None,
target_scope: None,
repo_scopes: &[],
ann_enabled: false,
local_query_embedding: false,
embedding_timeout: None,
cold_start_retry: false,
adaptive_prune: false,
},
)
.await?;
let mut ranked: Vec<String> = Vec::new();
for hit in &hits {
if !ranked.iter().any(|id| id == &hit.skill_id) {
ranked.push(hit.skill_id.clone());
}
}
results.push(score_one_case(case, &ranked, &file_patterns));
}
Ok(aggregate(results))
}
fn score_one_case(
case: &GoldenCase,
ranked: &[String],
file_patterns: &std::collections::HashMap<&str, Option<String>>,
) -> GoldenCaseResult {
let expected: BTreeSet<&str> = case.expected_rule_ids.iter().map(String::as_str).collect();
let forbidden: BTreeSet<&str> = case.forbidden_rule_ids.iter().map(String::as_str).collect();
let cutoff = GOLDEN_K.min(ranked.len());
let top: Vec<&str> = ranked.iter().take(cutoff).map(String::as_str).collect();
let first_relevant_rank = top
.iter()
.position(|id| expected.contains(*id))
.map(|pos| pos + 1);
let expected_in_top = top.iter().filter(|id| expected.contains(*id)).count();
let forbidden_hits = top.iter().filter(|id| forbidden.contains(*id)).count();
let (recall_at_k, precision_at_k, strict_file_match, abstained_correctly) =
if expected.is_empty() {
(None, None, None, Some(top.is_empty()))
} else {
let recall = expected_in_top as f64 / expected.len() as f64;
let precision = if top.is_empty() {
0.0
} else {
expected_in_top as f64 / top.len() as f64
};
let strict = case.file.as_deref().map(|file| {
top.iter().filter(|id| expected.contains(*id)).all(|id| {
let blob = file_patterns.get(id).and_then(Option::as_deref);
glob_match(blob, file, GlobErrorPolicy::OverRecall)
})
});
(Some(recall), Some(precision), strict, None)
};
GoldenCaseResult {
case_id: case.id.clone(),
expected: expected.len(),
first_relevant_rank,
recall_at_k,
precision_at_k,
forbidden_hits,
strict_file_match,
abstained_correctly,
top_rule: ranked.first().cloned(),
}
}
fn aggregate(cases: Vec<GoldenCaseResult>) -> GoldenReport {
let total_cases = cases.len();
let positive: Vec<&GoldenCaseResult> = cases.iter().filter(|c| c.expected > 0).collect();
let positive_cases = positive.len();
let negative_cases = total_cases - positive_cases;
let mean = |sum: f64, n: usize| if n == 0 { 0.0 } else { sum / n as f64 };
let recall_sum: f64 = positive.iter().filter_map(|c| c.recall_at_k).sum();
let precision_sum: f64 = positive.iter().filter_map(|c| c.precision_at_k).sum();
let rr_sum: f64 = positive
.iter()
.map(|c| c.first_relevant_rank.map_or(0.0, |rank| 1.0 / rank as f64))
.sum();
let positive_forbidden_hits = positive.iter().map(|c| c.forbidden_hits).sum();
let negative_clean = cases
.iter()
.filter(|c| c.abstained_correctly == Some(true))
.count();
let strict_file_total = positive
.iter()
.filter(|c| c.strict_file_match.is_some())
.count();
let strict_file_correct = positive
.iter()
.filter(|c| c.strict_file_match == Some(true))
.count();
GoldenReport {
k: GOLDEN_K,
total_cases,
positive_cases,
negative_cases,
mean_recall_at_k: mean(recall_sum, positive_cases),
mean_precision_at_k: mean(precision_sum, positive_cases),
mean_reciprocal_rank: mean(rr_sum, positive_cases),
positive_forbidden_hits,
negative_clean,
strict_file_correct,
strict_file_total,
cases,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn golden_case(expected_rule_ids: Vec<&str>) -> GoldenCase {
GoldenCase {
id: "case-1".to_owned(),
query: "query".to_owned(),
file: None,
expected_rule_ids: expected_rule_ids.into_iter().map(str::to_owned).collect(),
forbidden_rule_ids: Vec::new(),
}
}
#[test]
fn score_one_case_does_not_credit_mrr_past_top_k() {
let ranked = ["a", "b", "c", "target"]
.into_iter()
.map(str::to_owned)
.collect::<Vec<_>>();
let result = score_one_case(&golden_case(vec!["target"]), &ranked, &HashMap::new());
assert_eq!(result.first_relevant_rank, None);
assert_eq!(result.recall_at_k, Some(0.0));
assert_eq!(result.precision_at_k, Some(0.0));
let report = aggregate(vec![result]);
assert!(report.mean_reciprocal_rank.abs() < f64::EPSILON);
}
}