use ndarray::Array1;
use rand_distr::Distribution;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;
use crate::traits::FloatExt;
use crate::traits::ProcessExt;
pub struct CustomJt<T, D, S: SeedExt = Unseeded>
where
T: FloatExt,
D: Distribution<T> + Send + Sync,
{
pub n: Option<usize>,
pub t_max: Option<T>,
pub distribution: D,
pub seed: S,
}
#[inline]
fn validate_n_or_tmax<T: FloatExt>(n: Option<usize>, t_max: Option<T>, type_name: &'static str) {
if n.is_none() && t_max.is_none() {
panic!("{type_name}: n or t_max must be provided");
}
}
impl<T, D> CustomJt<T, D>
where
T: FloatExt,
D: Distribution<T> + Send + Sync,
{
pub fn new(n: Option<usize>, t_max: Option<T>, distribution: D) -> Self {
validate_n_or_tmax(n, t_max, "CustomJt");
CustomJt {
n,
t_max,
distribution,
seed: Unseeded,
}
}
}
impl<T, D> CustomJt<T, D, Deterministic>
where
T: FloatExt,
D: Distribution<T> + Send + Sync,
{
pub fn seeded(n: Option<usize>, t_max: Option<T>, distribution: D, seed: u64) -> Self {
validate_n_or_tmax(n, t_max, "CustomJt");
CustomJt {
n,
t_max,
distribution,
seed: Deterministic::new(seed),
}
}
}
#[cfg(feature = "python")]
#[pyo3::prelude::pyclass]
pub struct PyCustomJt {
inner_f32: Option<CustomJt<f32, crate::traits::CallableDist<f32>>>,
inner_f64: Option<CustomJt<f64, crate::traits::CallableDist<f64>>>,
}
#[cfg(feature = "python")]
#[pyo3::prelude::pymethods]
impl PyCustomJt {
#[new]
#[pyo3(signature = (distribution, n=None, t_max=None, dtype=None))]
fn new(
distribution: pyo3::Py<pyo3::PyAny>,
n: Option<usize>,
t_max: Option<f64>,
dtype: Option<&str>,
) -> Self {
match dtype.unwrap_or("f64") {
"f32" => Self {
inner_f32: Some(CustomJt::new(
n,
t_max.map(|v| v as f32),
crate::traits::CallableDist::new(distribution),
)),
inner_f64: None,
},
_ => Self {
inner_f32: None,
inner_f64: Some(CustomJt::new(
n,
t_max,
crate::traits::CallableDist::new(distribution),
)),
},
}
}
fn sample<'py>(&self, py: pyo3::Python<'py>) -> pyo3::Py<pyo3::PyAny> {
use numpy::IntoPyArray;
use pyo3::IntoPyObjectExt;
use crate::traits::ProcessExt;
if let Some(ref inner) = self.inner_f64 {
inner.sample().into_pyarray(py).into_py_any(py).unwrap()
} else if let Some(ref inner) = self.inner_f32 {
inner.sample().into_pyarray(py).into_py_any(py).unwrap()
} else {
unreachable!()
}
}
fn sample_par<'py>(&self, py: pyo3::Python<'py>, m: usize) -> pyo3::Py<pyo3::PyAny> {
use numpy::IntoPyArray;
use numpy::ndarray::Array2;
use pyo3::IntoPyObjectExt;
use crate::traits::ProcessExt;
if let Some(ref inner) = self.inner_f64 {
let paths = inner.sample_par(m);
let n = paths[0].len();
let mut result = Array2::<f64>::zeros((m, n));
for (i, path) in paths.iter().enumerate() {
result.row_mut(i).assign(path);
}
result.into_pyarray(py).into_py_any(py).unwrap()
} else if let Some(ref inner) = self.inner_f32 {
let paths = inner.sample_par(m);
let n = paths[0].len();
let mut result = Array2::<f32>::zeros((m, n));
for (i, path) in paths.iter().enumerate() {
result.row_mut(i).assign(path);
}
result.into_pyarray(py).into_py_any(py).unwrap()
} else {
unreachable!()
}
}
}
impl<T, D, S: SeedExt> CustomJt<T, D, S>
where
T: FloatExt,
D: Distribution<T> + Send + Sync,
{
pub(crate) fn sample_impl<S2: SeedExt>(&self, seed: &S2) -> Array1<T> {
if let Some(n) = self.n {
let mut random = Array1::<T>::zeros(n);
let mut rng = seed.rng();
for x in &mut random {
*x = self.distribution.sample(&mut rng);
}
let mut x = Array1::<T>::zeros(n);
for i in 1..n {
x[i] = x[i - 1] + random[i - 1];
}
x
} else if let Some(t_max) = self.t_max {
let mut x = Vec::with_capacity(16);
x.push(T::zero());
let mut t = T::zero();
seed.derive();
let mut rng = seed.rng();
while t < t_max {
t += self.distribution.sample(&mut rng);
x.push(t);
}
Array1::from(x)
} else {
unreachable!("validate_n_or_tmax ensures at least one of n, t_max is set")
}
}
}
impl<T, D, S: SeedExt> ProcessExt<T> for CustomJt<T, D, S>
where
T: FloatExt,
D: Distribution<T> + Send + Sync,
{
type Output = Array1<T>;
fn sample(&self) -> Self::Output {
self.sample_impl(&self.seed)
}
}