use ndarray::Array1;
use crate::traits::FloatExt;
pub struct DiffusionProcessFn {
pub drift: Box<dyn Fn(f64, f64) -> f64>,
pub diffusion: Box<dyn Fn(f64, f64) -> f64>,
pub jump_term: Option<Box<dyn Fn(f64, f64, f64) -> f64>>,
}
pub struct Function2D {
pub eval: Box<dyn Fn(f64, f64) -> f64>,
}
pub struct ItoResult {
pub drift_term: f64,
pub diffusion_term: f64,
}
pub struct ItoCalculator {
pub process: DiffusionProcessFn,
pub function: Function2D,
pub h: f64,
}
impl ItoCalculator {
pub fn new(process: DiffusionProcessFn, function: Function2D, h: f64) -> Self {
ItoCalculator {
process,
function,
h,
}
}
fn dfdx(&self, t: f64, x: f64) -> f64 {
((self.function.eval)(t, x + self.h) - (self.function.eval)(t, x - self.h)) / (2.0 * self.h)
}
fn d2fdx2(&self, t: f64, x: f64) -> f64 {
((self.function.eval)(t, x + self.h) - 2.0 * (self.function.eval)(t, x)
+ (self.function.eval)(t, x - self.h))
/ self.h.powi(2)
}
fn dfdt(&self, t: f64, x: f64) -> f64 {
((self.function.eval)(t + self.h, x) - (self.function.eval)(t - self.h, x)) / (2.0 * self.h)
}
pub fn ito_transform(&self, t: f64, x: f64) -> ItoResult {
let mu = (self.process.drift)(t, x);
let sigma = (self.process.diffusion)(t, x);
let dfdx = self.dfdx(t, x);
let d2fdx2 = self.d2fdx2(t, x);
let dfdt = self.dfdt(t, x);
let drift_term = dfdt + mu * dfdx + 0.5 * sigma.powi(2) * d2fdx2;
let diffusion_term = sigma * dfdx;
ItoResult {
drift_term,
diffusion_term,
}
}
pub fn simulate(
&self,
x0: f64,
t0: f64,
t1: f64,
dt: f64,
_rng: &mut impl rand::Rng,
) -> Array1<(f64, f64)> {
let steps = ((t1 - t0) / dt).ceil().max(0.0) as usize;
let sqrt_dt = dt.sqrt();
let mut normals = vec![0.0; steps];
<f64 as FloatExt>::fill_standard_normal_slice(&mut normals);
let mut t = t0;
let mut x = x0;
let mut path = Vec::with_capacity(steps + 1);
path.push((t, x));
for z in normals {
let mu = (self.process.drift)(t, x);
let sigma = (self.process.diffusion)(t, x);
let dw = sigma * sqrt_dt * z;
let jump = if let Some(jump_fn) = &self.process.jump_term {
jump_fn(t, x, dt)
} else {
0.0
};
x += mu * dt + dw + jump;
t += dt;
path.push((t, x));
}
Array1::from(path)
}
}