Skip to main content

brainwires_eval/
case.rs

1//! The [`EvaluationCase`] trait — the unit of evaluation.
2//!
3//! Implement this trait for any scenario you want to evaluate N times.
4
5use async_trait::async_trait;
6
7use super::trial::TrialResult;
8
9/// A single evaluation scenario.
10///
11/// Implement this trait and pass instances to
12/// [`EvaluationSuite`](crate::suite::EvaluationSuite) to run N independent
13/// trials and compute statistics.
14///
15/// ```rust,ignore
16/// use brainwires_eval::{EvaluationCase, TrialResult};
17/// use async_trait::async_trait;
18///
19/// struct MyCase;
20///
21/// #[async_trait]
22/// impl EvaluationCase for MyCase {
23///     fn name(&self) -> &str { "my_case" }
24///     fn category(&self) -> &str { "smoke" }
25///     async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
26///         let start = std::time::Instant::now();
27///         let ok = do_the_thing().await.is_ok();
28///         let ms = start.elapsed().as_millis() as u64;
29///         Ok(if ok {
30///             TrialResult::success(trial_id, ms)
31///         } else {
32///             TrialResult::failure(trial_id, ms, "thing failed")
33///         })
34///     }
35/// }
36/// ```
37#[async_trait]
38pub trait EvaluationCase: Send + Sync {
39    /// Short identifier used in reports and log output.
40    fn name(&self) -> &str;
41
42    /// Category label for grouping (e.g. `"smoke"`, `"adversarial"`,
43    /// `"budget_stress"`).
44    fn category(&self) -> &str;
45
46    /// Execute one trial and return its result.
47    ///
48    /// The implementation is responsible for measuring wall-clock duration and
49    /// encoding it in the returned [`TrialResult`].
50    async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult>;
51}
52
53/// A minimal no-op evaluation case useful for unit-testing the evaluation
54/// infrastructure itself.
55pub struct AlwaysPassCase {
56    /// Short identifier for this case.
57    pub name: String,
58    /// Category label for grouping.
59    pub category: String,
60    /// Simulated duration in milliseconds returned by each trial.
61    pub duration_ms: u64,
62}
63
64impl AlwaysPassCase {
65    /// Create a new always-passing case with the given name.
66    pub fn new(name: impl Into<String>) -> Self {
67        Self {
68            name: name.into(),
69            category: "test".into(),
70            duration_ms: 0,
71        }
72    }
73
74    /// Set the simulated duration in milliseconds for each trial.
75    pub fn with_duration(mut self, ms: u64) -> Self {
76        self.duration_ms = ms;
77        self
78    }
79}
80
81#[async_trait]
82impl EvaluationCase for AlwaysPassCase {
83    fn name(&self) -> &str {
84        &self.name
85    }
86    fn category(&self) -> &str {
87        &self.category
88    }
89    async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
90        Ok(TrialResult::success(trial_id, self.duration_ms))
91    }
92}
93
94/// A no-op evaluation case that always fails — useful for testing failure paths.
95pub struct AlwaysFailCase {
96    /// Short identifier for this case.
97    pub name: String,
98    /// Category label for grouping.
99    pub category: String,
100    /// Error message returned by each trial.
101    pub error_msg: String,
102}
103
104impl AlwaysFailCase {
105    /// Create a new always-failing case with the given name and error message.
106    pub fn new(name: impl Into<String>, error: impl Into<String>) -> Self {
107        Self {
108            name: name.into(),
109            category: "test".into(),
110            error_msg: error.into(),
111        }
112    }
113}
114
115#[async_trait]
116impl EvaluationCase for AlwaysFailCase {
117    fn name(&self) -> &str {
118        &self.name
119    }
120    fn category(&self) -> &str {
121        &self.category
122    }
123    async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
124        Ok(TrialResult::failure(trial_id, 0, self.error_msg.clone()))
125    }
126}
127
128/// A case that succeeds with a configurable probability (for testing statistics).
129pub struct StochasticCase {
130    /// Short identifier for this case.
131    pub name: String,
132    /// Probability of success per trial (0.0-1.0).
133    pub success_rate: f64,
134}
135
136impl StochasticCase {
137    /// Create a new stochastic case with the given name and success probability.
138    pub fn new(name: impl Into<String>, success_rate: f64) -> Self {
139        Self {
140            name: name.into(),
141            success_rate: success_rate.clamp(0.0, 1.0),
142        }
143    }
144}
145
146#[async_trait]
147impl EvaluationCase for StochasticCase {
148    fn name(&self) -> &str {
149        &self.name
150    }
151    fn category(&self) -> &str {
152        "stochastic"
153    }
154    async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
155        // Deterministic per trial_id so tests are reproducible.
156        // Uses a simple LCG hash: seed = trial_id * prime, mapped to [0, 1).
157        let seed = (trial_id as u64)
158            .wrapping_mul(6364136223846793005)
159            .wrapping_add(1442695040888963407);
160        // Map the full u64 range to [0, 1) uniformly.
161        let norm = seed as f64 / u64::MAX as f64;
162        if norm < self.success_rate {
163            Ok(TrialResult::success(trial_id, 1))
164        } else {
165            Ok(TrialResult::failure(trial_id, 1, "stochastic failure"))
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[tokio::test]
175    async fn test_always_pass_case() {
176        let case = AlwaysPassCase::new("test").with_duration(5);
177        let result = case.run(0).await.unwrap();
178        assert!(result.success);
179        assert_eq!(result.trial_id, 0);
180        assert_eq!(result.duration_ms, 5);
181    }
182
183    #[tokio::test]
184    async fn test_always_fail_case() {
185        let case = AlwaysFailCase::new("test", "oops");
186        let result = case.run(3).await.unwrap();
187        assert!(!result.success);
188        assert_eq!(result.trial_id, 3);
189        assert_eq!(result.error.as_deref(), Some("oops"));
190    }
191
192    #[tokio::test]
193    async fn test_stochastic_case_reproducible() {
194        let case = StochasticCase::new("test", 0.7);
195        let r1 = case.run(42).await.unwrap();
196        let r2 = case.run(42).await.unwrap();
197        assert_eq!(
198            r1.success, r2.success,
199            "same trial_id must give same result"
200        );
201    }
202
203    #[tokio::test]
204    async fn test_stochastic_case_rate() {
205        let case = StochasticCase::new("test", 0.6);
206        let mut successes = 0usize;
207        for i in 0..200 {
208            if case.run(i).await.unwrap().success {
209                successes += 1;
210            }
211        }
212        // Allow ±15 % variance around the expected ~120 successes
213        assert!(
214            successes > 90 && successes < 170,
215            "expected ~120 successes, got {}",
216            successes
217        );
218    }
219}