use rand::Rng;
use rand::distr::Open01;
use std::fmt;
#[derive(Debug, Clone)]
pub struct PiecewiseExponential {
rates: Vec<f64>,
cumulative_time: Vec<f64>,
cumulative_hazard: Vec<f64>,
}
impl PiecewiseExponential {
pub fn new(durations: &[f64], rates: &[f64]) -> Result<Self, PiecewiseExponentialError> {
let interval_count = durations.len();
if interval_count == 0 {
return Err(PiecewiseExponentialError::EmptyIntervals);
}
if interval_count != rates.len() {
return Err(PiecewiseExponentialError::LengthMismatch {
durations: interval_count,
rates: rates.len(),
});
}
let last_index = interval_count - 1;
for (idx, &duration) in durations.iter().enumerate() {
if duration.is_nan() {
if idx == last_index {
return Err(PiecewiseExponentialError::FinalDurationInvalid);
}
return Err(PiecewiseExponentialError::NonFiniteDuration { index: idx });
}
if idx < last_index {
if !duration.is_finite() {
return Err(PiecewiseExponentialError::NonFiniteDuration { index: idx });
}
if duration <= 0.0 {
return Err(PiecewiseExponentialError::NonPositiveDuration { index: idx });
}
} else {
if duration <= 0.0 {
return Err(PiecewiseExponentialError::NonPositiveFinalDuration);
}
if !duration.is_finite() && !duration.is_infinite() {
return Err(PiecewiseExponentialError::FinalDurationInvalid);
}
if duration.is_infinite() && duration.is_sign_negative() {
return Err(PiecewiseExponentialError::NonPositiveFinalDuration);
}
}
}
for (idx, &rate) in rates.iter().enumerate() {
if !rate.is_finite() {
return Err(PiecewiseExponentialError::NonFiniteRate { index: idx });
}
if rate <= 0.0 {
return Err(PiecewiseExponentialError::NonPositiveRate { index: idx });
}
}
let mut cumulative_time = Vec::with_capacity(interval_count);
let mut cumulative_hazard = Vec::with_capacity(interval_count);
cumulative_time.push(0.0);
cumulative_hazard.push(0.0);
let mut time_acc = 0.0;
let mut hazard_acc = 0.0;
for idx in 0..last_index {
time_acc += durations[idx];
hazard_acc += durations[idx] * rates[idx];
cumulative_time.push(time_acc);
cumulative_hazard.push(hazard_acc);
}
Ok(Self {
rates: rates.to_vec(),
cumulative_time,
cumulative_hazard,
})
}
pub fn sample<R>(&self, rng: &mut R) -> f64
where
R: Rng + ?Sized,
{
let uniform: f64 = rng.sample(Open01);
let hazard = -uniform.ln();
self.sample_from_hazard(hazard)
}
pub fn inverse_cdf(&self, uniform: f64) -> Result<f64, PiecewiseExponentialSampleError> {
if !(uniform > 0.0 && uniform <= 1.0) {
return Err(PiecewiseExponentialSampleError::UniformOutOfRange { value: uniform });
}
let hazard = -uniform.ln();
Ok(self.sample_from_hazard(hazard))
}
fn sample_from_hazard(&self, hazard: f64) -> f64 {
let idx = self
.cumulative_hazard
.partition_point(|&value| value <= hazard)
.saturating_sub(1);
let base_time = self.cumulative_time[idx];
let offset = (hazard - self.cumulative_hazard[idx]) / self.rates[idx];
base_time + offset
}
pub fn sample_n<R>(&self, n: usize, rng: &mut R) -> Vec<f64>
where
R: Rng + ?Sized,
{
(0..n).map(|_| self.sample(rng)).collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PiecewiseExponentialError {
EmptyIntervals,
LengthMismatch {
durations: usize,
rates: usize,
},
NonFiniteDuration {
index: usize,
},
NonPositiveDuration {
index: usize,
},
NonPositiveFinalDuration,
FinalDurationInvalid,
NonFiniteRate {
index: usize,
},
NonPositiveRate {
index: usize,
},
}
impl fmt::Display for PiecewiseExponentialError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PiecewiseExponentialError::EmptyIntervals => {
f.write_str("durations must contain at least one interval")
}
PiecewiseExponentialError::LengthMismatch { durations, rates } => write!(
f,
"durations and rates must have the same length ({} vs {})",
durations, rates
),
PiecewiseExponentialError::NonFiniteDuration { index } => {
write!(f, "duration at index {} must be finite", index)
}
PiecewiseExponentialError::NonPositiveDuration { index } => {
write!(f, "duration at index {} must be positive", index)
}
PiecewiseExponentialError::NonPositiveFinalDuration => {
f.write_str("final duration must be positive")
}
PiecewiseExponentialError::FinalDurationInvalid => f.write_str(
"final duration must be finite or positive infinity (use f64::INFINITY)",
),
PiecewiseExponentialError::NonFiniteRate { index } => {
write!(f, "rate at index {} must be finite", index)
}
PiecewiseExponentialError::NonPositiveRate { index } => {
write!(f, "rate at index {} must be strictly positive", index)
}
}
}
}
impl std::error::Error for PiecewiseExponentialError {}
#[derive(Debug, Clone, PartialEq)]
pub enum PiecewiseExponentialSampleError {
UniformOutOfRange {
value: f64,
},
}
impl fmt::Display for PiecewiseExponentialSampleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PiecewiseExponentialSampleError::UniformOutOfRange { value } => write!(
f,
"uniform variate {} must lie within the interval (0, 1]",
value
),
}
}
}
impl std::error::Error for PiecewiseExponentialSampleError {}