1use crate::system::{SdeOptions, SdeResult, SdeSolver, SdeSystem};
10use numra_core::Scalar;
11use rayon::prelude::*;
12
13#[derive(Clone, Debug)]
15pub struct EnsembleResult<S: Scalar> {
16 pub trajectories: Vec<SdeResult<S>>,
18 pub n_success: usize,
20 pub n_failed: usize,
22 pub seeds: Vec<u64>,
24}
25
26impl<S: Scalar> EnsembleResult<S> {
27 pub fn new(trajectories: Vec<SdeResult<S>>, seeds: Vec<u64>) -> Self {
29 let n_success = trajectories.iter().filter(|r| r.success).count();
30 let n_failed = trajectories.len() - n_success;
31 Self {
32 trajectories,
33 n_success,
34 n_failed,
35 seeds,
36 }
37 }
38
39 pub fn final_values(&self, component: usize) -> Vec<Option<S>> {
43 self.trajectories
44 .iter()
45 .map(|r| {
46 r.y_final().map(|y| {
47 if component < y.len() {
48 y[component]
49 } else {
50 S::NAN
51 }
52 })
53 })
54 .collect()
55 }
56
57 pub fn successful_final_values(&self, component: usize) -> Vec<S> {
59 self.trajectories
60 .iter()
61 .filter_map(|r| {
62 if r.success {
63 r.y_final().map(|y| y[component])
64 } else {
65 None
66 }
67 })
68 .collect()
69 }
70
71 pub fn successful(&self) -> impl Iterator<Item = &SdeResult<S>> {
73 self.trajectories.iter().filter(|r| r.success)
74 }
75
76 pub fn get(&self, index: usize) -> Option<&SdeResult<S>> {
78 self.trajectories.get(index)
79 }
80
81 pub fn len(&self) -> usize {
83 self.trajectories.len()
84 }
85
86 pub fn is_empty(&self) -> bool {
88 self.trajectories.is_empty()
89 }
90}
91
92pub struct EnsembleRunner;
94
95impl EnsembleRunner {
96 pub fn run<S, Sys, Solver>(
109 system: &Sys,
110 t0: S,
111 tf: S,
112 x0: &[S],
113 options: &SdeOptions<S>,
114 n_trajectories: usize,
115 ) -> EnsembleResult<S>
116 where
117 S: Scalar + Send + Sync,
118 Sys: SdeSystem<S> + Sync,
119 Solver: SdeSolver<S>,
120 {
121 let base_seed = options.seed.unwrap_or(0);
123 let seeds: Vec<u64> = (0..n_trajectories)
124 .map(|i| base_seed.wrapping_add(i as u64))
125 .collect();
126
127 let results: Vec<SdeResult<S>> = seeds
129 .par_iter()
130 .map(|&seed| {
131 Solver::solve(system, t0, tf, x0, options, Some(seed))
132 .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
133 })
134 .collect();
135
136 EnsembleResult::new(results, seeds)
137 }
138
139 pub fn run_with_seeds<S, Sys, Solver>(
143 system: &Sys,
144 t0: S,
145 tf: S,
146 x0: &[S],
147 options: &SdeOptions<S>,
148 seeds: &[u64],
149 ) -> EnsembleResult<S>
150 where
151 S: Scalar + Send + Sync,
152 Sys: SdeSystem<S> + Sync,
153 Solver: SdeSolver<S>,
154 {
155 let results: Vec<SdeResult<S>> = seeds
156 .par_iter()
157 .map(|&seed| {
158 Solver::solve(system, t0, tf, x0, options, Some(seed))
159 .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
160 })
161 .collect();
162
163 EnsembleResult::new(results, seeds.to_vec())
164 }
165
166 pub fn run_sequential<S, Sys, Solver>(
168 system: &Sys,
169 t0: S,
170 tf: S,
171 x0: &[S],
172 options: &SdeOptions<S>,
173 n_trajectories: usize,
174 ) -> EnsembleResult<S>
175 where
176 S: Scalar,
177 Sys: SdeSystem<S>,
178 Solver: SdeSolver<S>,
179 {
180 let base_seed = options.seed.unwrap_or(0);
181 let seeds: Vec<u64> = (0..n_trajectories)
182 .map(|i| base_seed.wrapping_add(i as u64))
183 .collect();
184
185 let results: Vec<SdeResult<S>> = seeds
186 .iter()
187 .map(|&seed| {
188 Solver::solve(system, t0, tf, x0, options, Some(seed))
189 .unwrap_or_else(|msg| SdeResult::failed(msg, Default::default()))
190 })
191 .collect();
192
193 EnsembleResult::new(results, seeds)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::{EulerMaruyama, SdeSystem};
201
202 #[allow(clippy::upper_case_acronyms)]
203 struct GBM {
204 mu: f64,
205 sigma: f64,
206 }
207
208 impl SdeSystem<f64> for GBM {
209 fn dim(&self) -> usize {
210 1
211 }
212 fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
213 f[0] = self.mu * x[0];
214 }
215 fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
216 g[0] = self.sigma * x[0];
217 }
218 }
219
220 #[test]
221 fn test_ensemble_parallel() {
222 let gbm = GBM {
223 mu: 0.05,
224 sigma: 0.2,
225 };
226 let options = SdeOptions::default().dt(0.01).seed(42);
227
228 let result =
229 EnsembleRunner::run::<_, _, EulerMaruyama>(&gbm, 0.0, 1.0, &[100.0], &options, 100);
230
231 assert_eq!(result.len(), 100);
232 assert_eq!(result.n_success, 100);
233 assert_eq!(result.n_failed, 0);
234
235 let finals = result.successful_final_values(0);
237 assert_eq!(finals.len(), 100);
238 for &price in &finals {
239 assert!(price > 0.0);
240 }
241 }
242
243 #[test]
244 fn test_ensemble_sequential() {
245 let gbm = GBM {
246 mu: 0.05,
247 sigma: 0.2,
248 };
249 let options = SdeOptions::default().dt(0.01).seed(42);
250
251 let result = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
252 &gbm,
253 0.0,
254 1.0,
255 &[100.0],
256 &options,
257 10,
258 );
259
260 assert_eq!(result.len(), 10);
261 assert_eq!(result.n_success, 10);
262 }
263
264 #[test]
265 fn test_ensemble_reproducibility() {
266 let gbm = GBM {
267 mu: 0.05,
268 sigma: 0.2,
269 };
270 let options = SdeOptions::default().dt(0.01).seed(12345);
271
272 let r1 = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
274 &gbm,
275 0.0,
276 1.0,
277 &[100.0],
278 &options,
279 5,
280 );
281 let r2 = EnsembleRunner::run_sequential::<_, _, EulerMaruyama>(
282 &gbm,
283 0.0,
284 1.0,
285 &[100.0],
286 &options,
287 5,
288 );
289
290 for i in 0..5 {
292 let y1 = r1.get(i).unwrap().y_final().unwrap()[0];
293 let y2 = r2.get(i).unwrap().y_final().unwrap()[0];
294 assert!((y1 - y2).abs() < 1e-10);
295 }
296 }
297
298 #[test]
299 fn test_ensemble_statistics_sample() {
300 let gbm = GBM {
301 mu: 0.05,
302 sigma: 0.2,
303 };
304 let options = SdeOptions::default().dt(0.001).seed(0);
305
306 let result =
308 EnsembleRunner::run::<_, _, EulerMaruyama>(&gbm, 0.0, 1.0, &[100.0], &options, 1000);
309
310 let finals = result.successful_final_values(0);
311
312 let mean: f64 = finals.iter().sum::<f64>() / finals.len() as f64;
314 let variance: f64 =
315 finals.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / finals.len() as f64;
316
317 let s0 = 100.0;
320 let expected_mean = s0 * (0.05 * 1.0_f64).exp(); let expected_var =
322 s0 * s0 * (2.0 * 0.05 * 1.0_f64).exp() * ((0.2 * 0.2 * 1.0_f64).exp() - 1.0); let se_mean = (variance / finals.len() as f64).sqrt();
326 assert!((mean - expected_mean).abs() < 3.0 * se_mean);
327
328 assert!((variance - expected_var).abs() < expected_var * 0.2); }
331}