Skip to main content

numra_sde/
ensemble.rs

1//! Ensemble runner for Monte Carlo SDE simulations.
2//!
3//! Provides parallel execution of multiple SDE trajectories using rayon.
4//!
5//! Author: Moussa Leblouba
6//! Date: 4 February 2026
7//! Modified: 2 May 2026
8
9use crate::system::{SdeOptions, SdeResult, SdeSolver, SdeSystem};
10use numra_core::Scalar;
11use rayon::prelude::*;
12
13/// Result of ensemble simulation.
14#[derive(Clone, Debug)]
15pub struct EnsembleResult<S: Scalar> {
16    /// Individual trajectory results
17    pub trajectories: Vec<SdeResult<S>>,
18    /// Number of successful simulations
19    pub n_success: usize,
20    /// Number of failed simulations
21    pub n_failed: usize,
22    /// Seeds used for each trajectory (for reproducibility)
23    pub seeds: Vec<u64>,
24}
25
26impl<S: Scalar> EnsembleResult<S> {
27    /// Create a new ensemble result.
28    pub fn new(trajectories: Vec<SdeResult<S>>, seeds: Vec<u64>) -> Self {
29        let n_success = trajectories.iter().filter(|r| r.success).count();
30        let n_failed = trajectories.len() - n_success;
31        Self {
32            trajectories,
33            n_success,
34            n_failed,
35            seeds,
36        }
37    }
38
39    /// Get all final values for a specific component.
40    ///
41    /// Returns None for failed trajectories.
42    pub fn final_values(&self, component: usize) -> Vec<Option<S>> {
43        self.trajectories
44            .iter()
45            .map(|r| {
46                r.y_final().map(|y| {
47                    if component < y.len() {
48                        y[component]
49                    } else {
50                        S::NAN
51                    }
52                })
53            })
54            .collect()
55    }
56
57    /// Get all successful final values for a component.
58    pub fn successful_final_values(&self, component: usize) -> Vec<S> {
59        self.trajectories
60            .iter()
61            .filter_map(|r| {
62                if r.success {
63                    r.y_final().map(|y| y[component])
64                } else {
65                    None
66                }
67            })
68            .collect()
69    }
70
71    /// Iterate over successful trajectories.
72    pub fn successful(&self) -> impl Iterator<Item = &SdeResult<S>> {
73        self.trajectories.iter().filter(|r| r.success)
74    }
75
76    /// Get trajectory at index.
77    pub fn get(&self, index: usize) -> Option<&SdeResult<S>> {
78        self.trajectories.get(index)
79    }
80
81    /// Number of trajectories.
82    pub fn len(&self) -> usize {
83        self.trajectories.len()
84    }
85
86    /// Is the ensemble empty?
87    pub fn is_empty(&self) -> bool {
88        self.trajectories.is_empty()
89    }
90}
91
92/// Ensemble runner for parallel Monte Carlo simulations.
93pub struct EnsembleRunner;
94
95impl EnsembleRunner {
96    /// Run an ensemble of SDE simulations in parallel.
97    ///
98    /// # Arguments
99    /// * `system` - The SDE system to simulate
100    /// * `t0` - Initial time
101    /// * `tf` - Final time
102    /// * `x0` - Initial state
103    /// * `options` - Solver options (seed is used as base seed)
104    /// * `n_trajectories` - Number of Monte Carlo paths
105    ///
106    /// # Returns
107    /// An `EnsembleResult` containing all trajectory results.
108    pub fn run<S, Sys, Solver>(
109        system: &Sys,
110        t0: S,
111        tf: S,
112        x0: &[S],
113        options: &SdeOptions<S>,
114        n_trajectories: usize,
115    ) -> EnsembleResult<S>
116    where
117        S: Scalar + Send + Sync,
118        Sys: SdeSystem<S> + Sync,
119        Solver: SdeSolver<S>,
120    {
121        // Generate seeds for each trajectory
122        let base_seed = options.seed.unwrap_or(0);
123        let seeds: Vec<u64> = (0..n_trajectories)
124            .map(|i| base_seed.wrapping_add(i as u64))
125            .collect();
126
127        // Run simulations in parallel
128        let results: Vec<SdeResult<S>> = seeds
129            .par_iter()
130            .map(|&seed| {
131                Solver::solve(system, t0, tf, x0, options, Some(seed))
132                    .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
133            })
134            .collect();
135
136        EnsembleResult::new(results, seeds)
137    }
138
139    /// Run ensemble with custom seeds.
140    ///
141    /// Useful for resuming simulations or specific reproducibility needs.
142    pub fn run_with_seeds<S, Sys, Solver>(
143        system: &Sys,
144        t0: S,
145        tf: S,
146        x0: &[S],
147        options: &SdeOptions<S>,
148        seeds: &[u64],
149    ) -> EnsembleResult<S>
150    where
151        S: Scalar + Send + Sync,
152        Sys: SdeSystem<S> + Sync,
153        Solver: SdeSolver<S>,
154    {
155        let results: Vec<SdeResult<S>> = seeds
156            .par_iter()
157            .map(|&seed| {
158                Solver::solve(system, t0, tf, x0, options, Some(seed))
159                    .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
160            })
161            .collect();
162
163        EnsembleResult::new(results, seeds.to_vec())
164    }
165
166    /// Run ensemble sequentially (for debugging or when parallelism isn't needed).
167    pub fn run_sequential<S, Sys, Solver>(
168        system: &Sys,
169        t0: S,
170        tf: S,
171        x0: &[S],
172        options: &SdeOptions<S>,
173        n_trajectories: usize,
174    ) -> EnsembleResult<S>
175    where
176        S: Scalar,
177        Sys: SdeSystem<S>,
178        Solver: SdeSolver<S>,
179    {
180        let base_seed = options.seed.unwrap_or(0);
181        let seeds: Vec<u64> = (0..n_trajectories)
182            .map(|i| base_seed.wrapping_add(i as u64))
183            .collect();
184
185        let results: Vec<SdeResult<S>> = seeds
186            .iter()
187            .map(|&seed| {
188                Solver::solve(system, t0, tf, x0, options, Some(seed))
189                    .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
190            })
191            .collect();
192
193        EnsembleResult::new(results, seeds)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::{EulerMaruyama, SdeSystem};
201
202    #[allow(clippy::upper_case_acronyms)]
203    struct GBM {
204        mu: f64,
205        sigma: f64,
206    }
207
208    impl SdeSystem<f64> for GBM {
209        fn dim(&self) -> usize {
210            1
211        }
212        fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
213            f[0] = self.mu * x[0];
214        }
215        fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
216            g[0] = self.sigma * x[0];
217        }
218    }
219
220    #[test]
221    fn test_ensemble_parallel() {
222        let gbm = GBM {
223            mu: 0.05,
224            sigma: 0.2,
225        };
226        let options = SdeOptions::default().dt(0.01).seed(42);
227
228        let result =
229            EnsembleRunner::run::<_, _, EulerMaruyama>(&gbm, 0.0, 1.0, &[100.0], &options, 100);
230
231        assert_eq!(result.len(), 100);
232        assert_eq!(result.n_success, 100);
233        assert_eq!(result.n_failed, 0);
234
235        // All final prices should be positive
236        let finals = result.successful_final_values(0);
237        assert_eq!(finals.len(), 100);
238        for &price in &finals {
239            assert!(price > 0.0);
240        }
241    }
242
243    #[test]
244    fn test_ensemble_sequential() {
245        let gbm = GBM {
246            mu: 0.05,
247            sigma: 0.2,
248        };
249        let options = SdeOptions::default().dt(0.01).seed(42);
250
251        let result = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
252            &gbm,
253            0.0,
254            1.0,
255            &[100.0],
256            &options,
257            10,
258        );
259
260        assert_eq!(result.len(), 10);
261        assert_eq!(result.n_success, 10);
262    }
263
264    #[test]
265    fn test_ensemble_reproducibility() {
266        let gbm = GBM {
267            mu: 0.05,
268            sigma: 0.2,
269        };
270        let options = SdeOptions::default().dt(0.01).seed(12345);
271
272        // Run twice with same base seed
273        let r1 = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
274            &gbm,
275            0.0,
276            1.0,
277            &[100.0],
278            &options,
279            5,
280        );
281        let r2 = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
282            &gbm,
283            0.0,
284            1.0,
285            &[100.0],
286            &options,
287            5,
288        );
289
290        // Results should be identical
291        for i in 0..5 {
292            let y1 = r1.get(i).unwrap().y_final().unwrap()[0];
293            let y2 = r2.get(i).unwrap().y_final().unwrap()[0];
294            assert!((y1 - y2).abs() < 1e-10);
295        }
296    }
297
298    #[test]
299    fn test_ensemble_statistics_sample() {
300        let gbm = GBM {
301            mu: 0.05,
302            sigma: 0.2,
303        };
304        let options = SdeOptions::default().dt(0.001).seed(0);
305
306        // Large ensemble for statistical test
307        let result =
308            EnsembleRunner::run::<_, _, EulerMaruyama>(&gbm, 0.0, 1.0, &[100.0], &options, 1000);
309
310        let finals = result.successful_final_values(0);
311
312        // Compute sample mean and variance
313        let mean: f64 = finals.iter().sum::<f64>() / finals.len() as f64;
314        let variance: f64 =
315            finals.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / finals.len() as f64;
316
317        // For GBM: E[S_T] = S_0 * exp(μT)
318        // Var[S_T] = S_0² * exp(2μT) * (exp(σ²T) - 1)
319        let s0 = 100.0;
320        let expected_mean = s0 * (0.05 * 1.0_f64).exp(); // ~105.13
321        let expected_var =
322            s0 * s0 * (2.0 * 0.05 * 1.0_f64).exp() * ((0.2 * 0.2 * 1.0_f64).exp() - 1.0); // ~454.9
323
324        // Check within 3 standard errors
325        let se_mean = (variance / finals.len() as f64).sqrt();
326        assert!((mean - expected_mean).abs() < 3.0 * se_mean);
327
328        // Variance estimate should be in the right ballpark
329        assert!((variance - expected_var).abs() < expected_var * 0.2); // Within 20%
330    }
331}