use serde::{Deserialize, Serialize};
use super::errors::PoissonError;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PoissonEstimationConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoissonEstimationResult {
pub lambda: f64,
pub log_likelihood: f64,
pub iterations: usize,
pub converged: bool,
pub aic: f64,
pub bic: f64,
}
pub fn validate_events(events: &[f64]) -> Result<(), PoissonError> {
if events.len() < 2 {
return Err(PoissonError::InvalidEventTimes(
"Need at least 2 events for estimation".into(),
));
}
for (i, &t) in events.iter().enumerate() {
if !t.is_finite() {
return Err(PoissonError::InvalidEventTimes(format!(
"Event at index {} is not finite: {}",
i, t
)));
}
}
for i in 1..events.len() {
if events[i] <= events[i - 1] {
return Err(PoissonError::InvalidEventTimes(format!(
"Events not strictly increasing at index {}: {} <= {}",
i, events[i], events[i - 1]
)));
}
}
Ok(())
}
pub fn log_likelihood(lambda: f64, events: &[f64]) -> f64 {
let n = events.len();
if n < 2 || lambda <= 0.0 {
return f64::NEG_INFINITY;
}
let t_span = events[n - 1] - events[0];
let n_gaps = (n - 1) as f64;
n_gaps * lambda.ln() - lambda * t_span
}
pub fn estimate_poisson_mle(
events: &[f64],
_config: &PoissonEstimationConfig,
) -> Result<PoissonEstimationResult, PoissonError> {
validate_events(events)?;
let n = events.len();
let t_span = events[n - 1] - events[0];
if t_span <= 0.0 {
return Err(PoissonError::EstimationFailed(
"All events have the same timestamp".into(),
));
}
let n_gaps = (n - 1) as f64;
let lambda = n_gaps / t_span;
let ll = log_likelihood(lambda, events);
let k = 1.0_f64; let n_f = n as f64;
let aic = 2.0 * k - 2.0 * ll;
let bic = k * n_f.ln() - 2.0 * ll;
Ok(PoissonEstimationResult {
lambda,
log_likelihood: ll,
iterations: 1,
converged: true,
aic,
bic,
})
}
pub fn compensator(lambda: f64, events: &[f64], t: f64) -> f64 {
if events.is_empty() {
return 0.0;
}
let t0 = events[0];
lambda * (t - t0).max(0.0)
}
pub fn time_rescaling_residuals(lambda: f64, events: &[f64]) -> Vec<f64> {
let n = events.len();
if n < 2 {
return vec![];
}
let mut residuals = Vec::with_capacity(n - 1);
for i in 1..n {
residuals.push(lambda * (events[i] - events[i - 1]));
}
residuals
}