use ndarray::Array1;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use stochastic_rs_distributions::normal::SimdNormal;
use crate::traits::FloatExt;
use crate::traits::ProcessExt;
#[derive(Debug, Clone)]
pub struct Egarch<T: FloatExt, S: SeedExt = Unseeded> {
pub omega: T,
pub alpha: Array1<T>,
pub gamma: Array1<T>,
pub beta: Array1<T>,
pub n: usize,
pub seed: S,
}
impl<T: FloatExt> Egarch<T> {
pub fn new(omega: T, alpha: Array1<T>, gamma: Array1<T>, beta: Array1<T>, n: usize) -> Self {
assert!(
alpha.len() == gamma.len(),
"Egarch requires alpha.len() == gamma.len()"
);
Self {
omega,
alpha,
gamma,
beta,
n,
seed: Unseeded,
}
}
}
impl<T: FloatExt> Egarch<T, Deterministic> {
pub fn seeded(
omega: T,
alpha: Array1<T>,
gamma: Array1<T>,
beta: Array1<T>,
n: usize,
seed: u64,
) -> Self {
assert!(
alpha.len() == gamma.len(),
"Egarch requires alpha.len() == gamma.len()"
);
Self {
omega,
alpha,
gamma,
beta,
n,
seed: Deterministic::new(seed),
}
}
}
impl<T: FloatExt, S: SeedExt> ProcessExt<T> for Egarch<T, S> {
type Output = Array1<T>;
fn sample(&self) -> Self::Output {
let p = self.alpha.len();
let q = self.beta.len();
let mut z = Array1::<T>::zeros(self.n);
if self.n > 0 {
let slice = z.as_slice_mut().expect("contiguous");
let normal = SimdNormal::<T>::from_seed_source(T::zero(), T::one(), &self.seed);
normal.fill_slice_fast(slice);
}
let mut x = Array1::<T>::zeros(self.n);
let mut log_sigma2 = Array1::<T>::zeros(self.n);
let e_abs_z = T::from_f64_fast((2.0_f64 / std::f64::consts::PI).sqrt());
for t in 0..self.n {
if t == 0 {
log_sigma2[t] = self.omega;
} else {
let mut shock_term = T::zero();
for i in 1..=p {
if t >= i {
let sigma_t_i = (log_sigma2[t - i].exp()).sqrt();
let z_t_i = x[t - i] / sigma_t_i;
shock_term += self.alpha[i - 1] * (z_t_i.abs() - e_abs_z) + self.gamma[i - 1] * z_t_i;
}
}
let mut persistence_term = T::zero();
for j in 1..=q {
if t >= j {
persistence_term += self.beta[j - 1] * log_sigma2[t - j];
}
}
log_sigma2[t] = self.omega + shock_term + persistence_term;
}
assert!(
log_sigma2[t].is_finite(),
"Egarch produced non-finite log-variance at t={}",
t
);
let sigma_t = (log_sigma2[t].exp()).sqrt();
assert!(
sigma_t.is_finite() && sigma_t > T::zero(),
"Egarch produced non-positive or non-finite sigma at t={}",
t
);
x[t] = sigma_t * z[t];
}
x
}
}
py_process_1d!(PyEgarch, Egarch,
sig: (omega, alpha, gamma_, beta, n, seed=None, dtype=None),
params: (omega: f64, alpha: Vec<f64>, gamma_: Vec<f64>, beta: Vec<f64>, n: usize)
);