use ndarray::Array1;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use crate::noise::cgns::Cgns;
use crate::traits::FloatExt;
use crate::traits::Fn1D;
use crate::traits::ProcessExt;
pub struct HullWhite2F<T: FloatExt, S: SeedExt = Unseeded> {
pub k: Fn1D<T>,
pub theta: T,
pub sigma1: T,
pub sigma2: T,
pub rho: T,
pub b: T,
pub x0: Option<T>,
pub t: Option<T>,
pub n: usize,
pub seed: S,
cgns: Cgns<T, S>,
}
impl<T: FloatExt> HullWhite2F<T> {
pub fn new(
k: impl Into<Fn1D<T>>,
theta: T,
sigma1: T,
sigma2: T,
rho: T,
b: T,
x0: Option<T>,
t: Option<T>,
n: usize,
) -> Self {
Self {
k: k.into(),
theta,
sigma1,
sigma2,
rho,
b,
x0,
t,
n,
seed: Unseeded,
cgns: Cgns::new(rho, n - 1, t),
}
}
}
impl<T: FloatExt> HullWhite2F<T, Deterministic> {
pub fn seeded(
k: impl Into<Fn1D<T>>,
theta: T,
sigma1: T,
sigma2: T,
rho: T,
b: T,
x0: Option<T>,
t: Option<T>,
n: usize,
seed: u64,
) -> Self {
let s = Deterministic::new(seed);
let child = s.derive();
Self {
k: k.into(),
theta,
sigma1,
sigma2,
rho,
b,
x0,
t,
n,
seed: Deterministic::new(seed),
cgns: Cgns::seeded(rho, n - 1, t, child.current()),
}
}
}
impl<T: FloatExt, S: SeedExt> ProcessExt<T> for HullWhite2F<T, S> {
type Output = [Array1<T>; 2];
fn sample(&self) -> Self::Output {
let dt = self.cgns.dt();
let [cgn1, cgn2] = &self.cgns.sample();
let mut x = Array1::<T>::zeros(self.n);
let mut u = Array1::<T>::zeros(self.n);
x[0] = self.x0.unwrap_or(T::zero());
for i in 1..self.n {
x[i] = x[i - 1]
+ (self.k.call(T::from_usize_(i) * dt) + u[i - 1] - self.theta * x[i - 1]) * dt
+ self.sigma1 * cgn1[i - 1];
u[i] = u[i - 1] - self.b * u[i - 1] * dt + self.sigma2 * cgn2[i - 1];
}
[x, u]
}
}
#[cfg(feature = "python")]
#[pyo3::prelude::pyclass]
pub struct PyHullWhite2F {
inner: Option<HullWhite2F<f64>>,
seeded: Option<HullWhite2F<f64, crate::simd_rng::Deterministic>>,
}
#[cfg(feature = "python")]
#[pyo3::prelude::pymethods]
impl PyHullWhite2F {
#[new]
#[pyo3(signature = (k, theta, sigma1, sigma2, rho, b, n, x0=None, t=None, seed=None))]
fn new(
k: pyo3::Py<pyo3::PyAny>,
theta: f64,
sigma1: f64,
sigma2: f64,
rho: f64,
b: f64,
n: usize,
x0: Option<f64>,
t: Option<f64>,
seed: Option<u64>,
) -> Self {
match seed {
Some(s) => Self {
inner: None,
seeded: Some(HullWhite2F::seeded(
Fn1D::Py(k),
theta,
sigma1,
sigma2,
rho,
b,
x0,
t,
n,
s,
)),
},
None => Self {
inner: Some(HullWhite2F::new(
Fn1D::Py(k),
theta,
sigma1,
sigma2,
rho,
b,
x0,
t,
n,
)),
seeded: None,
},
}
}
fn sample<'py>(&self, py: pyo3::Python<'py>) -> (pyo3::Py<pyo3::PyAny>, pyo3::Py<pyo3::PyAny>) {
use numpy::IntoPyArray;
use pyo3::IntoPyObjectExt;
use crate::traits::ProcessExt;
py_dispatch_f64!(self, |inner| {
let [a, b] = inner.sample();
(
a.into_pyarray(py).into_py_any(py).unwrap(),
b.into_pyarray(py).into_py_any(py).unwrap(),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn const_k(_t: f64) -> f64 {
0.5
}
#[test]
fn sample_returns_two_paths() {
let hw2 = HullWhite2F::<f64>::new(
const_k as fn(f64) -> f64,
0.04,
0.01,
0.005,
-0.3,
0.4,
Some(0.04),
Some(1.0),
64,
);
let [x, u] = hw2.sample();
assert_eq!(x.len(), 64);
assert_eq!(u.len(), 64);
}
}