use async_trait::async_trait;
use super::trial::TrialResult;
#[async_trait]
pub trait EvaluationCase: Send + Sync {
fn name(&self) -> &str;
fn category(&self) -> &str;
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult>;
}
pub struct AlwaysPassCase {
pub name: String,
pub category: String,
pub duration_ms: u64,
}
impl AlwaysPassCase {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
category: "test".into(),
duration_ms: 0,
}
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = ms;
self
}
}
#[async_trait]
impl EvaluationCase for AlwaysPassCase {
fn name(&self) -> &str {
&self.name
}
fn category(&self) -> &str {
&self.category
}
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
Ok(TrialResult::success(trial_id, self.duration_ms))
}
}
pub struct AlwaysFailCase {
pub name: String,
pub category: String,
pub error_msg: String,
}
impl AlwaysFailCase {
pub fn new(name: impl Into<String>, error: impl Into<String>) -> Self {
Self {
name: name.into(),
category: "test".into(),
error_msg: error.into(),
}
}
}
#[async_trait]
impl EvaluationCase for AlwaysFailCase {
fn name(&self) -> &str {
&self.name
}
fn category(&self) -> &str {
&self.category
}
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
Ok(TrialResult::failure(trial_id, 0, self.error_msg.clone()))
}
}
pub struct StochasticCase {
pub name: String,
pub success_rate: f64,
}
impl StochasticCase {
pub fn new(name: impl Into<String>, success_rate: f64) -> Self {
Self {
name: name.into(),
success_rate: success_rate.clamp(0.0, 1.0),
}
}
}
#[async_trait]
impl EvaluationCase for StochasticCase {
fn name(&self) -> &str {
&self.name
}
fn category(&self) -> &str {
"stochastic"
}
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
let seed = (trial_id as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let norm = seed as f64 / u64::MAX as f64;
if norm < self.success_rate {
Ok(TrialResult::success(trial_id, 1))
} else {
Ok(TrialResult::failure(trial_id, 1, "stochastic failure"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_always_pass_case() {
let case = AlwaysPassCase::new("test").with_duration(5);
let result = case.run(0).await.unwrap();
assert!(result.success);
assert_eq!(result.trial_id, 0);
assert_eq!(result.duration_ms, 5);
}
#[tokio::test]
async fn test_always_fail_case() {
let case = AlwaysFailCase::new("test", "oops");
let result = case.run(3).await.unwrap();
assert!(!result.success);
assert_eq!(result.trial_id, 3);
assert_eq!(result.error.as_deref(), Some("oops"));
}
#[tokio::test]
async fn test_stochastic_case_reproducible() {
let case = StochasticCase::new("test", 0.7);
let r1 = case.run(42).await.unwrap();
let r2 = case.run(42).await.unwrap();
assert_eq!(
r1.success, r2.success,
"same trial_id must give same result"
);
}
#[tokio::test]
async fn test_stochastic_case_rate() {
let case = StochasticCase::new("test", 0.6);
let mut successes = 0usize;
for i in 0..200 {
if case.run(i).await.unwrap().success {
successes += 1;
}
}
assert!(
successes > 90 && successes < 170,
"expected ~120 successes, got {}",
successes
);
}
}