rustsim 0.0.1

High-performance agent-based modelling engine - top-level orchestration crate
Documentation
//! Ensemble runner for parameter sweeps and batch simulations.
//!
//! Inspired by FlameGPU2's [`CUDAEnsemble`](https://docs.flamegpu.com/api/classflamegpu_1_1CUDAEnsemble.html),
//! which runs many independent simulation instances concurrently across
//! multiple GPU devices.
//!
//! rustsim's [`Ensemble`] is framework-agnostic: it uses standard Rust threads
//! to run simulations in parallel on CPU, making it work on any platform
//! with or without GPU hardware.
//!
//! # Usage
//!
//! ```ignore
//! let plans: Vec<RunPlan> = (0..100).map(|i| RunPlan {
//!     id: i,
//!     seed: 42 + i as u64,
//!     steps: 1000,
//!     params: MyParams { temperature: 0.1 * i as f64 },
//! }).collect();
//!
//! let results = Ensemble::new()
//!     .concurrency(8)
//!     .run(plans, |plan| {
//!         // Build and run one simulation, return summary
//!         let model = build_model(plan.seed, plan.params);
//!         model.step_n(plan.steps);
//!         extract_metrics(&model)
//!     });
//! ```

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use tracing::{debug, info, warn};

/// A plan for a single simulation run within an ensemble.
///
/// Users define their own parameter type `P` and pass it through
/// the `RunPlan`.
#[derive(Debug, Clone)]
pub struct RunPlan<P> {
    /// Unique identifier for this run (used for result mapping).
    pub id: usize,
    /// RNG seed for deterministic reproduction.
    pub seed: u64,
    /// Number of simulation steps to execute.
    pub steps: usize,
    /// User-defined parameters for this run.
    pub params: P,
}

/// Result of a single ensemble run.
#[derive(Debug)]
pub struct RunResult<R> {
    /// The plan ID that produced this result.
    pub plan_id: usize,
    /// The simulation output (user-defined).
    pub result: Result<R, RunError>,
    /// Wall-clock time for this run in milliseconds.
    pub elapsed_ms: u128,
}

/// Error from a single ensemble run.
#[derive(Debug)]
pub struct RunError {
    /// Human-readable error message.
    pub message: String,
}

impl std::fmt::Display for RunError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.message)
    }
}

impl std::error::Error for RunError {}

/// Configuration for error handling behavior.
///
/// Mirrors FlameGPU2's `CUDAEnsemble::EnsembleConfig::ErrorLevel`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorLevel {
    /// Failed runs do not cause an error; check results manually.
    Off,
    /// Collect all results, then panic if any failed.
    Slow,
    /// Panic immediately when a run fails (cancels remaining runs).
    Fast,
}

/// Ensemble runner for executing many simulations concurrently.
///
/// Mirrors FlameGPU2's `CUDAEnsemble`, adapted for Rust's thread model.
/// Each simulation runs in its own OS thread. The runner distributes
/// plans across a configurable number of concurrent threads.
pub struct Ensemble {
    concurrency: usize,
    error_level: ErrorLevel,
}

impl Ensemble {
    /// Create a new ensemble runner with default settings.
    ///
    /// Defaults:
    /// - `concurrency`: number of available CPU cores (or 4 if detection fails)
    /// - `error_level`: [`ErrorLevel::Slow`]
    pub fn new() -> Self {
        let concurrency = thread::available_parallelism()
            .map(|p| p.get())
            .unwrap_or(4);
        Self {
            concurrency,
            error_level: ErrorLevel::Slow,
        }
    }

    /// Set the maximum number of concurrent simulation threads.
    pub fn concurrency(mut self, n: usize) -> Self {
        assert!(n > 0, "concurrency must be at least 1");
        self.concurrency = n;
        self
    }

    /// Set the error handling behavior.
    pub fn error_level(mut self, level: ErrorLevel) -> Self {
        self.error_level = level;
        self
    }

