use super::Scheme;
use crate::process::markov::Drift;
use crate::state::{Diffusion, Increment};
pub struct Sri {
h: f64,
}
impl Sri {
pub fn new(h: f64) -> Self {
Self { h }
}
}
impl Scheme<f64> for Sri {
type Noise = f64;
fn step<D, G>(
&self,
drift: &D,
diffusion: &G,
x: &f64,
t: f64,
dt: f64,
inc: &Increment<f64>,
) -> f64
where
D: Drift<f64>,
G: Diffusion<f64, f64>,
{
let dw = inc.dw;
let dz = inc.dz; let h = self.h;
let f = drift(x, t);
let f_plus = drift(&(x + h), t);
let f_minus = drift(&(x - h), t);
let df_dx = (f_plus - f_minus) / (2.0 * h);
let g = diffusion.apply(x, t, &1.0_f64);
let g_plus = diffusion.apply(&(x + h), t, &1.0_f64);
let g_minus = diffusion.apply(&(x - h), t, &1.0_f64);
let dg_dx = (g_plus - g_minus) / (2.0 * h);
let d2g_dx2 = (g_plus - 2.0 * g + g_minus) / (h * h);
let i10 = dt * dw - dz;
let i11 = (dw * dw - dt) * 0.5;
let i111 = (dw * dw * dw - 3.0 * dt * dw) / 6.0;
let milstein = g * dg_dx * i11;
let term_l1f = g * df_dx * i10;
let term_l0g = (f * dg_dx + 0.5 * g * g * d2g_dx2) * dz;
let term_l0f = 0.5 * f * df_dx * dt * dt;
let d111 = g * dg_dx * dg_dx + g * g * d2g_dx2;
let term_l11g = d111 * i111;
x + f * dt + g * dw + milstein + term_l1f + term_l0g + term_l0f + term_l11g
}
}
pub fn sri() -> Sri {
Sri::new(1e-4)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::process::markov::bm;
use crate::state::Increment;
#[test]
fn sri_equals_euler_for_constant_diffusion() {
let b = bm();
let s = sri();
let e = crate::scheme::euler::euler();
let x = 1.0_f64;
let inc = Increment { dw: 0.3, dz: 0.001 };
let x_euler = e.step(&b.drift, &b.diffusion, &x, 0.0, 0.01, &inc);
let x_sri = s.step(&b.drift, &b.diffusion, &x, 0.0, 0.01, &inc);
assert!(
(x_euler - x_sri).abs() < 1e-6,
"BM: euler={} sri={}",
x_euler,
x_sri
);
}
#[test]
fn sri_differs_from_milstein_for_state_dependent_diffusion() {
let gbm = crate::process::markov::gbm(0.05, 0.3);
let s = sri();
let m = crate::scheme::milstein::milstein();
let x = 1.0_f64;
let inc = Increment {
dw: 0.05,
dz: 0.0001,
};
let dt = 0.01;
let x_sri = s.step(&gbm.drift, &gbm.diffusion, &x, 0.0, dt, &inc);
let x_mil = m.step(&gbm.drift, &gbm.diffusion, &x, 0.0, dt, &inc);
assert!(
(x_sri - x_mil).abs() > 1e-10,
"SRI and Milstein should differ"
);
}
}