use serde::{Deserialize, Serialize};
use crate::domain::eval::EvalThresholds;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CaseResult {
pub case_id: String,
pub score: f32,
pub passed: bool,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct EvalReport {
pub case_results: Vec<CaseResult>,
pub pass_rate: f32,
pub baseline_pass_rate: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum GateRule {
MinPassRate,
MaxRegression,
RequireTag { tag: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GateRuleSet {
pub thresholds: EvalThresholds,
pub rules: Vec<GateRule>,
}
impl GateRuleSet {
pub fn standard() -> Self {
Self {
thresholds: EvalThresholds::default(),
rules: vec![GateRule::MinPassRate, GateRule::MaxRegression],
}
}
pub fn with_rule(mut self, rule: GateRule) -> Self {
self.rules.push(rule);
self
}
pub fn with_thresholds(mut self, thresholds: EvalThresholds) -> Self {
self.thresholds = thresholds;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Violation {
pub rule: GateRule,
pub reason: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GateVerdict {
pub passed: bool,
pub violations: Vec<Violation>,
}
impl GateVerdict {
fn pass() -> Self {
Self {
passed: true,
violations: Vec::new(),
}
}
fn fail(violations: Vec<Violation>) -> Self {
Self {
passed: false,
violations,
}
}
}
pub fn evaluate_gate(rule_set: &GateRuleSet, report: &EvalReport) -> GateVerdict {
let mut violations = Vec::new();
let fail_fast = rule_set.thresholds.fail_fast;
for rule in &rule_set.rules {
if let Some(v) = check_rule(rule, &rule_set.thresholds, report) {
violations.push(v);
if fail_fast {
return GateVerdict::fail(violations);
}
}
}
if violations.is_empty() {
GateVerdict::pass()
} else {
GateVerdict::fail(violations)
}
}
fn check_rule(
rule: &GateRule,
thresholds: &EvalThresholds,
report: &EvalReport,
) -> Option<Violation> {
match rule {
GateRule::MinPassRate => {
if report.pass_rate < thresholds.min_pass_rate {
Some(Violation {
rule: rule.clone(),
reason: format!(
"pass rate {:.2}% < required {:.2}%",
report.pass_rate * 100.0,
thresholds.min_pass_rate * 100.0,
),
})
} else {
None
}
}
GateRule::MaxRegression => {
if let Some(baseline) = report.baseline_pass_rate {
let regression = baseline - report.pass_rate;
if regression > thresholds.max_regression {
Some(Violation {
rule: rule.clone(),
reason: format!(
"regression {:.2}% > allowed {:.2}% (baseline {:.2}% → current {:.2}%)",
regression * 100.0,
thresholds.max_regression * 100.0,
baseline * 100.0,
report.pass_rate * 100.0,
),
})
} else {
None
}
} else {
None
}
}
GateRule::RequireTag { tag } => {
let tagged: Vec<&CaseResult> = report
.case_results
.iter()
.filter(|c| c.tags.contains(tag))
.collect();
let failed: Vec<&str> = tagged
.iter()
.filter(|c| !c.passed)
.map(|c| c.case_id.as_str())
.collect();
if failed.is_empty() {
None
} else {
Some(Violation {
rule: rule.clone(),
reason: format!(
"{} of {} cases tagged '{}' failed: [{}]",
failed.len(),
tagged.len(),
tag,
failed.join(", "),
),
})
}
}
}
}