use crate::{
FloatExt, SimulationError, XResult, check_duration_time_step, random::gamma,
simulation::prelude::*,
};
use rand_distr::{Distribution, Exp1, Open01, StandardNormal};
#[derive(Debug, Clone)]
pub struct Gamma<T: FloatExt = f64> {
shape: T,
rate: T,
}
impl<T: FloatExt> Gamma<T> {
pub fn new(shape: T, rate: T) -> XResult<Self> {
if shape <= T::zero() {
return Err(SimulationError::InvalidParameters(format!(
"The `shape` must be positive, got {shape:?}"
))
.into());
}
if rate <= T::zero() {
return Err(SimulationError::InvalidParameters(format!(
"The `rate` must be positive, got {rate:?}"
))
.into());
}
Ok(Self { shape, rate })
}
pub fn get_shape(&self) -> T {
self.shape
}
pub fn get_rate(&self) -> T {
self.rate
}
}
impl<T: FloatExt> ContinuousProcess<T> for Gamma<T>
where
StandardNormal: Distribution<T>,
Exp1: Distribution<T>,
Open01: Distribution<T>,
{
fn start(&self) -> T {
T::zero()
}
fn simulate(&self, duration: T, time_step: T) -> XResult<(Vec<T>, Vec<T>)> {
simulate_gamma(self.shape, self.rate, duration, time_step)
}
fn displacement(&self, duration: T, time_step: T) -> XResult<T> {
check_duration_time_step(duration, time_step)?;
let scale = T::one() / self.rate;
gamma::rand(self.shape * duration, scale)
}
}
pub fn simulate_gamma<T: FloatExt>(
shape: T,
rate: T,
duration: T,
time_step: T,
) -> XResult<(Vec<T>, Vec<T>)>
where
StandardNormal: Distribution<T>,
Exp1: Distribution<T>,
Open01: Distribution<T>,
{
check_duration_time_step(duration, time_step)?;
let num_steps = (duration / time_step).ceil().to_usize().unwrap();
let scale = T::one() / rate;
let noise = gamma::rands(shape * time_step, scale, num_steps - 1)?;
let mut t = Vec::with_capacity(num_steps + 1);
let mut x = Vec::with_capacity(num_steps + 1);
t.push(T::zero());
x.push(T::zero());
let mut current_x = T::zero();
let mut current_t = T::zero();
for xi in noise {
current_t += time_step;
t.push(current_t);
current_x += xi;
x.push(current_x);
}
let last_step = duration - current_t;
let xi = gamma::rand(shape * last_step, scale)?;
current_x += xi;
x.push(current_x);
t.push(duration);
Ok((t, x))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simulate_gamma() {
let gamma = Gamma::new(0.5, 1.0).unwrap();
let time_step = 0.1;
let duration = 1.0;
let (t, x) = gamma.simulate(duration, time_step).unwrap();
println!("t: {t:?}");
println!("x: {x:?}");
}
#[test]
fn test_displacement_scales_shape_by_elapsed_time() {
let gamma = Gamma::new(1000.0, 1000.0).unwrap();
let displacement = gamma.displacement(1.0, 0.1).unwrap();
assert!(
displacement < 2.0,
"displacement should be O(shape / rate * duration), got {displacement}"
);
}
#[test]
fn test_fpt() {
let gamma = Gamma::new(0.5, 1.0).unwrap();
let time_step = 0.1;
let fpt = gamma.fpt((-1.0, 1.0), 1000.0, time_step).unwrap().unwrap();
println!("fpt: {fpt:?}");
}
#[test]
fn test_occupation_time() {
let gamma = Gamma::new(0.5, 1.0).unwrap();
let time_step = 0.1;
let ot = gamma.occupation_time((-1.0, 1.0), 10.0, time_step).unwrap();
println!("ot: {ot:?}");
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Gamma>();
}
}