use ndarray::Array1;
use ndarray::Array2;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use super::markov_lift::RoughSimd;
use super::rl_fbm::RlFBm;
use crate::traits::FloatExt;
use crate::traits::ProcessExt;
#[derive(Clone)]
pub struct RlFOU<T: FloatExt, S: SeedExt = Unseeded> {
pub hurst: T,
pub kappa: T,
pub mu: T,
pub sigma: T,
pub n: usize,
pub x0: Option<T>,
pub t: Option<T>,
pub degree: Option<usize>,
pub seed: S,
fbm: RlFBm<T>,
}
impl<T: FloatExt> RlFOU<T> {
#[must_use]
pub fn new(
hurst: T,
kappa: T,
mu: T,
sigma: T,
n: usize,
x0: Option<T>,
t: Option<T>,
degree: Option<usize>,
) -> Self {
assert!(n >= 2, "n must be at least 2");
Self {
hurst,
kappa,
mu,
sigma,
n,
x0,
t,
degree,
seed: Unseeded,
fbm: RlFBm::new(hurst, n, t, degree),
}
}
}
impl<T: FloatExt> RlFOU<T, Deterministic> {
#[must_use]
pub fn seeded(
hurst: T,
kappa: T,
mu: T,
sigma: T,
n: usize,
x0: Option<T>,
t: Option<T>,
degree: Option<usize>,
seed: u64,
) -> Self {
assert!(n >= 2, "n must be at least 2");
Self {
hurst,
kappa,
mu,
sigma,
n,
x0,
t,
degree,
seed: Deterministic::new(seed),
fbm: RlFBm::new(hurst, n, t, degree),
}
}
}
impl<T: FloatExt + RoughSimd, S: SeedExt> RlFOU<T, S> {
pub fn sample_batch(&self, m: usize) -> Array2<T> {
let fbm = self.fbm.sample_batch_impl(&self.seed.derive(), m);
let dt = self.t.unwrap_or(T::one()) / T::from_usize_(self.n - 1);
let x0 = self.x0.unwrap_or(T::zero());
let mut out = Array2::<T>::zeros((m, self.n));
for p in 0..m {
out[[p, 0]] = x0;
for i in 1..self.n {
let dfbm = fbm[[p, i]] - fbm[[p, i - 1]];
out[[p, i]] =
out[[p, i - 1]] + self.kappa * (self.mu - out[[p, i - 1]]) * dt + self.sigma * dfbm;
}
}
out
}
}
impl<T: FloatExt + RoughSimd, S: SeedExt> ProcessExt<T> for RlFOU<T, S> {
type Output = Array1<T>;
fn sample(&self) -> Self::Output {
let dt = self.t.unwrap_or(T::one()) / T::from_usize_(self.n - 1);
let fbm = self.fbm.sample_impl(&self.seed.derive());
let mut x = Array1::<T>::zeros(self.n);
x[0] = self.x0.unwrap_or(T::zero());
for i in 1..self.n {
let dfbm = fbm[i] - fbm[i - 1];
x[i] = x[i - 1] + self.kappa * (self.mu - x[i - 1]) * dt + self.sigma * dfbm;
}
x
}
}
#[cfg(test)]
mod tests {
use super::RlFOU;
use crate::traits::ProcessExt;
#[test]
fn fou_sigma_zero_matches_deterministic_euler() {
let kappa = 1.3_f64;
let mu = 0.8_f64;
let n = 129;
let x0 = 0.2_f64;
let t = 1.0_f64;
let p = RlFOU::<f64>::new(0.3, kappa, mu, 0.0, n, Some(x0), Some(t), None);
let x = p.sample();
let dt = t / (n as f64 - 1.0);
let mut expected = x0;
for i in 1..n {
expected = expected + kappa * (mu - expected) * dt;
assert!((x[i] - expected).abs() < 1e-12, "mismatch at {i}");
}
}
#[test]
fn finite_output_at_typical_rfsv_parameters() {
let p = RlFOU::seeded(
0.1_f64,
2.0,
0.15_f64.ln(),
0.25,
512,
Some(0.15_f64.ln()),
Some(1.0),
None,
7,
);
let x = p.sample();
assert_eq!(x.len(), 512);
assert!(x.iter().all(|v| v.is_finite()));
}
}