1use std::collections::HashMap;
32
33use chrono::Utc;
34use serde::{Deserialize, Serialize};
35
36use super::suite::SuiteResult;
37use super::trial::EvaluationStats;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CategoryBaseline {
44 pub category: String,
46 pub baseline_success_rate: f64,
48 pub measured_at_unix: i64,
50 pub n_trials: usize,
52}
53
54impl CategoryBaseline {
55 pub fn new(category: impl Into<String>, stats: &EvaluationStats) -> Self {
57 Self {
58 category: category.into(),
59 baseline_success_rate: stats.success_rate,
60 measured_at_unix: Utc::now().timestamp(),
61 n_trials: stats.n_trials,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
70pub struct RegressionConfig {
71 pub max_regression: f64,
73 pub min_trials: usize,
76}
77
78impl Default for RegressionConfig {
79 fn default() -> Self {
80 Self {
81 max_regression: 0.05,
82 min_trials: 30,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CategoryRegressionResult {
92 pub category: String,
94 pub current_success_rate: f64,
96 pub baseline_success_rate: f64,
98 pub regression: f64,
100 pub passed: bool,
102 pub reason: Option<String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RegressionResult {
111 pub passed: bool,
113 pub category_results: Vec<CategoryRegressionResult>,
115}
116
117impl RegressionResult {
118 pub fn is_ci_passing(&self) -> bool {
120 self.passed
121 }
122
123 pub fn failing_categories(&self) -> Vec<&CategoryRegressionResult> {
125 self.category_results.iter().filter(|r| !r.passed).collect()
126 }
127
128 pub fn improved_categories(&self) -> Vec<&CategoryRegressionResult> {
130 self.category_results
131 .iter()
132 .filter(|r| r.regression < 0.0)
133 .collect()
134 }
135}
136
137pub struct RegressionSuite {
144 config: RegressionConfig,
145 baselines: HashMap<String, CategoryBaseline>,
147}
148
149impl Default for RegressionSuite {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl RegressionSuite {
156 pub fn new() -> Self {
158 Self {
159 config: RegressionConfig::default(),
160 baselines: HashMap::new(),
161 }
162 }
163
164 pub fn with_config(config: RegressionConfig) -> Self {
166 Self {
167 config,
168 baselines: HashMap::new(),
169 }
170 }
171
172 pub fn with_baseline(mut self, baseline: CategoryBaseline) -> Self {
174 self.baselines.insert(baseline.category.clone(), baseline);
175 self
176 }
177
178 pub fn add_baseline(&mut self, category: impl Into<String>, stats: &EvaluationStats) {
180 let cat = category.into();
181 self.baselines
182 .insert(cat.clone(), CategoryBaseline::new(cat, stats));
183 }
184
185 pub fn record_baselines(&mut self, suite_result: &SuiteResult) {
189 let category_stats = Self::aggregate_by_category(suite_result);
191 for (category, stats) in &category_stats {
192 self.add_baseline(category.as_str(), stats);
193 }
194 }
195
196 fn aggregate_by_category(suite_result: &SuiteResult) -> HashMap<String, EvaluationStats> {
200 let mut category_trials: HashMap<String, Vec<super::trial::TrialResult>> = HashMap::new();
205
206 for (case_name, trials) in &suite_result.case_results {
212 category_trials
213 .entry(case_name.clone())
214 .or_default()
215 .extend(trials.iter().cloned());
216 }
217
218 category_trials
219 .into_iter()
220 .filter_map(|(cat, trials)| {
221 EvaluationStats::from_trials(&trials).map(|stats| (cat, stats))
222 })
223 .collect()
224 }
225
226 pub fn baselines_to_json(&self) -> anyhow::Result<String> {
228 let list: Vec<&CategoryBaseline> = self.baselines.values().collect();
229 Ok(serde_json::to_string_pretty(&list)?)
230 }
231
232 pub fn has_baseline(&self, category: &str) -> bool {
234 self.baselines.contains_key(category)
235 }
236
237 pub fn get_baseline(&self, category: &str) -> Option<&CategoryBaseline> {
239 self.baselines.get(category)
240 }
241
242 pub fn load_baselines_from_json(json: &str) -> anyhow::Result<Self> {
244 let baselines: Vec<CategoryBaseline> = serde_json::from_str(json)?;
245 let mut map = HashMap::new();
246 for b in baselines {
247 map.insert(b.category.clone(), b);
248 }
249 Ok(Self {
250 config: RegressionConfig::default(),
251 baselines: map,
252 })
253 }
254
255 pub fn check(&self, suite_result: &SuiteResult) -> RegressionResult {
261 let current_stats = Self::aggregate_by_category(suite_result);
262 let mut results = Vec::new();
263 let mut all_passed = true;
264
265 for (category, baseline) in &self.baselines {
266 let Some(current) = current_stats.get(category) else {
267 continue;
269 };
270
271 if current.n_trials < self.config.min_trials {
272 continue;
274 }
275
276 let regression = baseline.baseline_success_rate - current.success_rate;
277 let passed = regression <= self.config.max_regression;
278 let reason = if !passed {
279 Some(format!(
280 "category '{}' dropped {:.1}% (from {:.1}% to {:.1}%), limit is {:.1}%",
281 category,
282 regression * 100.0,
283 baseline.baseline_success_rate * 100.0,
284 current.success_rate * 100.0,
285 self.config.max_regression * 100.0,
286 ))
287 } else {
288 None
289 };
290
291 if !passed {
292 all_passed = false;
293 }
294
295 results.push(CategoryRegressionResult {
296 category: category.clone(),
297 current_success_rate: current.success_rate,
298 baseline_success_rate: baseline.baseline_success_rate,
299 regression,
300 passed,
301 reason,
302 });
303 }
304
305 results.sort_by(|a, b| a.category.cmp(&b.category));
306
307 RegressionResult {
308 passed: all_passed,
309 category_results: results,
310 }
311 }
312}
313
314#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::trial::TrialResult;
320
321 fn make_stats(successes: usize, total: usize) -> EvaluationStats {
322 let trials: Vec<TrialResult> = (0..total)
323 .map(|i| {
324 if i < successes {
325 TrialResult::success(i, 10)
326 } else {
327 TrialResult::failure(i, 10, "fail")
328 }
329 })
330 .collect();
331 EvaluationStats::from_trials(&trials).unwrap()
332 }
333
334 #[test]
335 fn test_baseline_creation() {
336 let stats = make_stats(80, 100);
337 let baseline = CategoryBaseline::new("smoke", &stats);
338 assert_eq!(baseline.category, "smoke");
339 assert!((baseline.baseline_success_rate - 0.8).abs() < 1e-9);
340 assert_eq!(baseline.n_trials, 100);
341 }
342
343 #[test]
344 fn test_check_passes_when_no_regression() {
345 let stats = make_stats(80, 100);
346 let mut reg = RegressionSuite::new();
347 reg.add_baseline("smoke", &stats);
348
349 let suite_result = SuiteResult {
351 case_results: std::collections::HashMap::from([(
352 "smoke".to_string(),
353 (0..100)
354 .map(|i| {
355 if i < 80 {
356 TrialResult::success(i, 10)
357 } else {
358 TrialResult::failure(i, 10, "fail")
359 }
360 })
361 .collect(),
362 )]),
363 stats: std::collections::HashMap::from([("smoke".to_string(), stats.clone())]),
364 };
365
366 let result = reg.check(&suite_result);
367 assert!(result.is_ci_passing(), "no regression should pass");
368 assert!(result.failing_categories().is_empty());
369 }
370
371 #[test]
372 fn test_check_fails_on_regression_above_threshold() {
373 let baseline_stats = make_stats(90, 100); let mut reg = RegressionSuite::new();
375 reg.add_baseline("smoke", &baseline_stats);
376
377 let current_stats = make_stats(80, 100);
379 let suite_result = SuiteResult {
380 case_results: std::collections::HashMap::from([(
381 "smoke".to_string(),
382 (0..100)
383 .map(|i| {
384 if i < 80 {
385 TrialResult::success(i, 10)
386 } else {
387 TrialResult::failure(i, 10, "fail")
388 }
389 })
390 .collect(),
391 )]),
392 stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
393 };
394
395 let result = reg.check(&suite_result);
396 assert!(!result.is_ci_passing(), "10% regression should fail CI");
397 assert_eq!(result.failing_categories().len(), 1);
398 let failing = &result.failing_categories()[0];
399 assert!((failing.regression - 0.1).abs() < 1e-9);
400 }
401
402 #[test]
403 fn test_check_passes_regression_within_threshold() {
404 let baseline_stats = make_stats(90, 100); let config = RegressionConfig {
406 max_regression: 0.10, min_trials: 30,
408 };
409 let mut reg = RegressionSuite::with_config(config);
410 reg.add_baseline("smoke", &baseline_stats);
411
412 let current_stats = make_stats(82, 100);
414 let suite_result = SuiteResult {
415 case_results: std::collections::HashMap::from([(
416 "smoke".to_string(),
417 (0..100)
418 .map(|i| {
419 if i < 82 {
420 TrialResult::success(i, 10)
421 } else {
422 TrialResult::failure(i, 10, "fail")
423 }
424 })
425 .collect(),
426 )]),
427 stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
428 };
429
430 let result = reg.check(&suite_result);
431 assert!(
432 result.is_ci_passing(),
433 "8% drop within 10% threshold should pass"
434 );
435 }
436
437 #[test]
438 fn test_check_skips_low_trial_count() {
439 let baseline_stats = make_stats(90, 100); let mut reg = RegressionSuite::new(); reg.add_baseline("smoke", &baseline_stats);
442
443 let current_stats = make_stats(0, 5); let suite_result = SuiteResult {
446 case_results: std::collections::HashMap::from([(
447 "smoke".to_string(),
448 (0..5)
449 .map(|i| TrialResult::failure(i, 10, "fail"))
450 .collect(),
451 )]),
452 stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
453 };
454
455 let result = reg.check(&suite_result);
457 assert!(result.is_ci_passing(), "low trial count should be skipped");
458 assert!(result.category_results.is_empty());
459 }
460
461 #[test]
462 fn test_json_roundtrip() {
463 let stats = make_stats(75, 100);
464 let mut reg = RegressionSuite::new();
465 reg.add_baseline("smoke", &stats);
466
467 let json = reg.baselines_to_json().unwrap();
468 let loaded = RegressionSuite::load_baselines_from_json(&json).unwrap();
469 let baseline = loaded.baselines.get("smoke").unwrap();
470 assert!((baseline.baseline_success_rate - 0.75).abs() < 1e-9);
471 }
472
473 #[test]
474 fn test_record_baselines_from_suite_result() {
475 let trials: Vec<TrialResult> = (0..50)
476 .map(|i| {
477 if i < 40 {
478 TrialResult::success(i, 10)
479 } else {
480 TrialResult::failure(i, 10, "fail")
481 }
482 })
483 .collect();
484 let stats = EvaluationStats::from_trials(&trials).unwrap();
485 let suite_result = SuiteResult {
486 case_results: std::collections::HashMap::from([(
487 "my_case".to_string(),
488 trials.clone(),
489 )]),
490 stats: std::collections::HashMap::from([("my_case".to_string(), stats)]),
491 };
492
493 let mut reg = RegressionSuite::new();
494 reg.record_baselines(&suite_result);
495
496 assert!(reg.baselines.contains_key("my_case"));
497 let b = ®.baselines["my_case"];
498 assert!((b.baseline_success_rate - 0.8).abs() < 1e-9);
499 }
500
501 #[test]
502 fn test_improved_categories() {
503 let baseline_stats = make_stats(70, 100); let mut reg = RegressionSuite::new();
505 reg.add_baseline("smoke", &baseline_stats);
506
507 let suite_result = SuiteResult {
509 case_results: std::collections::HashMap::from([(
510 "smoke".to_string(),
511 (0..100)
512 .map(|i| {
513 if i < 90 {
514 TrialResult::success(i, 10)
515 } else {
516 TrialResult::failure(i, 10, "fail")
517 }
518 })
519 .collect(),
520 )]),
521 stats: std::collections::HashMap::from([("smoke".to_string(), make_stats(90, 100))]),
522 };
523
524 let result = reg.check(&suite_result);
525 assert!(result.is_ci_passing());
526 assert_eq!(result.improved_categories().len(), 1);
527 assert!(result.improved_categories()[0].regression < 0.0);
528 }
529}