    /// Execute all plans and collect results.
    ///
    /// The `runner` closure receives a `RunPlan` and returns the simulation
    /// output. Each plan is executed in a separate thread, with at most
    /// `concurrency` threads active at a time.
    ///
    /// # Panics
    ///
    /// - With `ErrorLevel::Fast`: panics as soon as any run fails.
    /// - With `ErrorLevel::Slow`: panics after all runs complete if any failed.
    /// - With `ErrorLevel::Off`: never panics (check results manually).
    pub fn run<P, R, F>(self, plans: Vec<RunPlan<P>>, runner: F) -> Vec<RunResult<R>>
    where
        P: Send + 'static,
        R: Send + 'static,
        F: Fn(RunPlan<P>) -> R + Send + Sync + 'static,
    {
        let total = plans.len();
        info!(
            total_plans = total,
            concurrency = self.concurrency,
            "ensemble starting"
        );

        let t0 = std::time::Instant::now();
        let runner = Arc::new(runner);
        let completed = Arc::new(AtomicUsize::new(0));
        let failed = Arc::new(AtomicUsize::new(0));

        // Partition plans into chunks for thread distribution
        let mut results: Vec<RunResult<R>> = Vec::with_capacity(total);
        let mut plan_queue: Vec<RunPlan<P>> = plans;

        // Process in batches of `concurrency` size
        while !plan_queue.is_empty() {
            let batch_size = plan_queue.len().min(self.concurrency);
            let batch: Vec<_> = plan_queue.drain(..batch_size).collect();

            let handles: Vec<_> = batch
                .into_iter()
                .map(|plan| {
                    let runner = Arc::clone(&runner);
                    let completed = Arc::clone(&completed);
                    let failed = Arc::clone(&failed);
                    let plan_id = plan.id;

                    thread::spawn(move || {
                        let t_run = std::time::Instant::now();
                        let result =
                            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| runner(plan)));
                        let elapsed_ms = t_run.elapsed().as_millis();

                        let c = completed.fetch_add(1, Ordering::Relaxed) + 1;

                        match result {
                            Ok(r) => {
                                debug!(plan_id, elapsed_ms, progress = c, "run completed");
                                RunResult {
                                    plan_id,
                                    result: Ok(r),
                                    elapsed_ms,
                                }
                            }
                            Err(e) => {
                                let msg = if let Some(s) = e.downcast_ref::<String>() {
                                    s.clone()
                                } else if let Some(s) = e.downcast_ref::<&str>() {
                                    s.to_string()
                                } else {
                                    "unknown panic".to_string()
                                };
                                failed.fetch_add(1, Ordering::Relaxed);
                                warn!(plan_id, error = %msg, "run failed");
                                RunResult {
                                    plan_id,
                                    result: Err(RunError { message: msg }),
                                    elapsed_ms,
                                }
                            }
                        }
                    })
                })
                .collect();

            for handle in handles {
                let result = handle.join().expect("thread panicked unexpectedly");
                if result.result.is_err() && self.error_level == ErrorLevel::Fast {
                    let msg = format!("ensemble run {} failed", result.plan_id,);
                    results.push(result);
                    panic!("{}", msg);
                }
                results.push(result);
            }
        }

        let total_ms = t0.elapsed().as_millis();
        let fail_count = failed.load(Ordering::Relaxed);

        info!(
            total_plans = total,
            completed = completed.load(Ordering::Relaxed),
            failed = fail_count,
            total_ms,
            "ensemble completed"
        );

        if self.error_level == ErrorLevel::Slow && fail_count > 0 {
            panic!(
                "ensemble: {fail_count} of {total} runs failed. \
                 Set ErrorLevel::Off to suppress this panic."
            );
        }

        results
    }
}

impl Default for Ensemble {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn basic_ensemble() {
        let plans: Vec<RunPlan<f64>> = (0..10)
            .map(|i| RunPlan {
                id: i,
                seed: 42 + i as u64,
                steps: 100,
                params: 0.1 * i as f64,
            })
            .collect();

        let results = Ensemble::new()
            .concurrency(4)
            .error_level(ErrorLevel::Off)
            .run(plans, |plan| {
                // Trivial simulation: just return seed * params
                plan.seed as f64 * plan.params
            });

        assert_eq!(results.len(), 10);
        for r in &results {
            assert!(r.result.is_ok());
        }
    }

    #[test]
    fn ensemble_with_failure() {
        let plans: Vec<RunPlan<bool>> = vec![
            RunPlan {
                id: 0,
                seed: 1,
                steps: 10,
                params: false,
            },
            RunPlan {
                id: 1,
                seed: 2,
                steps: 10,
                params: true,
            }, // will "fail"
            RunPlan {
                id: 2,
                seed: 3,
                steps: 10,
                params: false,
            },
        ];

        let results = Ensemble::new()
            .concurrency(2)
            .error_level(ErrorLevel::Off)
            .run(plans, |plan| {
                if plan.params {
                    panic!("intentional failure");
                }
                plan.seed
            });

        assert_eq!(results.len(), 3);
        let failures: Vec<_> = results.iter().filter(|r| r.result.is_err()).collect();
        assert_eq!(failures.len(), 1);
        assert_eq!(failures[0].plan_id, 1);
    }
}