Skip to main content

swink_agent_eval/
gate.rs

1//! Post-evaluation gating for CI/CD pipelines.
2
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6
7use crate::types::EvalSetResult;
8
9/// Configuration for CI/CD gate checks against evaluation results.
10///
11/// Serde-ready so consumers (notably the `swink-eval gate` subcommand)
12/// can load thresholds from a YAML/JSON file without re-declaring the
13/// shape.
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub struct GateConfig {
17    /// Minimum fraction of cases that must pass (e.g. 0.95 for 95%).
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub min_pass_rate: Option<f64>,
20    /// Maximum total cost in dollars.
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    pub max_cost: Option<f64>,
23    /// Maximum total wall-clock duration.
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub max_duration: Option<Duration>,
26}
27
28impl GateConfig {
29    /// Create a new empty gate configuration (all checks disabled).
30    #[must_use]
31    pub const fn new() -> Self {
32        Self {
33            min_pass_rate: None,
34            max_cost: None,
35            max_duration: None,
36        }
37    }
38
39    /// Set the minimum pass rate threshold.
40    #[must_use]
41    pub fn with_min_pass_rate(mut self, rate: f64) -> Self {
42        assert!(
43            (0.0..=1.0).contains(&rate),
44            "pass rate must be in [0.0, 1.0], got {rate}"
45        );
46        self.min_pass_rate = Some(rate);
47        self
48    }
49
50    /// Set the maximum allowed cost in dollars.
51    #[must_use]
52    pub fn with_max_cost(mut self, cost: f64) -> Self {
53        assert!(cost >= 0.0, "cost must be non-negative, got {cost}");
54        self.max_cost = Some(cost);
55        self
56    }
57
58    /// Set the maximum allowed wall-clock duration.
59    #[must_use]
60    pub const fn with_max_duration(mut self, duration: Duration) -> Self {
61        self.max_duration = Some(duration);
62        self
63    }
64}
65
66/// Result of a CI/CD gate check.
67#[derive(Debug, Clone)]
68pub struct GateResult {
69    /// Whether the gate check passed.
70    pub passed: bool,
71    /// Process exit code: 0 for pass, 1 for fail.
72    pub exit_code: i32,
73    /// Human-readable summary of the gate result.
74    pub summary: String,
75}
76
77impl GateResult {
78    /// Exit the process with this result's exit code.
79    pub fn exit(&self) -> ! {
80        std::process::exit(self.exit_code)
81    }
82}
83
84/// Check evaluation results against gate configuration.
85///
86/// Returns a [`GateResult`] indicating whether all configured thresholds were met.
87#[must_use]
88#[allow(clippy::cast_precision_loss)]
89pub fn check_gate(result: &EvalSetResult, config: &GateConfig) -> GateResult {
90    let mut failures: Vec<String> = Vec::new();
91
92    let total = result.summary.total_cases;
93    let passed = result.summary.passed;
94
95    if let Some(min_rate) = config.min_pass_rate {
96        let rate = if total == 0 {
97            1.0
98        } else {
99            passed as f64 / total as f64
100        };
101        if rate < min_rate {
102            failures.push(format!(
103                "pass rate {rate:.2} < minimum {min_rate:.2} ({passed}/{total})"
104            ));
105        }
106    }
107
108    if let Some(max_cost) = config.max_cost
109        && result.summary.total_cost.total > max_cost
110    {
111        failures.push(format!(
112            "cost ${:.4} > max ${max_cost:.4}",
113            result.summary.total_cost.total
114        ));
115    }
116
117    if let Some(max_dur) = config.max_duration
118        && result.summary.total_duration > max_dur
119    {
120        failures.push(format!(
121            "duration {:.1}s > max {:.1}s",
122            result.summary.total_duration.as_secs_f64(),
123            max_dur.as_secs_f64()
124        ));
125    }
126
127    if failures.is_empty() {
128        GateResult {
129            passed: true,
130            exit_code: 0,
131            summary: format!("GATE PASSED: {passed}/{total} cases passed"),
132        }
133    } else {
134        GateResult {
135            passed: false,
136            exit_code: 1,
137            summary: format!("GATE FAILED: {}", failures.join("; ")),
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    use swink_agent::{Cost, Usage};
147
148    use crate::types::{EvalSetResult, EvalSummary};
149
150    fn make_result(passed: usize, failed: usize, cost: f64, duration: Duration) -> EvalSetResult {
151        EvalSetResult {
152            eval_set_id: "test".to_string(),
153            case_results: Vec::new(),
154            summary: EvalSummary {
155                total_cases: passed + failed,
156                passed,
157                failed,
158                total_cost: Cost {
159                    total: cost,
160                    ..Default::default()
161                },
162                total_usage: Usage::default(),
163                total_duration: duration,
164            },
165            timestamp: 0,
166        }
167    }
168
169    #[test]
170    fn all_pass_no_config() {
171        let result = make_result(5, 2, 1.0, Duration::from_secs(10));
172        let config = GateConfig::new();
173        let gate = check_gate(&result, &config);
174        assert!(gate.passed);
175        assert_eq!(gate.exit_code, 0);
176    }
177
178    #[test]
179    fn pass_rate_met() {
180        let result = make_result(9, 1, 0.5, Duration::from_secs(5));
181        let config = GateConfig::new().with_min_pass_rate(0.9);
182        let gate = check_gate(&result, &config);
183        assert!(gate.passed);
184    }
185
186    #[test]
187    fn pass_rate_not_met() {
188        let result = make_result(8, 2, 0.5, Duration::from_secs(5));
189        let config = GateConfig::new().with_min_pass_rate(0.9);
190        let gate = check_gate(&result, &config);
191        assert!(!gate.passed);
192        assert_eq!(gate.exit_code, 1);
193        assert!(gate.summary.contains("pass rate"));
194    }
195
196    #[test]
197    fn cost_exceeded() {
198        let result = make_result(10, 0, 5.0, Duration::from_secs(5));
199        let config = GateConfig::new().with_max_cost(2.0);
200        let gate = check_gate(&result, &config);
201        assert!(!gate.passed);
202        assert!(gate.summary.contains("cost"));
203    }
204
205    #[test]
206    fn duration_exceeded() {
207        let result = make_result(10, 0, 0.5, Duration::from_mins(1));
208        let config = GateConfig::new().with_max_duration(Duration::from_secs(30));
209        let gate = check_gate(&result, &config);
210        assert!(!gate.passed);
211        assert!(gate.summary.contains("duration"));
212    }
213
214    #[test]
215    fn multiple_failures_reported() {
216        let result = make_result(5, 5, 10.0, Duration::from_secs(5));
217        let config = GateConfig::new().with_min_pass_rate(0.9).with_max_cost(1.0);
218        let gate = check_gate(&result, &config);
219        assert!(!gate.passed);
220        assert!(gate.summary.contains("pass rate"));
221        assert!(gate.summary.contains("cost"));
222    }
223
224    #[test]
225    fn zero_cases_passes() {
226        let result = make_result(0, 0, 0.0, Duration::from_secs(0));
227        let config = GateConfig::new().with_min_pass_rate(0.95);
228        let gate = check_gate(&result, &config);
229        assert!(gate.passed);
230    }
231}