Skip to main content

brainwires_eval/
regression.rs

1//! Regression testing infrastructure for CI integration.
2//!
3//! [`RegressionSuite`] compares current [`SuiteResult`] success rates against
4//! stored per-category baselines.  If any category drops more than
5//! [`RegressionConfig::max_regression`] below its baseline, the check fails —
6//! enabling CI pipelines to gate on evaluation regressions automatically.
7//!
8//! # Quick start
9//!
10//! ```rust,ignore
11//! use brainwires_eval::{
12//!     regression::{RegressionSuite, RegressionConfig, CategoryBaseline},
13//!     EvaluationSuite, AlwaysPassCase,
14//! };
15//! use std::sync::Arc;
16//!
17//! // 1. Run the evaluation suite.
18//! let suite = EvaluationSuite::new(30);
19//! let cases = vec![Arc::new(AlwaysPassCase::new("smoke_test")) as Arc<_>];
20//! let results = suite.run_suite(&cases).await;
21//!
22//! // 2. Build baselines from current results.
23//! let mut reg = RegressionSuite::new();
24//! reg.record_baselines(&results);
25//!
26//! // 3. On the next CI run, compare.
27//! let check = reg.check(&results);
28//! assert!(check.is_ci_passing());
29//! ```
30
31use std::collections::HashMap;
32
33use chrono::Utc;
34use serde::{Deserialize, Serialize};
35
36use super::suite::SuiteResult;
37use super::trial::EvaluationStats;
38
39// ── Baseline ──────────────────────────────────────────────────────────────────
40
41/// Per-category success-rate baseline stored for regression comparison.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CategoryBaseline {
44    /// Category label matching [`EvaluationCase::category`](crate::case::EvaluationCase::category).
45    pub category: String,
46    /// Baseline success rate in [0, 1].
47    pub baseline_success_rate: f64,
48    /// Unix timestamp (seconds) when this baseline was recorded.
49    pub measured_at_unix: i64,
50    /// Number of trials used to compute this baseline.
51    pub n_trials: usize,
52}
53
54impl CategoryBaseline {
55    /// Create a new baseline from measured stats.
56    pub fn new(category: impl Into<String>, stats: &EvaluationStats) -> Self {
57        Self {
58            category: category.into(),
59            baseline_success_rate: stats.success_rate,
60            measured_at_unix: Utc::now().timestamp(),
61            n_trials: stats.n_trials,
62        }
63    }
64}
65
66// ── Configuration ─────────────────────────────────────────────────────────────
67
68/// Configuration for the regression checker.
69#[derive(Debug, Clone)]
70pub struct RegressionConfig {
71    /// Maximum tolerated regression below baseline in [0, 1]. Default: 0.05 (5 %).
72    pub max_regression: f64,
73    /// Minimum number of trials required for a category to be checked.
74    /// Categories with fewer trials are skipped (not enough data). Default: 30.
75    pub min_trials: usize,
76}
77
78impl Default for RegressionConfig {
79    fn default() -> Self {
80        Self {
81            max_regression: 0.05,
82            min_trials: 30,
83        }
84    }
85}
86
87// ── Per-category result ───────────────────────────────────────────────────────
88
89/// Result of a single category's regression check.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CategoryRegressionResult {
92    /// Category name.
93    pub category: String,
94    /// Current measured success rate.
95    pub current_success_rate: f64,
96    /// Baseline success rate this was compared against.
97    pub baseline_success_rate: f64,
98    /// `baseline - current` (positive means regression, negative means improvement).
99    pub regression: f64,
100    /// Whether this category passed the regression threshold.
101    pub passed: bool,
102    /// Human-readable reason when `passed == false`.
103    pub reason: Option<String>,
104}
105
106// ── Aggregate result ──────────────────────────────────────────────────────────
107
108/// Aggregate result of a full regression check across all categories.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RegressionResult {
111    /// `true` when all checked categories passed.
112    pub passed: bool,
113    /// Per-category breakdown.
114    pub category_results: Vec<CategoryRegressionResult>,
115}
116
117impl RegressionResult {
118    /// Whether all categories passed (suitable for CI gate).
119    pub fn is_ci_passing(&self) -> bool {
120        self.passed
121    }
122
123    /// Categories that failed the regression threshold.
124    pub fn failing_categories(&self) -> Vec<&CategoryRegressionResult> {
125        self.category_results.iter().filter(|r| !r.passed).collect()
126    }
127
128    /// Categories with improvements (negative regression).
129    pub fn improved_categories(&self) -> Vec<&CategoryRegressionResult> {
130        self.category_results
131            .iter()
132            .filter(|r| r.regression < 0.0)
133            .collect()
134    }
135}
136
137// ── RegressionSuite ───────────────────────────────────────────────────────────
138
139/// Compares evaluation suite results against stored per-category baselines.
140///
141/// Fails the check if any category's success rate drops more than
142/// [`RegressionConfig::max_regression`] below its baseline.
143pub struct RegressionSuite {
144    config: RegressionConfig,
145    /// Baseline indexed by category name.
146    baselines: HashMap<String, CategoryBaseline>,
147}
148
149impl Default for RegressionSuite {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl RegressionSuite {
156    /// Create a new regression suite with default configuration and no baselines.
157    pub fn new() -> Self {
158        Self {
159            config: RegressionConfig::default(),
160            baselines: HashMap::new(),
161        }
162    }
163
164    /// Create with custom configuration.
165    pub fn with_config(config: RegressionConfig) -> Self {
166        Self {
167            config,
168            baselines: HashMap::new(),
169        }
170    }
171
172    /// Manually register a baseline for a category.
173    pub fn with_baseline(mut self, baseline: CategoryBaseline) -> Self {
174        self.baselines.insert(baseline.category.clone(), baseline);
175        self
176    }
177
178    /// Register a baseline from an [`EvaluationStats`] object.
179    pub fn add_baseline(&mut self, category: impl Into<String>, stats: &EvaluationStats) {
180        let cat = category.into();
181        self.baselines
182            .insert(cat.clone(), CategoryBaseline::new(cat, stats));
183    }
184
185    /// Record baselines for ALL categories present in `suite_result`.
186    ///
187    /// Use this to capture the current run as the new baseline.
188    pub fn record_baselines(&mut self, suite_result: &SuiteResult) {
189        // Aggregate stats by category (not by case name)
190        let category_stats = Self::aggregate_by_category(suite_result);
191        for (category, stats) in &category_stats {
192            self.add_baseline(category.as_str(), stats);
193        }
194    }
195
196    /// Aggregate per-case stats into per-category stats.
197    ///
198    /// Combines all trial results across cases sharing the same category.
199    fn aggregate_by_category(suite_result: &SuiteResult) -> HashMap<String, EvaluationStats> {
200        // We need to look at trial results. Build a mapping:
201        // category → [all trial results from that category]
202        // NOTE: SuiteResult only has case-level stats. We infer per-category
203        // stats by re-aggregating from case_results using trial data.
204        let mut category_trials: HashMap<String, Vec<super::trial::TrialResult>> = HashMap::new();
205
206        // For aggregation we need to know which case belongs to which category.
207        // SuiteResult stores results keyed by case name. The category is embedded
208        // in EvaluationCase but not stored in SuiteResult. As a workaround, we
209        // aggregate per case name as a fallback. Callers can register baselines
210        // by category directly via `add_baseline`.
211        for (case_name, trials) in &suite_result.case_results {
212            category_trials
213                .entry(case_name.clone())
214                .or_default()
215                .extend(trials.iter().cloned());
216        }
217
218        category_trials
219            .into_iter()
220            .filter_map(|(cat, trials)| {
221                EvaluationStats::from_trials(&trials).map(|stats| (cat, stats))
222            })
223            .collect()
224    }
225
226    /// Serialize baselines to a JSON string.
227    pub fn baselines_to_json(&self) -> anyhow::Result<String> {
228        let list: Vec<&CategoryBaseline> = self.baselines.values().collect();
229        Ok(serde_json::to_string_pretty(&list)?)
230    }
231
232    /// Returns `true` when a baseline has been recorded for `category`.
233    pub fn has_baseline(&self, category: &str) -> bool {
234        self.baselines.contains_key(category)
235    }
236
237    /// Retrieve the stored baseline for `category`, or `None` if absent.
238    pub fn get_baseline(&self, category: &str) -> Option<&CategoryBaseline> {
239        self.baselines.get(category)
240    }
241
242    /// Load baselines from a JSON string (produced by [`Self::baselines_to_json`]).
243    pub fn load_baselines_from_json(json: &str) -> anyhow::Result<Self> {
244        let baselines: Vec<CategoryBaseline> = serde_json::from_str(json)?;
245        let mut map = HashMap::new();
246        for b in baselines {
247            map.insert(b.category.clone(), b);
248        }
249        Ok(Self {
250            config: RegressionConfig::default(),
251            baselines: map,
252        })
253    }
254
255    /// Run the regression check against a completed [`SuiteResult`].
256    ///
257    /// For each category with a stored baseline:
258    /// - Skip if `current_n_trials < min_trials`.
259    /// - Fail if `baseline_rate - current_rate > max_regression`.
260    pub fn check(&self, suite_result: &SuiteResult) -> RegressionResult {
261        let current_stats = Self::aggregate_by_category(suite_result);
262        let mut results = Vec::new();
263        let mut all_passed = true;
264
265        for (category, baseline) in &self.baselines {
266            let Some(current) = current_stats.get(category) else {
267                // Category present in baseline but absent from current run — skip.
268                continue;
269            };
270
271            if current.n_trials < self.config.min_trials {
272                // Not enough data — skip.
273                continue;
274            }
275
276            let regression = baseline.baseline_success_rate - current.success_rate;
277            let passed = regression <= self.config.max_regression;
278            let reason = if !passed {
279                Some(format!(
280                    "category '{}' dropped {:.1}% (from {:.1}% to {:.1}%), limit is {:.1}%",
281                    category,
282                    regression * 100.0,
283                    baseline.baseline_success_rate * 100.0,
284                    current.success_rate * 100.0,
285                    self.config.max_regression * 100.0,
286                ))
287            } else {
288                None
289            };
290
291            if !passed {
292                all_passed = false;
293            }
294
295            results.push(CategoryRegressionResult {
296                category: category.clone(),
297                current_success_rate: current.success_rate,
298                baseline_success_rate: baseline.baseline_success_rate,
299                regression,
300                passed,
301                reason,
302            });
303        }
304
305        results.sort_by(|a, b| a.category.cmp(&b.category));
306
307        RegressionResult {
308            passed: all_passed,
309            category_results: results,
310        }
311    }
312}
313
314// ── Tests ─────────────────────────────────────────────────────────────────────
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::trial::TrialResult;
320
321    fn make_stats(successes: usize, total: usize) -> EvaluationStats {
322        let trials: Vec<TrialResult> = (0..total)
323            .map(|i| {
324                if i < successes {
325                    TrialResult::success(i, 10)
326                } else {
327                    TrialResult::failure(i, 10, "fail")
328                }
329            })
330            .collect();
331        EvaluationStats::from_trials(&trials).unwrap()
332    }
333
334    #[test]
335    fn test_baseline_creation() {
336        let stats = make_stats(80, 100);
337        let baseline = CategoryBaseline::new("smoke", &stats);
338        assert_eq!(baseline.category, "smoke");
339        assert!((baseline.baseline_success_rate - 0.8).abs() < 1e-9);
340        assert_eq!(baseline.n_trials, 100);
341    }
342
343    #[test]
344    fn test_check_passes_when_no_regression() {
345        let stats = make_stats(80, 100);
346        let mut reg = RegressionSuite::new();
347        reg.add_baseline("smoke", &stats);
348
349        // Same stats → regression = 0 → passes
350        let suite_result = SuiteResult {
351            case_results: std::collections::HashMap::from([(
352                "smoke".to_string(),
353                (0..100)
354                    .map(|i| {
355                        if i < 80 {
356                            TrialResult::success(i, 10)
357                        } else {
358                            TrialResult::failure(i, 10, "fail")
359                        }
360                    })
361                    .collect(),
362            )]),
363            stats: std::collections::HashMap::from([("smoke".to_string(), stats.clone())]),
364        };
365
366        let result = reg.check(&suite_result);
367        assert!(result.is_ci_passing(), "no regression should pass");
368        assert!(result.failing_categories().is_empty());
369    }
370
371    #[test]
372    fn test_check_fails_on_regression_above_threshold() {
373        let baseline_stats = make_stats(90, 100); // 90 %
374        let mut reg = RegressionSuite::new();
375        reg.add_baseline("smoke", &baseline_stats);
376
377        // Current: 80 % — drop of 10 %, exceeds default 5 %
378        let current_stats = make_stats(80, 100);
379        let suite_result = SuiteResult {
380            case_results: std::collections::HashMap::from([(
381                "smoke".to_string(),
382                (0..100)
383                    .map(|i| {
384                        if i < 80 {
385                            TrialResult::success(i, 10)
386                        } else {
387                            TrialResult::failure(i, 10, "fail")
388                        }
389                    })
390                    .collect(),
391            )]),
392            stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
393        };
394
395        let result = reg.check(&suite_result);
396        assert!(!result.is_ci_passing(), "10% regression should fail CI");
397        assert_eq!(result.failing_categories().len(), 1);
398        let failing = &result.failing_categories()[0];
399        assert!((failing.regression - 0.1).abs() < 1e-9);
400    }
401
402    #[test]
403    fn test_check_passes_regression_within_threshold() {
404        let baseline_stats = make_stats(90, 100); // 90 %
405        let config = RegressionConfig {
406            max_regression: 0.10, // 10 % allowed
407            min_trials: 30,
408        };
409        let mut reg = RegressionSuite::with_config(config);
410        reg.add_baseline("smoke", &baseline_stats);
411
412        // Current: 82 % — drop of 8 %, within 10 %
413        let current_stats = make_stats(82, 100);
414        let suite_result = SuiteResult {
415            case_results: std::collections::HashMap::from([(
416                "smoke".to_string(),
417                (0..100)
418                    .map(|i| {
419                        if i < 82 {
420                            TrialResult::success(i, 10)
421                        } else {
422                            TrialResult::failure(i, 10, "fail")
423                        }
424                    })
425                    .collect(),
426            )]),
427            stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
428        };
429
430        let result = reg.check(&suite_result);
431        assert!(
432            result.is_ci_passing(),
433            "8% drop within 10% threshold should pass"
434        );
435    }
436
437    #[test]
438    fn test_check_skips_low_trial_count() {
439        let baseline_stats = make_stats(90, 100); // 90 %
440        let mut reg = RegressionSuite::new(); // min_trials=30
441        reg.add_baseline("smoke", &baseline_stats);
442
443        // Only 5 trials — below min_trials
444        let current_stats = make_stats(0, 5); // 0 %
445        let suite_result = SuiteResult {
446            case_results: std::collections::HashMap::from([(
447                "smoke".to_string(),
448                (0..5)
449                    .map(|i| TrialResult::failure(i, 10, "fail"))
450                    .collect(),
451            )]),
452            stats: std::collections::HashMap::from([("smoke".to_string(), current_stats)]),
453        };
454
455        // Should be skipped due to insufficient trials
456        let result = reg.check(&suite_result);
457        assert!(result.is_ci_passing(), "low trial count should be skipped");
458        assert!(result.category_results.is_empty());
459    }
460
461    #[test]
462    fn test_json_roundtrip() {
463        let stats = make_stats(75, 100);
464        let mut reg = RegressionSuite::new();
465        reg.add_baseline("smoke", &stats);
466
467        let json = reg.baselines_to_json().unwrap();
468        let loaded = RegressionSuite::load_baselines_from_json(&json).unwrap();
469        let baseline = loaded.baselines.get("smoke").unwrap();
470        assert!((baseline.baseline_success_rate - 0.75).abs() < 1e-9);
471    }
472
473    #[test]
474    fn test_record_baselines_from_suite_result() {
475        let trials: Vec<TrialResult> = (0..50)
476            .map(|i| {
477                if i < 40 {
478                    TrialResult::success(i, 10)
479                } else {
480                    TrialResult::failure(i, 10, "fail")
481                }
482            })
483            .collect();
484        let stats = EvaluationStats::from_trials(&trials).unwrap();
485        let suite_result = SuiteResult {
486            case_results: std::collections::HashMap::from([(
487                "my_case".to_string(),
488                trials.clone(),
489            )]),
490            stats: std::collections::HashMap::from([("my_case".to_string(), stats)]),
491        };
492
493        let mut reg = RegressionSuite::new();
494        reg.record_baselines(&suite_result);
495
496        assert!(reg.baselines.contains_key("my_case"));
497        let b = &reg.baselines["my_case"];
498        assert!((b.baseline_success_rate - 0.8).abs() < 1e-9);
499    }
500
501    #[test]
502    fn test_improved_categories() {
503        let baseline_stats = make_stats(70, 100); // 70 %
504        let mut reg = RegressionSuite::new();
505        reg.add_baseline("smoke", &baseline_stats);
506
507        // Current: 90 % — improvement
508        let suite_result = SuiteResult {
509            case_results: std::collections::HashMap::from([(
510                "smoke".to_string(),
511                (0..100)
512                    .map(|i| {
513                        if i < 90 {
514                            TrialResult::success(i, 10)
515                        } else {
516                            TrialResult::failure(i, 10, "fail")
517                        }
518                    })
519                    .collect(),
520            )]),
521            stats: std::collections::HashMap::from([("smoke".to_string(), make_stats(90, 100))]),
522        };
523
524        let result = reg.check(&suite_result);
525        assert!(result.is_ci_passing());
526        assert_eq!(result.improved_categories().len(), 1);
527        assert!(result.improved_categories()[0].regression < 0.0);
528    }
529}