1use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6
7use crate::types::EvalSetResult;
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub struct GateConfig {
17 #[serde(default, skip_serializing_if = "Option::is_none")]
19 pub min_pass_rate: Option<f64>,
20 #[serde(default, skip_serializing_if = "Option::is_none")]
22 pub max_cost: Option<f64>,
23 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub max_duration: Option<Duration>,
26}
27
28impl GateConfig {
29 #[must_use]
31 pub const fn new() -> Self {
32 Self {
33 min_pass_rate: None,
34 max_cost: None,
35 max_duration: None,
36 }
37 }
38
39 #[must_use]
41 pub fn with_min_pass_rate(mut self, rate: f64) -> Self {
42 assert!(
43 (0.0..=1.0).contains(&rate),
44 "pass rate must be in [0.0, 1.0], got {rate}"
45 );
46 self.min_pass_rate = Some(rate);
47 self
48 }
49
50 #[must_use]
52 pub fn with_max_cost(mut self, cost: f64) -> Self {
53 assert!(cost >= 0.0, "cost must be non-negative, got {cost}");
54 self.max_cost = Some(cost);
55 self
56 }
57
58 #[must_use]
60 pub const fn with_max_duration(mut self, duration: Duration) -> Self {
61 self.max_duration = Some(duration);
62 self
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct GateResult {
69 pub passed: bool,
71 pub exit_code: i32,
73 pub summary: String,
75}
76
77impl GateResult {
78 pub fn exit(&self) -> ! {
80 std::process::exit(self.exit_code)
81 }
82}
83
84#[must_use]
88#[allow(clippy::cast_precision_loss)]
89pub fn check_gate(result: &EvalSetResult, config: &GateConfig) -> GateResult {
90 let mut failures: Vec<String> = Vec::new();
91
92 let total = result.summary.total_cases;
93 let passed = result.summary.passed;
94
95 if let Some(min_rate) = config.min_pass_rate {
96 let rate = if total == 0 {
97 1.0
98 } else {
99 passed as f64 / total as f64
100 };
101 if rate < min_rate {
102 failures.push(format!(
103 "pass rate {rate:.2} < minimum {min_rate:.2} ({passed}/{total})"
104 ));
105 }
106 }
107
108 if let Some(max_cost) = config.max_cost
109 && result.summary.total_cost.total > max_cost
110 {
111 failures.push(format!(
112 "cost ${:.4} > max ${max_cost:.4}",
113 result.summary.total_cost.total
114 ));
115 }
116
117 if let Some(max_dur) = config.max_duration
118 && result.summary.total_duration > max_dur
119 {
120 failures.push(format!(
121 "duration {:.1}s > max {:.1}s",
122 result.summary.total_duration.as_secs_f64(),
123 max_dur.as_secs_f64()
124 ));
125 }
126
127 if failures.is_empty() {
128 GateResult {
129 passed: true,
130 exit_code: 0,
131 summary: format!("GATE PASSED: {passed}/{total} cases passed"),
132 }
133 } else {
134 GateResult {
135 passed: false,
136 exit_code: 1,
137 summary: format!("GATE FAILED: {}", failures.join("; ")),
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 use swink_agent::{Cost, Usage};
147
148 use crate::types::{EvalSetResult, EvalSummary};
149
150 fn make_result(passed: usize, failed: usize, cost: f64, duration: Duration) -> EvalSetResult {
151 EvalSetResult {
152 eval_set_id: "test".to_string(),
153 case_results: Vec::new(),
154 summary: EvalSummary {
155 total_cases: passed + failed,
156 passed,
157 failed,
158 total_cost: Cost {
159 total: cost,
160 ..Default::default()
161 },
162 total_usage: Usage::default(),
163 total_duration: duration,
164 },
165 timestamp: 0,
166 }
167 }
168
169 #[test]
170 fn all_pass_no_config() {
171 let result = make_result(5, 2, 1.0, Duration::from_secs(10));
172 let config = GateConfig::new();
173 let gate = check_gate(&result, &config);
174 assert!(gate.passed);
175 assert_eq!(gate.exit_code, 0);
176 }
177
178 #[test]
179 fn pass_rate_met() {
180 let result = make_result(9, 1, 0.5, Duration::from_secs(5));
181 let config = GateConfig::new().with_min_pass_rate(0.9);
182 let gate = check_gate(&result, &config);
183 assert!(gate.passed);
184 }
185
186 #[test]
187 fn pass_rate_not_met() {
188 let result = make_result(8, 2, 0.5, Duration::from_secs(5));
189 let config = GateConfig::new().with_min_pass_rate(0.9);
190 let gate = check_gate(&result, &config);
191 assert!(!gate.passed);
192 assert_eq!(gate.exit_code, 1);
193 assert!(gate.summary.contains("pass rate"));
194 }
195
196 #[test]
197 fn cost_exceeded() {
198 let result = make_result(10, 0, 5.0, Duration::from_secs(5));
199 let config = GateConfig::new().with_max_cost(2.0);
200 let gate = check_gate(&result, &config);
201 assert!(!gate.passed);
202 assert!(gate.summary.contains("cost"));
203 }
204
205 #[test]
206 fn duration_exceeded() {
207 let result = make_result(10, 0, 0.5, Duration::from_mins(1));
208 let config = GateConfig::new().with_max_duration(Duration::from_secs(30));
209 let gate = check_gate(&result, &config);
210 assert!(!gate.passed);
211 assert!(gate.summary.contains("duration"));
212 }
213
214 #[test]
215 fn multiple_failures_reported() {
216 let result = make_result(5, 5, 10.0, Duration::from_secs(5));
217 let config = GateConfig::new().with_min_pass_rate(0.9).with_max_cost(1.0);
218 let gate = check_gate(&result, &config);
219 assert!(!gate.passed);
220 assert!(gate.summary.contains("pass rate"));
221 assert!(gate.summary.contains("cost"));
222 }
223
224 #[test]
225 fn zero_cases_passes() {
226 let result = make_result(0, 0, 0.0, Duration::from_secs(0));
227 let config = GateConfig::new().with_min_pass_rate(0.95);
228 let gate = check_gate(&result, &config);
229 assert!(gate.passed);
230 }
231}