use crate::config::EnsembleStatistic;
use crate::hawkes::kernel::ExcitationKernel;
use crate::hawkes::HawkesProcess;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForecastMetrics {
pub mae: f64,
pub rmse: f64,
pub n: usize,
}
pub fn forecast_errors(actual_gaps: &[f64], forecast_gaps: &[f64]) -> ForecastMetrics {
let n = actual_gaps.len().min(forecast_gaps.len());
if n == 0 {
return ForecastMetrics {
mae: f64::NAN,
rmse: f64::NAN,
n: 0,
};
}
let mut sum_abs = 0.0_f64;
let mut sum_sq = 0.0_f64;
for i in 0..n {
let err = actual_gaps[i] - forecast_gaps[i];
sum_abs += err.abs();
sum_sq += err * err;
}
let n_f = n as f64;
ForecastMetrics {
mae: sum_abs / n_f,
rmse: (sum_sq / n_f).sqrt(),
n,
}
}
pub fn ensemble_forecast<K: ExcitationKernel>(
hp: &HawkesProcess<K>,
current_ts: f64,
history: &[f64],
n_events: usize,
mc_paths: usize,
statistic: EnsembleStatistic,
) -> Vec<f64> {
let mut all_gaps: Vec<Vec<f64>> = Vec::with_capacity(mc_paths);
for _ in 0..mc_paths {
let path = hp.generate_values_conditioned(current_ts, history, n_events);
let gaps: Vec<f64> = path.iter().map(|&t| t - current_ts).collect();
all_gaps.push(gaps);
}
(0..n_events)
.map(|i| {
let mut column: Vec<f64> = all_gaps
.iter()
.filter_map(|path| path.get(i).copied())
.collect();
match statistic {
EnsembleStatistic::Mean => {
if column.is_empty() {
0.0
} else {
column.iter().sum::<f64>() / column.len() as f64
}
}
_ => {
column.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q = match statistic {
EnsembleStatistic::Median => 0.5,
EnsembleStatistic::P25 => 0.25,
EnsembleStatistic::P75 => 0.75,
EnsembleStatistic::Mean => unreachable!(),
};
percentile(&column, q)
}
}
})
.collect()
}
pub fn percentile(sorted: &[f64], q: f64) -> f64 {
assert!(!sorted.is_empty(), "percentile called on empty slice");
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let pos = q * (n - 1) as f64;
let lo = pos.floor() as usize;
let hi = pos.ceil().min((n - 1) as f64) as usize;
let frac = pos - lo as f64;
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LRTestResult {
pub statistic: f64,
pub df: usize,
pub critical_value: f64,
pub reject_h0: bool,
}
pub fn likelihood_ratio_test(ll_full: f64, ll_restricted: f64, df: usize) -> LRTestResult {
let statistic = 2.0 * (ll_full - ll_restricted);
let critical_value = match df {
1 => 3.841,
2 => 5.991,
3 => 7.815,
4 => 9.488,
5 => 11.070,
_ => {
let z = 1.6449; let k = df as f64;
k * (1.0 - 2.0 / (9.0 * k) + z * (2.0 / (9.0 * k)).sqrt()).powi(3)
}
};
let reject_h0 = statistic > critical_value;
LRTestResult {
statistic,
df,
critical_value,
reject_h0,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forecast_errors_basic() {
let actual = vec![1.0, 3.0, 6.0];
let forecast = vec![1.5, 2.5, 7.0];
let m = forecast_errors(&actual, &forecast);
assert_eq!(m.n, 3);
let expected_mae = (0.5 + 0.5 + 1.0) / 3.0;
assert!((m.mae - expected_mae).abs() < 1e-12);
}
#[test]
fn test_percentile_median() {
let sorted = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert!((percentile(&sorted, 0.5) - 3.0).abs() < 1e-12);
}
#[test]
fn test_percentile_interpolation() {
let sorted = vec![10.0, 20.0];
assert!((percentile(&sorted, 0.5) - 15.0).abs() < 1e-12);
assert!((percentile(&sorted, 0.25) - 12.5).abs() < 1e-12);
}
#[test]
fn test_lr_test_significant() {
let result = likelihood_ratio_test(-100.0, -200.0, 2);
assert_eq!(result.statistic, 200.0);
assert!(result.reject_h0);
}
#[test]
fn test_lr_test_not_significant() {
let result = likelihood_ratio_test(-100.0, -100.5, 2);
assert_eq!(result.statistic, 1.0);
assert!(!result.reject_h0);
}
}