use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct RunPlan<P> {
pub id: usize,
pub seed: u64,
pub steps: usize,
pub params: P,
}
#[derive(Debug)]
pub struct RunResult<R> {
pub plan_id: usize,
pub result: Result<R, RunError>,
pub elapsed_ms: u128,
}
#[derive(Debug)]
pub struct RunError {
pub message: String,
}
impl std::fmt::Display for RunError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for RunError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorLevel {
Off,
Slow,
Fast,
}
pub struct Ensemble {
concurrency: usize,
error_level: ErrorLevel,
}
impl Ensemble {
pub fn new() -> Self {
let concurrency = thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4);
Self {
concurrency,
error_level: ErrorLevel::Slow,
}
}
pub fn concurrency(mut self, n: usize) -> Self {
assert!(n > 0, "concurrency must be at least 1");
self.concurrency = n;
self
}
pub fn error_level(mut self, level: ErrorLevel) -> Self {
self.error_level = level;
self
}
pub fn run<P, R, F>(self, plans: Vec<RunPlan<P>>, runner: F) -> Vec<RunResult<R>>
where
P: Send + 'static,
R: Send + 'static,
F: Fn(RunPlan<P>) -> R + Send + Sync + 'static,
{
let total = plans.len();
info!(
total_plans = total,
concurrency = self.concurrency,
"ensemble starting"
);
let t0 = std::time::Instant::now();
let runner = Arc::new(runner);
let completed = Arc::new(AtomicUsize::new(0));
let failed = Arc::new(AtomicUsize::new(0));
let mut results: Vec<RunResult<R>> = Vec::with_capacity(total);
let mut plan_queue: Vec<RunPlan<P>> = plans;
while !plan_queue.is_empty() {
let batch_size = plan_queue.len().min(self.concurrency);
let batch: Vec<_> = plan_queue.drain(..batch_size).collect();
let handles: Vec<_> = batch
.into_iter()
.map(|plan| {
let runner = Arc::clone(&runner);
let completed = Arc::clone(&completed);
let failed = Arc::clone(&failed);
let plan_id = plan.id;
thread::spawn(move || {
let t_run = std::time::Instant::now();
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| runner(plan)));
let elapsed_ms = t_run.elapsed().as_millis();
let c = completed.fetch_add(1, Ordering::Relaxed) + 1;
match result {
Ok(r) => {
debug!(plan_id, elapsed_ms, progress = c, "run completed");
RunResult {
plan_id,
result: Ok(r),
elapsed_ms,
}
}
Err(e) => {
let msg = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else {
"unknown panic".to_string()
};
failed.fetch_add(1, Ordering::Relaxed);
warn!(plan_id, error = %msg, "run failed");
RunResult {
plan_id,
result: Err(RunError { message: msg }),
elapsed_ms,
}
}
}
})
})
.collect();
for handle in handles {
let result = handle.join().expect("thread panicked unexpectedly");
if result.result.is_err() && self.error_level == ErrorLevel::Fast {
let msg = format!("ensemble run {} failed", result.plan_id,);
results.push(result);
panic!("{}", msg);
}
results.push(result);
}
}
let total_ms = t0.elapsed().as_millis();
let fail_count = failed.load(Ordering::Relaxed);
info!(
total_plans = total,
completed = completed.load(Ordering::Relaxed),
failed = fail_count,
total_ms,
"ensemble completed"
);
if self.error_level == ErrorLevel::Slow && fail_count > 0 {
panic!(
"ensemble: {fail_count} of {total} runs failed. \
Set ErrorLevel::Off to suppress this panic."
);
}
results
}
}
impl Default for Ensemble {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_ensemble() {
let plans: Vec<RunPlan<f64>> = (0..10)
.map(|i| RunPlan {
id: i,
seed: 42 + i as u64,
steps: 100,
params: 0.1 * i as f64,
})
.collect();
let results = Ensemble::new()
.concurrency(4)
.error_level(ErrorLevel::Off)
.run(plans, |plan| {
plan.seed as f64 * plan.params
});
assert_eq!(results.len(), 10);
for r in &results {
assert!(r.result.is_ok());
}
}
#[test]
fn ensemble_with_failure() {
let plans: Vec<RunPlan<bool>> = vec![
RunPlan {
id: 0,
seed: 1,
steps: 10,
params: false,
},
RunPlan {
id: 1,
seed: 2,
steps: 10,
params: true,
}, RunPlan {
id: 2,
seed: 3,
steps: 10,
params: false,
},
];
let results = Ensemble::new()
.concurrency(2)
.error_level(ErrorLevel::Off)
.run(plans, |plan| {
if plan.params {
panic!("intentional failure");
}
plan.seed
});
assert_eq!(results.len(), 3);
let failures: Vec<_> = results.iter().filter(|r| r.result.is_err()).collect();
assert_eq!(failures.len(), 1);
assert_eq!(failures[0].plan_id, 1);
}
}