stochastic-rs-stochastic 2.0.0

Stochastic process simulation.
Documentation
//! # Customjt
//!
//! $$
//! dX_t=a(t,X_t)dt+b(t,X_t)dW_t+\sum_{k=1}^{dN_t}J_k
//! $$
//!
use ndarray::Array1;
use rand_distr::Distribution;
use stochastic_rs_core::simd_rng::Deterministic;
use stochastic_rs_core::simd_rng::SeedExt;
use stochastic_rs_core::simd_rng::Unseeded;

use crate::traits::FloatExt;
use crate::traits::ProcessExt;

pub struct CustomJt<T, D, S: SeedExt = Unseeded>
where
  T: FloatExt,
  D: Distribution<T> + Send + Sync,
{
  /// Optional fixed number of generated events.
  pub n: Option<usize>,
  /// Optional horizon for time-based generation.
  /// Used when `n` is `None`.
  pub t_max: Option<T>,
  /// Distribution used for generated increments / inter-arrival draws.
  pub distribution: D,
  /// Seed strategy (compile-time: [`Unseeded`] or [`Deterministic`]).
  pub seed: S,
}

#[inline]
fn validate_n_or_tmax<T: FloatExt>(n: Option<usize>, t_max: Option<T>, type_name: &'static str) {
  if n.is_none() && t_max.is_none() {
    panic!("{type_name}: n or t_max must be provided");
  }
}

impl<T, D> CustomJt<T, D>
where
  T: FloatExt,
  D: Distribution<T> + Send + Sync,
{
  pub fn new(n: Option<usize>, t_max: Option<T>, distribution: D) -> Self {
    validate_n_or_tmax(n, t_max, "CustomJt");
    CustomJt {
      n,
      t_max,
      distribution,
      seed: Unseeded,
    }
  }
}

impl<T, D> CustomJt<T, D, Deterministic>
where
  T: FloatExt,
  D: Distribution<T> + Send + Sync,
{
  pub fn seeded(n: Option<usize>, t_max: Option<T>, distribution: D, seed: u64) -> Self {
    validate_n_or_tmax(n, t_max, "CustomJt");
    CustomJt {
      n,
      t_max,
      distribution,
      seed: Deterministic::new(seed),
    }
  }
}

#[cfg(feature = "python")]
#[pyo3::prelude::pyclass]
pub struct PyCustomJt {
  inner_f32: Option<CustomJt<f32, crate::traits::CallableDist<f32>>>,
  inner_f64: Option<CustomJt<f64, crate::traits::CallableDist<f64>>>,
}

#[cfg(feature = "python")]
#[pyo3::prelude::pymethods]
impl PyCustomJt {
  #[new]
  #[pyo3(signature = (distribution, n=None, t_max=None, dtype=None))]
  fn new(
    distribution: pyo3::Py<pyo3::PyAny>,
    n: Option<usize>,
    t_max: Option<f64>,
    dtype: Option<&str>,
  ) -> Self {
    match dtype.unwrap_or("f64") {
      "f32" => Self {
        inner_f32: Some(CustomJt::new(
          n,
          t_max.map(|v| v as f32),
          crate::traits::CallableDist::new(distribution),
        )),
        inner_f64: None,
      },
      _ => Self {
        inner_f32: None,
        inner_f64: Some(CustomJt::new(
          n,
          t_max,
          crate::traits::CallableDist::new(distribution),
        )),
      },
    }
  }

  fn sample<'py>(&self, py: pyo3::Python<'py>) -> pyo3::Py<pyo3::PyAny> {
    use numpy::IntoPyArray;
    use pyo3::IntoPyObjectExt;

    use crate::traits::ProcessExt;
    if let Some(ref inner) = self.inner_f64 {
      inner.sample().into_pyarray(py).into_py_any(py).unwrap()
    } else if let Some(ref inner) = self.inner_f32 {
      inner.sample().into_pyarray(py).into_py_any(py).unwrap()
    } else {
      unreachable!()
    }
  }

  fn sample_par<'py>(&self, py: pyo3::Python<'py>, m: usize) -> pyo3::Py<pyo3::PyAny> {
    use numpy::IntoPyArray;
    use numpy::ndarray::Array2;
    use pyo3::IntoPyObjectExt;

    use crate::traits::ProcessExt;
    if let Some(ref inner) = self.inner_f64 {
      let paths = inner.sample_par(m);
      let n = paths[0].len();
      let mut result = Array2::<f64>::zeros((m, n));
      for (i, path) in paths.iter().enumerate() {
        result.row_mut(i).assign(path);
      }
      result.into_pyarray(py).into_py_any(py).unwrap()
    } else if let Some(ref inner) = self.inner_f32 {
      let paths = inner.sample_par(m);
      let n = paths[0].len();
      let mut result = Array2::<f32>::zeros((m, n));
      for (i, path) in paths.iter().enumerate() {
        result.row_mut(i).assign(path);
      }
      result.into_pyarray(py).into_py_any(py).unwrap()
    } else {
      unreachable!()
    }
  }
}

impl<T, D, S: SeedExt> CustomJt<T, D, S>
where
  T: FloatExt,
  D: Distribution<T> + Send + Sync,
{
  /// Core sampling — monomorphised per seed strategy, zero runtime branching.
  pub(crate) fn sample_impl<S2: SeedExt>(&self, seed: &S2) -> Array1<T> {
    if let Some(n) = self.n {
      let mut random = Array1::<T>::zeros(n);
      let mut rng = seed.rng();
      for x in &mut random {
        *x = self.distribution.sample(&mut rng);
      }
      let mut x = Array1::<T>::zeros(n);
      for i in 1..n {
        x[i] = x[i - 1] + random[i - 1];
      }
      x
    } else if let Some(t_max) = self.t_max {
      let mut x = Vec::with_capacity(16);
      x.push(T::zero());
      let mut t = T::zero();
      seed.derive();
      let mut rng = seed.rng();

      while t < t_max {
        t += self.distribution.sample(&mut rng);
        x.push(t);
      }

      Array1::from(x)
    } else {
      unreachable!("validate_n_or_tmax ensures at least one of n, t_max is set")
    }
  }
}

impl<T, D, S: SeedExt> ProcessExt<T> for CustomJt<T, D, S>
where
  T: FloatExt,
  D: Distribution<T> + Send + Sync,
{
  type Output = Array1<T>;

  fn sample(&self) -> Self::Output {
    self.sample_impl(&self.seed)
  }
}