use ndarray::Array1;
use rand_distr::Distribution;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use stochastic_rs_distributions::poisson::SimdPoisson;
use crate::noise::cgns::Cgns;
use crate::traits::FloatExt;
use crate::traits::ProcessExt;
pub struct Hkde<T: FloatExt, S: SeedExt = Unseeded> {
pub mu: T,
pub kappa: T,
pub theta: T,
pub sigma_v: T,
pub rho: T,
pub v0: T,
pub lambda: T,
pub p_up: T,
pub eta1: T,
pub eta2: T,
pub n: usize,
pub s0: Option<T>,
pub t: Option<T>,
pub use_sym: Option<bool>,
pub seed: S,
cgns: Cgns<T>,
}
impl<T: FloatExt> Hkde<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
mu: T,
kappa: T,
theta: T,
sigma_v: T,
rho: T,
v0: T,
lambda: T,
p_up: T,
eta1: T,
eta2: T,
n: usize,
s0: Option<T>,
t: Option<T>,
use_sym: Option<bool>,
) -> Self {
assert!(n >= 2, "n must be at least 2");
assert!(
rho >= -T::one() && rho <= T::one(),
"rho must be in [-1, 1]"
);
assert!(eta1 > T::one(), "eta1 must be > 1 for finite expectation");
assert!(eta2 > T::zero(), "eta2 must be > 0");
assert!(lambda >= T::zero(), "lambda must be >= 0");
Self {
mu,
kappa,
theta,
sigma_v,
rho,
v0,
lambda,
p_up,
eta1,
eta2,
n,
s0,
t,
use_sym,
seed: Unseeded,
cgns: Cgns::new(rho, n - 1, t),
}
}
}
impl<T: FloatExt> Hkde<T, Deterministic> {
#[allow(clippy::too_many_arguments)]
pub fn seeded(
mu: T,
kappa: T,
theta: T,
sigma_v: T,
rho: T,
v0: T,
lambda: T,
p_up: T,
eta1: T,
eta2: T,
n: usize,
s0: Option<T>,
t: Option<T>,
use_sym: Option<bool>,
seed: u64,
) -> Self {
assert!(n >= 2, "n must be at least 2");
Self {
mu,
kappa,
theta,
sigma_v,
rho,
v0,
lambda,
p_up,
eta1,
eta2,
n,
s0,
t,
use_sym,
seed: Deterministic::new(seed),
cgns: Cgns::new(rho, n - 1, t),
}
}
}
impl<T: FloatExt, S: SeedExt> Hkde<T, S> {
#[inline]
fn k_bar(&self) -> T {
self.p_up * self.eta1 / (self.eta1 - T::one())
+ (T::one() - self.p_up) * self.eta2 / (self.eta2 + T::one())
- T::one()
}
#[inline]
fn sample_kou_jump<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> T {
let u: f64 = rng.random();
let p = self.p_up.to_f64().unwrap();
if u < p {
let e: f64 = rand_distr::Exp::new(self.eta1.to_f64().unwrap())
.unwrap()
.sample(rng);
T::from_f64_fast(e)
} else {
let e: f64 = rand_distr::Exp::new(self.eta2.to_f64().unwrap())
.unwrap()
.sample(rng);
-T::from_f64_fast(e)
}
}
}
impl<T: FloatExt, S: SeedExt> ProcessExt<T> for Hkde<T, S> {
type Output = [Array1<T>; 2];
fn sample(&self) -> Self::Output {
let dt = self.cgns.dt();
let [cgn1, cgn2] = &self.cgns.sample_impl(&self.seed.derive());
let mut s = Array1::<T>::zeros(self.n);
let mut v = Array1::<T>::zeros(self.n);
let s0 = self.s0.unwrap_or(T::one());
assert!(s0 > T::zero(), "s0 must be > 0");
s[0] = s0;
v[0] = self.v0.max(T::zero());
let k_bar = self.k_bar();
let mut rng = self.seed.rng();
let pois = if self.lambda > T::zero() {
Some(SimdPoisson::<u32>::new(
(self.lambda * dt).to_f64().unwrap(),
))
} else {
None
};
for i in 1..self.n {
let v_prev = match self.use_sym.unwrap_or(false) {
true => v[i - 1].abs(),
false => v[i - 1].max(T::zero()),
};
let sqrt_v = v_prev.sqrt();
let mut jump_log = T::zero();
if let Some(pois) = &pois {
let k: u32 = pois.sample(&mut rng);
for _ in 0..k {
jump_log += self.sample_kou_jump(&mut rng);
}
}
let log_inc = (self.mu - self.lambda * k_bar - T::from_f64_fast(0.5) * v_prev) * dt
+ sqrt_v * cgn1[i - 1]
+ jump_log;
s[i] = s[i - 1] * log_inc.exp();
let dv = self.kappa * (self.theta - v_prev) * dt + self.sigma_v * sqrt_v * cgn2[i - 1];
v[i] = match self.use_sym.unwrap_or(false) {
true => (v_prev + dv).abs(),
false => (v_prev + dv).max(T::zero()),
};
}
[s, v]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_hkde() -> Hkde<f64> {
Hkde::new(
0.05,
1.5,
0.04,
0.3,
-0.7,
0.04,
0.5,
0.4,
5.0,
5.0,
256,
Some(100.0),
Some(1.0),
Some(false),
)
}
#[test]
fn price_stays_positive() {
let p = default_hkde();
let [s, _v] = p.sample();
assert!(s.iter().all(|x| *x > 0.0));
}
#[test]
fn variance_non_negative() {
let p = default_hkde();
let [_s, v] = p.sample();
assert!(v.iter().all(|x| *x >= 0.0));
}
#[test]
fn no_jumps_reduces_to_heston() {
let p = Hkde::seeded(
0.05,
1.5,
0.04,
0.3,
-0.7,
0.04,
0.0,
0.5,
5.0,
5.0,
1000,
Some(100.0),
Some(1.0),
Some(false),
42,
);
let [s, _v] = p.sample();
let final_price = *s.last().unwrap();
assert!(
final_price > 20.0 && final_price < 500.0,
"final={final_price}"
);
}
#[test]
fn seeded_is_deterministic() {
let p1 = Hkde::seeded(
0.05,
1.5,
0.04,
0.3,
-0.7,
0.04,
0.5,
0.4,
5.0,
5.0,
100,
Some(100.0),
Some(1.0),
None,
123,
);
let p2 = Hkde::seeded(
0.05,
1.5,
0.04,
0.3,
-0.7,
0.04,
0.5,
0.4,
5.0,
5.0,
100,
Some(100.0),
Some(1.0),
None,
123,
);
let [s1, v1] = p1.sample();
let [s2, v2] = p2.sample();
assert_eq!(s1, s2);
assert_eq!(v1, v2);
}
}