use ndarray::Array1;
use ndarray::s;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use stochastic_rs_distributions::normal::SimdNormal;
use crate::traits::FloatExt;
use crate::traits::Fn1D;
use crate::traits::ProcessExt;
pub struct HullWhite<T: FloatExt, S: SeedExt = Unseeded> {
pub theta: Fn1D<T>,
pub alpha: T,
pub sigma: T,
pub n: usize,
pub x0: Option<T>,
pub t: Option<T>,
pub seed: S,
}
impl<T: FloatExt> HullWhite<T> {
pub fn new(
theta: impl Into<Fn1D<T>>,
alpha: T,
sigma: T,
n: usize,
x0: Option<T>,
t: Option<T>,
) -> Self {
Self {
theta: theta.into(),
alpha,
sigma,
n,
x0,
t,
seed: Unseeded,
}
}
}
impl<T: FloatExt> HullWhite<T, Deterministic> {
pub fn seeded(
theta: impl Into<Fn1D<T>>,
alpha: T,
sigma: T,
n: usize,
x0: Option<T>,
t: Option<T>,
seed: u64,
) -> Self {
Self {
theta: theta.into(),
alpha,
sigma,
n,
x0,
t,
seed: Deterministic::new(seed),
}
}
}
impl<T: FloatExt, S: SeedExt> ProcessExt<T> for HullWhite<T, S> {
type Output = Array1<T>;
fn sample(&self) -> Self::Output {
let mut hw = Array1::<T>::zeros(self.n);
if self.n == 0 {
return hw;
}
hw[0] = self.x0.unwrap_or(T::zero());
if self.n == 1 {
return hw;
}
let n_increments = self.n - 1;
let dt = self.t.unwrap_or(T::one()) / T::from_usize_(n_increments);
let sqrt_dt = dt.sqrt();
let diff_scale = self.sigma;
let mut prev = hw[0];
let mut tail_view = hw.slice_mut(s![1..]);
let tail = tail_view
.as_slice_mut()
.expect("HullWhite output tail must be contiguous");
let normal = SimdNormal::<T>::from_seed_source(T::zero(), sqrt_dt, &self.seed);
normal.fill_slice_fast(tail);
for (k, z) in tail.iter_mut().enumerate() {
let i = k + 1;
let next =
prev + (self.theta.call(T::from_usize_(i) * dt) - self.alpha * prev) * dt + diff_scale * *z;
*z = next;
prev = next;
}
hw
}
}
#[cfg(feature = "python")]
#[pyo3::prelude::pyclass]
pub struct PyHullWhite {
inner: Option<HullWhite<f64>>,
seeded: Option<HullWhite<f64, crate::simd_rng::Deterministic>>,
}
#[cfg(feature = "python")]
#[pyo3::prelude::pymethods]
impl PyHullWhite {
#[new]
#[pyo3(signature = (theta, alpha, sigma, n, x0=None, t=None, seed=None))]
fn new(
theta: pyo3::Py<pyo3::PyAny>,
alpha: f64,
sigma: f64,
n: usize,
x0: Option<f64>,
t: Option<f64>,
seed: Option<u64>,
) -> Self {
match seed {
Some(s) => Self {
inner: None,
seeded: Some(HullWhite::seeded(
Fn1D::Py(theta),
alpha,
sigma,
n,
x0,
t,
s,
)),
},
None => Self {
inner: Some(HullWhite::new(Fn1D::Py(theta), alpha, sigma, n, x0, t)),
seeded: None,
},
}
}
fn sample<'py>(&self, py: pyo3::Python<'py>) -> pyo3::Py<pyo3::PyAny> {
use numpy::IntoPyArray;
use pyo3::IntoPyObjectExt;
use crate::traits::ProcessExt;
py_dispatch_f64!(self, |inner| inner
.sample()
.into_pyarray(py)
.into_py_any(py)
.unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn const_theta(_t: f64) -> f64 {
0.04
}
#[test]
fn sample_length_matches_n() {
let hw = HullWhite::<f64>::new(
const_theta as fn(f64) -> f64,
0.5,
0.01,
64,
Some(0.04),
Some(1.0),
);
let path = hw.sample();
assert_eq!(path.len(), 64);
}
}