use std::collections::HashMap;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use super::suite::SuiteResult;
use super::trial::EvaluationStats;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoryBaseline {
pub category: String,
pub baseline_success_rate: f64,
pub measured_at_unix: i64,
pub n_trials: usize,
}
impl CategoryBaseline {
pub fn new(category: impl Into<String>, stats: &EvaluationStats) -> Self {
Self {
category: category.into(),
baseline_success_rate: stats.success_rate,
measured_at_unix: Utc::now().timestamp(),
n_trials: stats.n_trials,
}
}
}
#[derive(Debug, Clone)]
pub struct RegressionConfig {
pub max_regression: f64,
pub min_trials: usize,
}
impl Default for RegressionConfig {
fn default() -> Self {
Self {
max_regression: 0.05,
min_trials: 30,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoryRegressionResult {
pub category: String,
pub current_success_rate: f64,
pub baseline_success_rate: f64,
pub regression: f64,
pub passed: bool,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegressionResult {
pub passed: bool,
pub category_results: Vec<CategoryRegressionResult>,
}
impl RegressionResult {
pub fn is_ci_passing(&self) -> bool {
self.passed
}
pub fn failing_categories(&self) -> Vec<&CategoryRegressionResult> {
self.category_results.iter().filter(|r| !r.passed).collect()
}
pub fn improved_categories(&self) -> Vec<&CategoryRegressionResult> {
self.category_results
.iter()
.filter(|r| r.regression < 0.0)
.collect()
}
}
pub struct RegressionSuite {
config: RegressionConfig,
baselines: HashMap<String, CategoryBaseline>,
}
impl Default for RegressionSuite {
fn default() -> Self {
Self::new()
}
}
impl RegressionSuite {
pub fn new() -> Self {
Self {
config: RegressionConfig::default(),
baselines: HashMap::new(),
}
}
pub fn with_config(config: RegressionConfig) -> Self {
Self {
config,
baselines: HashMap::new(),
}
}
pub fn with_baseline(mut self, baseline: CategoryBaseline) -> Self {
self.baselines.insert(baseline.category.clone(), baseline);
self
}
pub fn add_baseline(&mut self, category: impl Into<String>, stats: &EvaluationStats) {
let cat = category.into();
self.baselines
.insert(cat.clone(), CategoryBaseline::new(cat, stats));
}
pub fn record_baselines(&mut self, suite_result: &SuiteResult) {
let category_stats = Self::aggregate_by_category(suite_result);
for (category, stats) in &category_stats {
self.add_baseline(category.as_str(), stats);
}
}
fn aggregate_by_category(suite_result: &SuiteResult) -> HashMap<String, EvaluationStats> {
let mut category_trials: HashMap<String, Vec<super::trial::TrialResult>> = HashMap::new();
for (case_name, trials) in &suite_result.case_results {
category_trials
.entry(case_name.clone())
.or_default()
.extend(trials.iter().cloned());
}
category_trials
.into_iter()
.filter_map(|(cat, trials)| {
EvaluationStats::from_trials(&trials).map(|stats| (cat, stats))
})
.collect()
}
pub fn baselines_to_json(&self) -> anyhow::Result<String> {
let list: Vec<&CategoryBaseline> = self.baselines.values().collect();
Ok(serde_json::to_string_pretty(&list)?)
}
pub fn has_baseline(&self, category: &str) -> bool {
self.baselines.contains_key(category)
}
pub fn get_baseline(&self, category: &str) -> Option<&CategoryBaseline> {
self.baselines.get(category)
}
pub fn load_baselines_from_json(json: &str) -> anyhow::Result<Self> {
let baselines: Vec<CategoryBaseline> = serde_json::from_str(json)?;
let mut map = HashMap::new();
for b in baselines {
map.insert(b.category.clone(), b);
}
Ok(Self {
config: RegressionConfig::default(),
baselines: map,
})
}
pub fn check(&self, suite_result: &SuiteResult) -> RegressionResult {
let current_stats = Self::aggregate_by_category(suite_result);
let mut results = Vec::new();
let mut all_passed = true;
for (category, baseline) in &self.baselines {
let Some(current) = current_stats.get(category) else {
continue;
};
if current.n_trials < self.config.min_trials {
continue;
}
let regression = baseline.baseline_success_rate - current.success_rate;
let passed = regression <= self.config.max_regression;
let reason = if !passed {
Some(format!(
"category '{}' dropped {:.1}% (from {:.1}% to {:.1}%), limit is {:.1}%",
category,
regression * 100.0,
baseline.baseline_success_rate * 100.0,
current.success_rate * 100.0,
self.config.max_regression * 100.0,
))
} else {
None
};
if !passed {
all_passed = false;
}
results.push(CategoryRegressionResult {
category: category.clone(),
current_success_rate: current.success_rate,
baseline_success_rate: baseline.baseline_success_rate,
regression,
passed,
reason,
});
}
results.sort_by(|a, b| a.category.cmp(&b.category));
RegressionResult {
passed: all_passed,
category_results: results,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::trial::TrialResult;
fn make_stats(successes: usize, total: usize) -> EvaluationStats {
let trials: Vec<TrialResult> = (0..total)
.map(|i| {
if i < successes {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect();
EvaluationStats::from_trials(&trials).unwrap()
}
#[test]
fn test_baseline_creation() {
let stats = make_stats(80, 100);
let baseline = CategoryBaseline::new("smoke", &stats);
assert_eq!(baseline.category, "smoke");
assert!((baseline.baseline_success_rate - 0.8).abs() < 1e-9);
assert_eq!(baseline.n_trials, 100);
}
#[test]
fn test_check_passes_when_no_regression() {
let stats = make_stats(80, 100);
let mut reg = RegressionSuite::new();
reg.add_baseline("smoke", &stats);
let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"smoke".to_string(),
(0..100)
.map(|i| {
if i < 80 {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect(),
)]),
stats: std::collections::HashMap::from([("smoke".to_string(), stats.clone())]),
};
let result = reg.check(&suite_result);
assert!(result.is_ci_passing(), "no regression should pass");
assert!(result.failing_categories().is_empty());
}
#[test]
fn test_check_fails_on_regression_above_threshold() {
let baseline_stats = make_stats(90, 100); let mut reg = RegressionSuite::new();
reg.add_baseline("smoke", &baseline_stats);
let current_stats = make_stats(80, 100);
let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"smoke".to_string(),
(0..100)
.map(|i| {
if i < 80 {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect(),
)]),
stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
};
let result = reg.check(&suite_result);
assert!(!result.is_ci_passing(), "10% regression should fail CI");
assert_eq!(result.failing_categories().len(), 1);
let failing = &result.failing_categories()[0];
assert!((failing.regression - 0.1).abs() < 1e-9);
}
#[test]
fn test_check_passes_regression_within_threshold() {
let baseline_stats = make_stats(90, 100); let config = RegressionConfig {
max_regression: 0.10, min_trials: 30,
};
let mut reg = RegressionSuite::with_config(config);
reg.add_baseline("smoke", &baseline_stats);
let current_stats = make_stats(82, 100);
let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"smoke".to_string(),
(0..100)
.map(|i| {
if i < 82 {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect(),
)]),
stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
};
let result = reg.check(&suite_result);
assert!(
result.is_ci_passing(),
"8% drop within 10% threshold should pass"
);
}
#[test]
fn test_check_skips_low_trial_count() {
let baseline_stats = make_stats(90, 100); let mut reg = RegressionSuite::new(); reg.add_baseline("smoke", &baseline_stats);
let current_stats = make_stats(0, 5); let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"smoke".to_string(),
(0..5)
.map(|i| TrialResult::failure(i, 10, "fail"))
.collect(),
)]),
stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
};
let result = reg.check(&suite_result);
assert!(result.is_ci_passing(), "low trial count should be skipped");
assert!(result.category_results.is_empty());
}
#[test]
fn test_json_roundtrip() {
let stats = make_stats(75, 100);
let mut reg = RegressionSuite::new();
reg.add_baseline("smoke", &stats);
let json = reg.baselines_to_json().unwrap();
let loaded = RegressionSuite::load_baselines_from_json(&json).unwrap();
let baseline = loaded.baselines.get("smoke").unwrap();
assert!((baseline.baseline_success_rate - 0.75).abs() < 1e-9);
}
#[test]
fn test_record_baselines_from_suite_result() {
let trials: Vec<TrialResult> = (0..50)
.map(|i| {
if i < 40 {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect();
let stats = EvaluationStats::from_trials(&trials).unwrap();
let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"my_case".to_string(),
trials.clone(),
)]),
stats: std::collections::HashMap::from([("my_case".to_string(), stats)]),
};
let mut reg = RegressionSuite::new();
reg.record_baselines(&suite_result);
assert!(reg.baselines.contains_key("my_case"));
let b = ®.baselines["my_case"];
assert!((b.baseline_success_rate - 0.8).abs() < 1e-9);
}
#[test]
fn test_improved_categories() {
let baseline_stats = make_stats(70, 100); let mut reg = RegressionSuite::new();
reg.add_baseline("smoke", &baseline_stats);
let suite_result = SuiteResult {
case_results: std::collections::HashMap::from([(
"smoke".to_string(),
(0..100)
.map(|i| {
if i < 90 {
TrialResult::success(i, 10)
} else {
TrialResult::failure(i, 10, "fail")
}
})
.collect(),
)]),
stats: std::collections::HashMap::from([("smoke".to_string(), make_stats(90, 100))]),
};
let result = reg.check(&suite_result);
assert!(result.is_ci_passing());
assert_eq!(result.improved_categories().len(), 1);
assert!(result.improved_categories()[0].regression < 0.0);
}
}