1use crate::system::{NoiseType, SdeOptions, SdeResult, SdeSolver, SdeStats, SdeSystem};
13use crate::wiener::create_wiener;
14use numra_core::Scalar;
15
16pub struct Sra1;
21
22impl<S: Scalar> SdeSolver<S> for Sra1 {
23 fn solve<Sys: SdeSystem<S>>(
24 system: &Sys,
25 t0: S,
26 tf: S,
27 x0: &[S],
28 options: &SdeOptions<S>,
29 seed: Option<u64>,
30 ) -> Result<SdeResult<S>, String> {
31 let dim = system.dim();
32 if x0.len() != dim {
33 return Err(format!(
34 "Initial state dimension {} doesn't match system dimension {}",
35 x0.len(),
36 dim
37 ));
38 }
39
40 match system.noise_type() {
42 NoiseType::Diagonal | NoiseType::Scalar => {}
43 _ => return Err("SRA1 currently only supports diagonal or scalar noise".to_string()),
44 }
45
46 let n_wiener = system.n_wiener();
47 let actual_seed = seed.or(options.seed);
48 let mut wiener = create_wiener(n_wiener, actual_seed);
49
50 let mut t = t0;
52 let mut x = x0.to_vec();
53 let mut h = options.dt.min(options.dt_max);
54
55 let mut f1 = vec![S::ZERO; dim];
57 let mut f2 = vec![S::ZERO; dim];
58 let mut g1 = vec![S::ZERO; dim];
59 let mut g2 = vec![S::ZERO; dim];
60 let mut x_stage = vec![S::ZERO; dim];
61 let mut x_new = vec![S::ZERO; dim];
62 let mut x_err = vec![S::ZERO; dim];
63
64 let mut t_out = Vec::new();
65 let mut y_out = Vec::new();
66 let mut stats = SdeStats::default();
67
68 let safety = S::from_f64(0.9);
70 let fac_min = S::from_f64(0.2);
71 let fac_max = S::from_f64(5.0);
72 let order = S::from_f64(1.5); if options.save_trajectory {
76 t_out.push(t);
77 y_out.extend_from_slice(&x);
78 }
79
80 let half = S::from_f64(0.5);
81 let one = S::ONE;
82 let mut step = 0;
83
84 while t < tf && step < options.max_steps {
85 h = h.min(tf - t).min(options.dt_max).max(options.dt_min);
87
88 let dw = wiener.increment(h);
90 let sqrt_h = h.sqrt();
91
92 system.drift(t, &x, &mut f1);
94 system.diffusion(t, &x, &mut g1);
95 stats.n_drift += 1;
96 stats.n_diffusion += 1;
97
98 for i in 0..dim {
101 x_stage[i] = x[i] + f1[i] * h + g1[i] * sqrt_h;
102 }
103 system.drift(t + h, &x_stage, &mut f2);
104 system.diffusion(t + h, &x_stage, &mut g2);
105 stats.n_drift += 1;
106 stats.n_diffusion += 1;
107
108 let is_scalar = matches!(system.noise_type(), NoiseType::Scalar);
114
115 for i in 0..dim {
116 let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
117
118 x_new[i] = x[i] + half * (f1[i] + f2[i]) * h + half * (g1[i] + g2[i]) * dw_i;
120
121 x_err[i] = half * (f2[i] - f1[i]) * h + half * (g2[i] - g1[i]) * dw_i;
123 }
124
125 let mut err_sq = S::ZERO;
127 for i in 0..dim {
128 let scale = options.atol + options.rtol * x[i].abs().max(x_new[i].abs());
129 let ratio = x_err[i] / scale;
130 err_sq += ratio * ratio;
131 }
132 let err = (err_sq / S::from_usize(dim)).sqrt();
133
134 if err <= one {
136 t += h;
138 x[..dim].copy_from_slice(&x_new[..dim]);
139 stats.n_accept += 1;
140 step += 1;
141
142 if options.save_trajectory {
144 t_out.push(t);
145 y_out.extend_from_slice(&x);
146 }
147 } else {
148 stats.n_reject += 1;
150 }
151
152 let err_safe = err.max(S::from_f64(1e-10));
154 let fac = safety * err_safe.powf(-one / (order + one));
155 h *= fac.max(fac_min).min(fac_max);
156 }
157
158 if step >= options.max_steps && t < tf {
159 return Err(format!(
160 "Maximum steps ({}) exceeded at t = {}",
161 options.max_steps,
162 t.to_f64()
163 ));
164 }
165
166 if !options.save_trajectory {
168 t_out.push(t);
169 y_out.extend_from_slice(&x);
170 }
171
172 Ok(SdeResult::new(t_out, y_out, dim, stats))
173 }
174}
175
176pub struct Sra2;
181
182impl<S: Scalar> SdeSolver<S> for Sra2 {
183 fn solve<Sys: SdeSystem<S>>(
184 system: &Sys,
185 t0: S,
186 tf: S,
187 x0: &[S],
188 options: &SdeOptions<S>,
189 seed: Option<u64>,
190 ) -> Result<SdeResult<S>, String> {
191 let dim = system.dim();
192 if x0.len() != dim {
193 return Err(format!(
194 "Initial state dimension {} doesn't match system dimension {}",
195 x0.len(),
196 dim
197 ));
198 }
199
200 match system.noise_type() {
202 NoiseType::Diagonal | NoiseType::Scalar => {}
203 _ => return Err("SRA2 currently only supports diagonal or scalar noise".to_string()),
204 }
205
206 let n_wiener = system.n_wiener();
207 let actual_seed = seed.or(options.seed);
208 let mut wiener = create_wiener(n_wiener, actual_seed);
209
210 let mut t = t0;
212 let mut x = x0.to_vec();
213 let mut h = options.dt.min(options.dt_max);
214
215 let mut f1 = vec![S::ZERO; dim];
217 let mut f2 = vec![S::ZERO; dim];
218 let mut f3 = vec![S::ZERO; dim];
219 let mut g1 = vec![S::ZERO; dim];
220 let mut g2 = vec![S::ZERO; dim];
221 let mut x_stage = vec![S::ZERO; dim];
222 let mut x_new = vec![S::ZERO; dim];
223 let mut x_err = vec![S::ZERO; dim];
224
225 let mut t_out = Vec::new();
226 let mut y_out = Vec::new();
227 let mut stats = SdeStats::default();
228
229 let safety = S::from_f64(0.9);
231 let fac_min = S::from_f64(0.2);
232 let fac_max = S::from_f64(5.0);
233 let order = S::from_f64(2.0); let c2 = S::from_f64(2.0 / 3.0);
237 let a21 = S::from_f64(2.0 / 3.0);
238 let b1 = S::from_f64(0.25);
239 let b2 = S::from_f64(0.75);
240
241 if options.save_trajectory {
243 t_out.push(t);
244 y_out.extend_from_slice(&x);
245 }
246
247 let one = S::ONE;
248 let half = S::from_f64(0.5);
249 let mut step = 0;
250
251 while t < tf && step < options.max_steps {
252 h = h.min(tf - t).min(options.dt_max).max(options.dt_min);
253
254 let dw = wiener.increment(h);
256 let sqrt_h = h.sqrt();
257
258 system.drift(t, &x, &mut f1);
260 system.diffusion(t, &x, &mut g1);
261 stats.n_drift += 1;
262 stats.n_diffusion += 1;
263
264 let is_scalar = matches!(system.noise_type(), NoiseType::Scalar);
266 for i in 0..dim {
267 let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
268 x_stage[i] = x[i] + a21 * f1[i] * h + g1[i] * sqrt_h;
269 let _ = dw_i; }
271 system.drift(t + c2 * h, &x_stage, &mut f2);
272 system.diffusion(t + c2 * h, &x_stage, &mut g2);
273 stats.n_drift += 1;
274 stats.n_diffusion += 1;
275
276 for i in 0..dim {
278 x_stage[i] = x[i] + f1[i] * h;
279 }
280 system.drift(t + h, &x_stage, &mut f3);
281 stats.n_drift += 1;
282
283 for i in 0..dim {
285 let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
286
287 x_new[i] = x[i] + (b1 * f1[i] + b2 * f2[i]) * h + half * (g1[i] + g2[i]) * dw_i;
288
289 x_err[i] = (b2 * (f2[i] - f1[i]) + b1 * (f1[i] - f3[i])) * h;
291 }
292
293 let mut err_sq = S::ZERO;
295 for i in 0..dim {
296 let scale = options.atol + options.rtol * x[i].abs().max(x_new[i].abs());
297 let ratio = x_err[i] / scale;
298 err_sq += ratio * ratio;
299 }
300 let err = (err_sq / S::from_usize(dim)).sqrt();
301
302 if err <= one {
304 t += h;
305 x[..dim].copy_from_slice(&x_new[..dim]);
306 stats.n_accept += 1;
307 step += 1;
308
309 if options.save_trajectory {
310 t_out.push(t);
311 y_out.extend_from_slice(&x);
312 }
313 } else {
314 stats.n_reject += 1;
315 }
316
317 let err_safe = err.max(S::from_f64(1e-10));
319 let fac = safety * err_safe.powf(-one / (order + one));
320 h *= fac.max(fac_min).min(fac_max);
321 }
322
323 if step >= options.max_steps && t < tf {
324 return Err(format!(
325 "Maximum steps ({}) exceeded at t = {}",
326 options.max_steps,
327 t.to_f64()
328 ));
329 }
330
331 if !options.save_trajectory {
332 t_out.push(t);
333 y_out.extend_from_slice(&x);
334 }
335
336 Ok(SdeResult::new(t_out, y_out, dim, stats))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[allow(clippy::upper_case_acronyms)]
345 struct GBM {
346 mu: f64,
347 sigma: f64,
348 }
349
350 impl SdeSystem<f64> for GBM {
351 fn dim(&self) -> usize {
352 1
353 }
354 fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
355 f[0] = self.mu * x[0];
356 }
357 fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
358 g[0] = self.sigma * x[0];
359 }
360 }
361
362 #[test]
363 fn test_sra1_gbm() {
364 let gbm = GBM {
365 mu: 0.05,
366 sigma: 0.2,
367 };
368 let options = SdeOptions::default().dt(0.01).seed(42);
369
370 let result = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, None).expect("Solve failed");
371
372 assert!(result.success);
373 let final_price = result.y_final().unwrap()[0];
374 assert!(final_price > 0.0);
375 assert!(result.stats.n_accept > 0);
376 }
377
378 #[test]
379 fn test_sra2_gbm() {
380 let gbm = GBM {
381 mu: 0.05,
382 sigma: 0.2,
383 };
384 let options = SdeOptions::default().dt(0.01).seed(42);
385
386 let result = Sra2::solve(&gbm, 0.0, 1.0, &[100.0], &options, None).expect("Solve failed");
387
388 assert!(result.success);
389 let final_price = result.y_final().unwrap()[0];
390 assert!(final_price > 0.0);
391 }
392
393 #[test]
394 fn test_sra1_adapts_step() {
395 struct Stiff;
397 impl SdeSystem<f64> for Stiff {
398 fn dim(&self) -> usize {
399 1
400 }
401 fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
402 f[0] = -50.0 * x[0]; }
404 fn diffusion(&self, _t: f64, _x: &[f64], g: &mut [f64]) {
405 g[0] = 0.1;
406 }
407 }
408
409 let options = SdeOptions::default()
410 .dt(0.1) .rtol(1e-4)
412 .atol(1e-6)
413 .seed(42);
414
415 let result = Sra1::solve(&Stiff, 0.0, 1.0, &[1.0], &options, None).expect("Solve failed");
416
417 assert!(result.success);
418 assert!(result.stats.n_accept >= 10);
421 }
422
423 #[test]
424 fn test_reproducibility() {
425 let gbm = GBM {
426 mu: 0.05,
427 sigma: 0.2,
428 };
429 let options = SdeOptions::default().dt(0.01);
430
431 let r1 = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
432 let r2 = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
433
434 let y1 = r1.y_final().unwrap()[0];
435 let y2 = r2.y_final().unwrap()[0];
436 assert!((y1 - y2).abs() < 1e-10);
437 }
438}