use ndarray::Array1;
use ndarray::Axis;
use rand_distr::Distribution;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use super::customjt::CustomJt;
use crate::traits::FloatExt;
use crate::traits::ProcessExt;
pub struct CompoundCustom<T, D1, D2, S: SeedExt = Unseeded>
where
T: FloatExt,
D1: Distribution<T> + Send + Sync,
D2: Distribution<T> + Send + Sync,
{
pub n: Option<usize>,
pub t_max: Option<T>,
pub jumps_distribution: D1,
pub jump_times_distribution: D2,
pub customjt: CustomJt<T, D2>,
pub seed: S,
}
impl<T, D1, D2> CompoundCustom<T, D1, D2>
where
T: FloatExt,
D1: Distribution<T> + Send + Sync,
D2: Distribution<T> + Send + Sync,
{
pub fn new(
n: Option<usize>,
t_max: Option<T>,
jumps_distribution: D1,
jump_times_distribution: D2,
customjt: CustomJt<T, D2>,
) -> Self {
if n.is_none() && t_max.is_none() {
panic!("CompoundCustom: n or t_max must be provided");
}
Self {
n,
t_max,
jumps_distribution,
jump_times_distribution,
customjt,
seed: Unseeded,
}
}
}
impl<T, D1, D2> CompoundCustom<T, D1, D2, Deterministic>
where
T: FloatExt,
D1: Distribution<T> + Send + Sync,
D2: Distribution<T> + Send + Sync,
{
pub fn seeded(
n: Option<usize>,
t_max: Option<T>,
jumps_distribution: D1,
jump_times_distribution: D2,
customjt: CustomJt<T, D2>,
seed: u64,
) -> Self {
if n.is_none() && t_max.is_none() {
panic!("CompoundCustom: n or t_max must be provided");
}
Self {
n,
t_max,
jumps_distribution,
jump_times_distribution,
customjt,
seed: Deterministic::new(seed),
}
}
}
impl<T, D1, D2, S: SeedExt> ProcessExt<T> for CompoundCustom<T, D1, D2, S>
where
T: FloatExt,
D1: Distribution<T> + Send + Sync,
D2: Distribution<T> + Send + Sync,
{
type Output = [Array1<T>; 3];
fn sample(&self) -> Self::Output {
let p = self.customjt.sample_impl(&self.seed.derive());
let mut jumps = Array1::<T>::zeros(self.n.unwrap_or(p.len()));
let mut rng = self.seed.rng();
for i in 1..p.len() {
jumps[i] = self.jumps_distribution.sample(&mut rng);
}
let mut cum_jupms = jumps.clone();
cum_jupms.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
[p, cum_jupms, jumps]
}
}
#[cfg(feature = "python")]
#[pyo3::prelude::pyclass]
pub struct PyCompoundCustom {
inner_f32:
Option<CompoundCustom<f32, crate::traits::CallableDist<f32>, crate::traits::CallableDist<f32>>>,
inner_f64:
Option<CompoundCustom<f64, crate::traits::CallableDist<f64>, crate::traits::CallableDist<f64>>>,
}
#[cfg(feature = "python")]
#[pyo3::prelude::pymethods]
impl PyCompoundCustom {
#[new]
#[pyo3(signature = (jumps_distribution, jump_times_distribution, n=None, t_max=None, dtype=None))]
fn new(
jumps_distribution: pyo3::Py<pyo3::PyAny>,
jump_times_distribution: pyo3::Py<pyo3::PyAny>,
n: Option<usize>,
t_max: Option<f64>,
dtype: Option<&str>,
) -> Self {
match dtype.unwrap_or("f64") {
"f32" => {
let (jt_dist, customjt_dist) = pyo3::Python::attach(|py| {
let a = jump_times_distribution.clone_ref(py);
let b = jump_times_distribution;
(
crate::traits::CallableDist::<f32>::new(a),
crate::traits::CallableDist::<f32>::new(b),
)
});
let customjt = CustomJt::new(n, t_max.map(|v| v as f32), customjt_dist);
Self {
inner_f32: Some(CompoundCustom::new(
n,
t_max.map(|v| v as f32),
crate::traits::CallableDist::new(jumps_distribution),
jt_dist,
customjt,
)),
inner_f64: None,
}
}
_ => {
let (jt_dist, customjt_dist) = pyo3::Python::attach(|py| {
let a = jump_times_distribution.clone_ref(py);
let b = jump_times_distribution;
(
crate::traits::CallableDist::<f64>::new(a),
crate::traits::CallableDist::<f64>::new(b),
)
});
let customjt = CustomJt::new(n, t_max, customjt_dist);
Self {
inner_f32: None,
inner_f64: Some(CompoundCustom::new(
n,
t_max,
crate::traits::CallableDist::new(jumps_distribution),
jt_dist,
customjt,
)),
}
}
}
}
fn sample<'py>(
&self,
py: pyo3::Python<'py>,
) -> (
pyo3::Py<pyo3::PyAny>,
pyo3::Py<pyo3::PyAny>,
pyo3::Py<pyo3::PyAny>,
) {
use numpy::IntoPyArray;
use pyo3::IntoPyObjectExt;
use crate::traits::ProcessExt;
if let Some(ref inner) = self.inner_f64 {
let [p, cum, j] = inner.sample();
(
p.into_pyarray(py).into_py_any(py).unwrap(),
cum.into_pyarray(py).into_py_any(py).unwrap(),
j.into_pyarray(py).into_py_any(py).unwrap(),
)
} else if let Some(ref inner) = self.inner_f32 {
let [p, cum, j] = inner.sample();
(
p.into_pyarray(py).into_py_any(py).unwrap(),
cum.into_pyarray(py).into_py_any(py).unwrap(),
j.into_pyarray(py).into_py_any(py).unwrap(),
)
} else {
unreachable!()
}
}
}