pathwise_core/scheme/
milstein_nd.rs1use super::Scheme;
2use crate::state::{Diffusion, Increment};
3use crate::process::markov::Drift;
4use nalgebra::SVector;
5
6pub struct MilsteinNd<const N: usize> {
16 pub h: f64,
17}
18
19impl<const N: usize> MilsteinNd<N> {
20 pub fn new(h: f64) -> Self { Self { h } }
21}
22
23impl<const N: usize> Scheme<SVector<f64, N>> for MilsteinNd<N> {
24 type Noise = SVector<f64, N>;
25
26 fn step<D, G>(
27 &self,
28 drift: &D,
29 diffusion: &G,
30 x: &SVector<f64, N>,
31 t: f64,
32 dt: f64,
33 inc: &Increment<SVector<f64, N>>,
34 ) -> SVector<f64, N>
35 where
36 D: Drift<SVector<f64, N>>,
37 G: Diffusion<SVector<f64, N>, SVector<f64, N>>,
38 {
39 let dw = &inc.dw;
40 let f_dt = drift(x, t) * dt;
41 let g_dw = diffusion.apply(x, t, dw);
42
43 let mut correction = SVector::<f64, N>::zeros();
44 let h = self.h;
45 for i in 0..N {
46 let mut ei = SVector::<f64, N>::zeros();
47 ei[i] = 1.0;
48 let g_ii = diffusion.apply(x, t, &ei)[i];
49 let mut xp = *x; xp[i] += h;
50 let mut xm = *x; xm[i] -= h;
51 let g_ii_plus = diffusion.apply(&xp, t, &ei)[i];
52 let g_ii_minus = diffusion.apply(&xm, t, &ei)[i];
53 let dg_ii = (g_ii_plus - g_ii_minus) / (2.0 * h);
54 correction[i] = 0.5 * g_ii * dg_ii * (dw[i] * dw[i] - dt);
55 }
56
57 x + f_dt + g_dw + correction
58 }
59}
60
61pub fn milstein_nd<const N: usize>() -> MilsteinNd<N> { MilsteinNd::new(1e-5) }
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::state::Increment;
67 use nalgebra::SVector;
68
69 #[test]
70 fn milstein_nd_equals_euler_nd_for_constant_diffusion() {
71 let m: MilsteinNd<2> = milstein_nd();
72 let e = crate::scheme::euler::EulerMaruyama;
73 let drift = |_x: &SVector<f64, 2>, _t: f64| SVector::zeros();
74 let diff = |_x: &SVector<f64, 2>, _t: f64| SVector::from([1.0_f64, 1.0]);
75 let x = SVector::from([0.5_f64, -0.5]);
76 let dw = SVector::from([0.1_f64, -0.2]);
77 let inc = Increment { dw, dz: SVector::zeros() };
78 let xe = e.step(&drift, &diff, &x, 0.0, 0.01, &inc);
79 let xm = m.step(&drift, &diff, &x, 0.0, 0.01, &inc);
80 assert!((xe - xm).norm() < 1e-8, "should be equal for constant diffusion");
81 }
82}