Skip to main content

numra_sde/
sra.rs

1//! Adaptive SDE solvers based on Stochastic Runge-Kutta methods.
2//!
3//! These methods provide adaptive step size control for SDEs.
4//!
5//! - `Sra1`: Strong order 1.5 for additive noise, 1.0 for multiplicative
6//! - `Sra2`: Weak order 2.0 (better for computing expectations)
7//!
8//! Author: Moussa Leblouba
9//! Date: 4 February 2026
10//! Modified: 2 May 2026
11
12use crate::system::{NoiseType, SdeOptions, SdeResult, SdeSolver, SdeStats, SdeSystem};
13use crate::wiener::create_wiener;
14use numra_core::Scalar;
15
16/// SRA1 - Adaptive strong order 1.0-1.5 method.
17///
18/// Uses a two-stage Runge-Kutta scheme with embedded error estimation.
19/// Strong order 1.5 for additive noise, 1.0 for multiplicative.
20pub struct Sra1;
21
22impl<S: Scalar> SdeSolver<S> for Sra1 {
23    fn solve<Sys: SdeSystem<S>>(
24        system: &Sys,
25        t0: S,
26        tf: S,
27        x0: &[S],
28        options: &SdeOptions<S>,
29        seed: Option<u64>,
30    ) -> Result<SdeResult<S>, String> {
31        let dim = system.dim();
32        if x0.len() != dim {
33            return Err(format!(
34                "Initial state dimension {} doesn't match system dimension {}",
35                x0.len(),
36                dim
37            ));
38        }
39
40        // Currently only supports diagonal noise
41        match system.noise_type() {
42            NoiseType::Diagonal | NoiseType::Scalar => {}
43            _ => return Err("SRA1 currently only supports diagonal or scalar noise".to_string()),
44        }
45
46        let n_wiener = system.n_wiener();
47        let actual_seed = seed.or(options.seed);
48        let mut wiener = create_wiener(n_wiener, actual_seed);
49
50        // Allocate storage
51        let mut t = t0;
52        let mut x = x0.to_vec();
53        let mut h = options.dt.min(options.dt_max);
54
55        // Working arrays
56        let mut f1 = vec![S::ZERO; dim];
57        let mut f2 = vec![S::ZERO; dim];
58        let mut g1 = vec![S::ZERO; dim];
59        let mut g2 = vec![S::ZERO; dim];
60        let mut x_stage = vec![S::ZERO; dim];
61        let mut x_new = vec![S::ZERO; dim];
62        let mut x_err = vec![S::ZERO; dim];
63
64        let mut t_out = Vec::new();
65        let mut y_out = Vec::new();
66        let mut stats = SdeStats::default();
67
68        // Constants for step size control
69        let safety = S::from_f64(0.9);
70        let fac_min = S::from_f64(0.2);
71        let fac_max = S::from_f64(5.0);
72        let order = S::from_f64(1.5); // Strong order for error estimation
73
74        // Save initial state
75        if options.save_trajectory {
76            t_out.push(t);
77            y_out.extend_from_slice(&x);
78        }
79
80        let half = S::from_f64(0.5);
81        let one = S::ONE;
82        let mut step = 0;
83
84        while t < tf && step < options.max_steps {
85            // Limit step to not overshoot tf
86            h = h.min(tf - t).min(options.dt_max).max(options.dt_min);
87
88            // Generate Wiener increment
89            let dw = wiener.increment(h);
90            let sqrt_h = h.sqrt();
91
92            // Stage 1: Evaluate at current point
93            system.drift(t, &x, &mut f1);
94            system.diffusion(t, &x, &mut g1);
95            stats.n_drift += 1;
96            stats.n_diffusion += 1;
97
98            // Stage 2: Evaluate at predicted point
99            // x_stage = x + f1*h + g1*sqrt(h) (deterministic predictor)
100            for i in 0..dim {
101                x_stage[i] = x[i] + f1[i] * h + g1[i] * sqrt_h;
102            }
103            system.drift(t + h, &x_stage, &mut f2);
104            system.diffusion(t + h, &x_stage, &mut g2);
105            stats.n_drift += 1;
106            stats.n_diffusion += 1;
107
108            // Two-stage SRK update with Rößler-style coefficients
109            // Higher order: x_new = x + 0.5*(f1+f2)*h + 0.5*(g1+g2)*dW
110            // Lower order:  x_lo  = x + f1*h + g1*dW
111            // Error: x_err = 0.5*(f2-f1)*h + 0.5*(g2-g1)*dW
112
113            let is_scalar = matches!(system.noise_type(), NoiseType::Scalar);
114
115            for i in 0..dim {
116                let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
117
118                // Higher-order solution
119                x_new[i] = x[i] + half * (f1[i] + f2[i]) * h + half * (g1[i] + g2[i]) * dw_i;
120
121                // Error estimate (difference from lower-order solution)
122                x_err[i] = half * (f2[i] - f1[i]) * h + half * (g2[i] - g1[i]) * dw_i;
123            }
124
125            // Compute error norm
126            let mut err_sq = S::ZERO;
127            for i in 0..dim {
128                let scale = options.atol + options.rtol * x[i].abs().max(x_new[i].abs());
129                let ratio = x_err[i] / scale;
130                err_sq += ratio * ratio;
131            }
132            let err = (err_sq / S::from_usize(dim)).sqrt();
133
134            // Accept or reject step
135            if err <= one {
136                // Accept step
137                t += h;
138                x[..dim].copy_from_slice(&x_new[..dim]);
139                stats.n_accept += 1;
140                step += 1;
141
142                // Save state
143                if options.save_trajectory {
144                    t_out.push(t);
145                    y_out.extend_from_slice(&x);
146                }
147            } else {
148                // Reject step
149                stats.n_reject += 1;
150            }
151
152            // Compute new step size
153            let err_safe = err.max(S::from_f64(1e-10));
154            let fac = safety * err_safe.powf(-one / (order + one));
155            h *= fac.max(fac_min).min(fac_max);
156        }
157
158        if step >= options.max_steps && t < tf {
159            return Err(format!(
160                "Maximum steps ({}) exceeded at t = {}",
161                options.max_steps,
162                t.to_f64()
163            ));
164        }
165
166        // If not saving trajectory, just save final state
167        if !options.save_trajectory {
168            t_out.push(t);
169            y_out.extend_from_slice(&x);
170        }
171
172        Ok(SdeResult::new(t_out, y_out, dim, stats))
173    }
174}
175
176/// SRA2 - Adaptive weak order 2.0 method.
177///
178/// Better for computing expectations (mean, variance) rather than pathwise accuracy.
179/// Uses a three-stage scheme optimized for weak convergence.
180pub struct Sra2;
181
182impl<S: Scalar> SdeSolver<S> for Sra2 {
183    fn solve<Sys: SdeSystem<S>>(
184        system: &Sys,
185        t0: S,
186        tf: S,
187        x0: &[S],
188        options: &SdeOptions<S>,
189        seed: Option<u64>,
190    ) -> Result<SdeResult<S>, String> {
191        let dim = system.dim();
192        if x0.len() != dim {
193            return Err(format!(
194                "Initial state dimension {} doesn't match system dimension {}",
195                x0.len(),
196                dim
197            ));
198        }
199
200        // Currently only supports diagonal noise
201        match system.noise_type() {
202            NoiseType::Diagonal | NoiseType::Scalar => {}
203            _ => return Err("SRA2 currently only supports diagonal or scalar noise".to_string()),
204        }
205
206        let n_wiener = system.n_wiener();
207        let actual_seed = seed.or(options.seed);
208        let mut wiener = create_wiener(n_wiener, actual_seed);
209
210        // Allocate storage
211        let mut t = t0;
212        let mut x = x0.to_vec();
213        let mut h = options.dt.min(options.dt_max);
214
215        // Working arrays
216        let mut f1 = vec![S::ZERO; dim];
217        let mut f2 = vec![S::ZERO; dim];
218        let mut f3 = vec![S::ZERO; dim];
219        let mut g1 = vec![S::ZERO; dim];
220        let mut g2 = vec![S::ZERO; dim];
221        let mut x_stage = vec![S::ZERO; dim];
222        let mut x_new = vec![S::ZERO; dim];
223        let mut x_err = vec![S::ZERO; dim];
224
225        let mut t_out = Vec::new();
226        let mut y_out = Vec::new();
227        let mut stats = SdeStats::default();
228
229        // Constants for step size control
230        let safety = S::from_f64(0.9);
231        let fac_min = S::from_f64(0.2);
232        let fac_max = S::from_f64(5.0);
233        let order = S::from_f64(2.0); // Weak order
234
235        // Coefficients for 3-stage weak order 2 method
236        let c2 = S::from_f64(2.0 / 3.0);
237        let a21 = S::from_f64(2.0 / 3.0);
238        let b1 = S::from_f64(0.25);
239        let b2 = S::from_f64(0.75);
240
241        // Save initial state
242        if options.save_trajectory {
243            t_out.push(t);
244            y_out.extend_from_slice(&x);
245        }
246
247        let one = S::ONE;
248        let half = S::from_f64(0.5);
249        let mut step = 0;
250
251        while t < tf && step < options.max_steps {
252            h = h.min(tf - t).min(options.dt_max).max(options.dt_min);
253
254            // Generate Wiener increment
255            let dw = wiener.increment(h);
256            let sqrt_h = h.sqrt();
257
258            // Stage 1
259            system.drift(t, &x, &mut f1);
260            system.diffusion(t, &x, &mut g1);
261            stats.n_drift += 1;
262            stats.n_diffusion += 1;
263
264            // Stage 2 at t + c2*h
265            let is_scalar = matches!(system.noise_type(), NoiseType::Scalar);
266            for i in 0..dim {
267                let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
268                x_stage[i] = x[i] + a21 * f1[i] * h + g1[i] * sqrt_h;
269                let _ = dw_i; // Used for stochastic part
270            }
271            system.drift(t + c2 * h, &x_stage, &mut f2);
272            system.diffusion(t + c2 * h, &x_stage, &mut g2);
273            stats.n_drift += 1;
274            stats.n_diffusion += 1;
275
276            // Final stage at t + h for error estimation
277            for i in 0..dim {
278                x_stage[i] = x[i] + f1[i] * h;
279            }
280            system.drift(t + h, &x_stage, &mut f3);
281            stats.n_drift += 1;
282
283            // Compute solution: x_new = x + (b1*f1 + b2*f2)*h + g1*dW
284            for i in 0..dim {
285                let dw_i = if is_scalar { dw.dw[0] } else { dw.dw[i] };
286
287                x_new[i] = x[i] + (b1 * f1[i] + b2 * f2[i]) * h + half * (g1[i] + g2[i]) * dw_i;
288
289                // Error estimate (embedded method difference)
290                x_err[i] = (b2 * (f2[i] - f1[i]) + b1 * (f1[i] - f3[i])) * h;
291            }
292
293            // Compute error norm
294            let mut err_sq = S::ZERO;
295            for i in 0..dim {
296                let scale = options.atol + options.rtol * x[i].abs().max(x_new[i].abs());
297                let ratio = x_err[i] / scale;
298                err_sq += ratio * ratio;
299            }
300            let err = (err_sq / S::from_usize(dim)).sqrt();
301
302            // Accept or reject step
303            if err <= one {
304                t += h;
305                x[..dim].copy_from_slice(&x_new[..dim]);
306                stats.n_accept += 1;
307                step += 1;
308
309                if options.save_trajectory {
310                    t_out.push(t);
311                    y_out.extend_from_slice(&x);
312                }
313            } else {
314                stats.n_reject += 1;
315            }
316
317            // Compute new step size
318            let err_safe = err.max(S::from_f64(1e-10));
319            let fac = safety * err_safe.powf(-one / (order + one));
320            h *= fac.max(fac_min).min(fac_max);
321        }
322
323        if step >= options.max_steps && t < tf {
324            return Err(format!(
325                "Maximum steps ({}) exceeded at t = {}",
326                options.max_steps,
327                t.to_f64()
328            ));
329        }
330
331        if !options.save_trajectory {
332            t_out.push(t);
333            y_out.extend_from_slice(&x);
334        }
335
336        Ok(SdeResult::new(t_out, y_out, dim, stats))
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[allow(clippy::upper_case_acronyms)]
345    struct GBM {
346        mu: f64,
347        sigma: f64,
348    }
349
350    impl SdeSystem<f64> for GBM {
351        fn dim(&self) -> usize {
352            1
353        }
354        fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
355            f[0] = self.mu * x[0];
356        }
357        fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
358            g[0] = self.sigma * x[0];
359        }
360    }
361
362    #[test]
363    fn test_sra1_gbm() {
364        let gbm = GBM {
365            mu: 0.05,
366            sigma: 0.2,
367        };
368        let options = SdeOptions::default().dt(0.01).seed(42);
369
370        let result = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, None).expect("Solve failed");
371
372        assert!(result.success);
373        let final_price = result.y_final().unwrap()[0];
374        assert!(final_price > 0.0);
375        assert!(result.stats.n_accept > 0);
376    }
377
378    #[test]
379    fn test_sra2_gbm() {
380        let gbm = GBM {
381            mu: 0.05,
382            sigma: 0.2,
383        };
384        let options = SdeOptions::default().dt(0.01).seed(42);
385
386        let result = Sra2::solve(&gbm, 0.0, 1.0, &[100.0], &options, None).expect("Solve failed");
387
388        assert!(result.success);
389        let final_price = result.y_final().unwrap()[0];
390        assert!(final_price > 0.0);
391    }
392
393    #[test]
394    fn test_sra1_adapts_step() {
395        // Stiff problem should cause step rejection/adaptation
396        struct Stiff;
397        impl SdeSystem<f64> for Stiff {
398            fn dim(&self) -> usize {
399                1
400            }
401            fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
402                f[0] = -50.0 * x[0]; // Fast dynamics
403            }
404            fn diffusion(&self, _t: f64, _x: &[f64], g: &mut [f64]) {
405                g[0] = 0.1;
406            }
407        }
408
409        let options = SdeOptions::default()
410            .dt(0.1) // Large initial step
411            .rtol(1e-4)
412            .atol(1e-6)
413            .seed(42);
414
415        let result = Sra1::solve(&Stiff, 0.0, 1.0, &[1.0], &options, None).expect("Solve failed");
416
417        assert!(result.success);
418        // Should have some rejected steps due to large initial dt
419        // or many accepted steps with smaller dt
420        assert!(result.stats.n_accept >= 10);
421    }
422
423    #[test]
424    fn test_reproducibility() {
425        let gbm = GBM {
426            mu: 0.05,
427            sigma: 0.2,
428        };
429        let options = SdeOptions::default().dt(0.01);
430
431        let r1 = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
432        let r2 = Sra1::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
433
434        let y1 = r1.y_final().unwrap()[0];
435        let y2 = r2.y_final().unwrap()[0];
436        assert!((y1 - y2).abs() < 1e-10);
437    }
438}