pathwise_core/scheme/
milstein_nd.rs1use super::Scheme;
2use crate::process::markov::Drift;
3use crate::state::{Diffusion, Increment};
4use nalgebra::SVector;
5
6pub struct MilsteinNd<const N: usize> {
16 pub h: f64,
17}
18
19impl<const N: usize> MilsteinNd<N> {
20 pub fn new(h: f64) -> Self {
21 Self { h }
22 }
23}
24
25impl<const N: usize> Scheme<SVector<f64, N>> for MilsteinNd<N> {
26 type Noise = SVector<f64, N>;
27
28 fn step<D, G>(
29 &self,
30 drift: &D,
31 diffusion: &G,
32 x: &SVector<f64, N>,
33 t: f64,
34 dt: f64,
35 inc: &Increment<SVector<f64, N>>,
36 ) -> SVector<f64, N>
37 where
38 D: Drift<SVector<f64, N>>,
39 G: Diffusion<SVector<f64, N>, SVector<f64, N>>,
40 {
41 let dw = &inc.dw;
42 let f_dt = drift(x, t) * dt;
43 let g_dw = diffusion.apply(x, t, dw);
44
45 let mut correction = SVector::<f64, N>::zeros();
46 let h = self.h;
47 for i in 0..N {
48 let mut ei = SVector::<f64, N>::zeros();
49 ei[i] = 1.0;
50 let g_ii = diffusion.apply(x, t, &ei)[i];
51 let mut xp = *x;
52 xp[i] += h;
53 let mut xm = *x;
54 xm[i] -= h;
55 let g_ii_plus = diffusion.apply(&xp, t, &ei)[i];
56 let g_ii_minus = diffusion.apply(&xm, t, &ei)[i];
57 let dg_ii = (g_ii_plus - g_ii_minus) / (2.0 * h);
58 correction[i] = 0.5 * g_ii * dg_ii * (dw[i] * dw[i] - dt);
59 }
60
61 x + f_dt + g_dw + correction
62 }
63}
64
65pub fn milstein_nd<const N: usize>() -> MilsteinNd<N> {
66 MilsteinNd::new(1e-5)
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::state::Increment;
73 use nalgebra::SVector;
74
75 #[test]
76 fn milstein_nd_equals_euler_nd_for_constant_diffusion() {
77 let m: MilsteinNd<2> = milstein_nd();
78 let e = crate::scheme::euler::EulerMaruyama;
79 let drift = |_x: &SVector<f64, 2>, _t: f64| SVector::zeros();
80 let diff = |_x: &SVector<f64, 2>, _t: f64| SVector::from([1.0_f64, 1.0]);
81 let x = SVector::from([0.5_f64, -0.5]);
82 let dw = SVector::from([0.1_f64, -0.2]);
83 let inc = Increment {
84 dw,
85 dz: SVector::zeros(),
86 };
87 let xe = e.step(&drift, &diff, &x, 0.0, 0.01, &inc);
88 let xm = m.step(&drift, &diff, &x, 0.0, 0.01, &inc);
89 assert!(
90 (xe - xm).norm() < 1e-8,
91 "should be equal for constant diffusion"
92 );
93 }
94}