use crate::estimate::UnifiedFitResult;
use crate::inference::alo::AloDiagnostics;
use crate::inference::psis::pareto_smooth_weights;
use crate::types::{GlmLikelihoodSpec, LikelihoodSpec};
use ndarray::{Array1, ArrayView1, ArrayView2};
#[derive(Debug, Clone)]
pub struct AloElpd {
pub elpd: f64,
pub se: f64,
pub pointwise: Array1<f64>,
pub k_hat_max: f64,
pub n_k_bad: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct CorrectedEdf {
pub conditional: f64,
pub corrected: f64,
}
impl CorrectedEdf {
pub fn rho_uncertainty_df(&self) -> f64 {
self.corrected - self.conditional
}
}
#[derive(Debug, Clone)]
pub struct ModelComparison {
pub log_lik: f64,
pub edf: CorrectedEdf,
pub aic_conditional: f64,
pub aic_corrected: f64,
pub loo: Option<AloElpd>,
}
pub fn corrected_edf(
edf_conditional: f64,
weighted_gram: Option<ArrayView2<'_, f64>>,
smoothing_correction: Option<ArrayView2<'_, f64>>,
phi: f64,
) -> CorrectedEdf {
let correction = wps_correction_term(weighted_gram, smoothing_correction, phi);
CorrectedEdf {
conditional: edf_conditional,
corrected: edf_conditional + correction,
}
}
fn wps_correction_term(
weighted_gram: Option<ArrayView2<'_, f64>>,
smoothing_correction: Option<ArrayView2<'_, f64>>,
phi: f64,
) -> f64 {
let (Some(xwx), Some(corr)) = (weighted_gram, smoothing_correction) else {
return 0.0;
};
let k = xwx.nrows();
if k == 0
|| xwx.ncols() != k
|| corr.nrows() != k
|| corr.ncols() != k
|| !(phi.is_finite() && phi > 0.0)
{
return 0.0;
}
let mut trace = 0.0;
for i in 0..k {
for j in 0..k {
trace += xwx[[i, j]] * corr[[j, i]];
}
}
trace /= phi;
if trace.is_finite() { trace } else { 0.0 }
}
pub fn alo_elpd(
loglik_fitted: ArrayView1<'_, f64>,
loglik_loo: ArrayView1<'_, f64>,
) -> Option<AloElpd> {
let n = loglik_loo.len();
if n == 0 || loglik_fitted.len() != n {
return None;
}
if loglik_fitted
.iter()
.chain(loglik_loo.iter())
.any(|v| !v.is_finite())
{
return None;
}
let log_ratio: Array1<f64> = &loglik_fitted.to_owned() - &loglik_loo.to_owned();
let max_lr = log_ratio.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if !max_lr.is_finite() {
return None;
}
let raw: Vec<f64> = log_ratio.iter().map(|&lr| (lr - max_lr).exp()).collect();
let (k_hat_max, n_k_bad);
match pareto_smooth_weights(&raw) {
Some(psis) => {
k_hat_max = psis.k_hat;
n_k_bad = if psis.k_hat > 0.7 { psis.tail_count } else { 0 };
}
None => {
k_hat_max = f64::NAN;
n_k_bad = 0;
}
}
let pointwise = loglik_loo.to_owned();
let elpd: f64 = pointwise.iter().sum();
let mean = elpd / n as f64;
let var = pointwise
.iter()
.map(|&p| (p - mean) * (p - mean))
.sum::<f64>()
/ n as f64;
let se = (n as f64 * var).sqrt();
Some(AloElpd {
elpd,
se,
pointwise,
k_hat_max,
n_k_bad,
})
}
#[derive(Debug, Clone)]
pub struct ComparisonReport {
pub delta_elpd: f64,
pub delta_elpd_se: f64,
pub delta_aic_corrected: f64,
pub rows_aligned: bool,
}
pub fn compare(a: &ModelComparison, b: &ModelComparison) -> ComparisonReport {
let delta_aic_corrected = a.aic_corrected - b.aic_corrected;
match (&a.loo, &b.loo) {
(Some(la), Some(lb))
if la.pointwise.len() == lb.pointwise.len() && !la.pointwise.is_empty() =>
{
let n = la.pointwise.len();
let diff: Array1<f64> = &la.pointwise - &lb.pointwise;
let delta_elpd: f64 = diff.iter().sum();
let mean = delta_elpd / n as f64;
let var = diff.iter().map(|&d| (d - mean) * (d - mean)).sum::<f64>() / n as f64;
ComparisonReport {
delta_elpd,
delta_elpd_se: (n as f64 * var).sqrt(),
delta_aic_corrected,
rows_aligned: true,
}
}
_ => ComparisonReport {
delta_elpd: f64::NAN,
delta_elpd_se: f64::NAN,
delta_aic_corrected,
rows_aligned: false,
},
}
}
pub fn model_comparison_from_unified(
fit: &UnifiedFitResult,
y: ArrayView1<'_, f64>,
eta_hat: ArrayView1<'_, f64>,
prior_weights: ArrayView1<'_, f64>,
alo: Option<&AloDiagnostics>,
) -> ModelComparison {
let log_lik = fit.log_likelihood;
let phi = fit.dispersion_phi();
let edf_conditional = fit.edf_total().unwrap_or(f64::NAN);
let edf = corrected_edf(
edf_conditional,
fit.weighted_gram().map(|g| g.view()),
fit.smoothing_correction().map(|c| c.view()),
phi,
);
let aic_conditional = -2.0 * log_lik + 2.0 * edf.conditional;
let aic_corrected = -2.0 * log_lik + 2.0 * edf.corrected;
let loo = alo.and_then(|alo| {
let spec = fit.likelihood_family.clone()?;
alo_elpd_from_family(
y,
eta_hat,
alo.eta_tilde.view(),
prior_weights,
&spec,
fit.likelihood_scale.clone(),
)
});
ModelComparison {
log_lik,
edf,
aic_conditional,
aic_corrected,
loo,
}
}
pub fn alo_elpd_from_family(
y: ArrayView1<'_, f64>,
eta_hat: ArrayView1<'_, f64>,
eta_loo: ArrayView1<'_, f64>,
prior_weights: ArrayView1<'_, f64>,
spec: &LikelihoodSpec,
scale: crate::types::LikelihoodScaleMetadata,
) -> Option<AloElpd> {
use crate::families::strategy::{FamilyStrategy, strategy_for_spec};
use crate::pirls::pointwise_loglikelihood_omitting_constants;
let n = y.len();
if eta_hat.len() != n || eta_loo.len() != n || prior_weights.len() != n || n == 0 {
return None;
}
let strategy = strategy_for_spec(spec);
let mu_hat = strategy.inverse_link_array(eta_hat).ok()?;
let mu_loo = strategy.inverse_link_array(eta_loo).ok()?;
let glm = GlmLikelihoodSpec {
spec: spec.clone(),
scale,
};
let ll_hat = pointwise_loglikelihood_omitting_constants(y, &mu_hat, &glm, prior_weights);
let ll_loo = pointwise_loglikelihood_omitting_constants(y, &mu_loo, &glm, prior_weights);
alo_elpd(ll_hat.view(), ll_loo.view())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, array};
#[test]
fn wps_correction_is_trace_of_h_f_sigma_over_phi() {
let xwx = Array2::<f64>::eye(3);
let corr = array![[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]];
let edf = corrected_edf(3.0, Some(xwx.view()), Some(corr.view()), 2.0);
assert!((edf.corrected - 9.0).abs() < 1e-12);
assert!((edf.rho_uncertainty_df() - 6.0).abs() < 1e-12);
assert!((edf.conditional - 3.0).abs() < 1e-12);
}
#[test]
fn corrected_edf_falls_back_to_conditional_without_inputs() {
let edf = corrected_edf(5.5, None, None, 1.0);
assert_eq!(edf.conditional, 5.5);
assert_eq!(edf.corrected, 5.5);
assert_eq!(edf.rho_uncertainty_df(), 0.0);
}
#[test]
fn alo_elpd_sums_pointwise_and_flags_no_tail() {
let ll: Array1<f64> = array![-1.0, -2.0, -0.5, -1.5, -0.8, -1.2, -0.9, -1.1, -0.7, -1.3];
let loo = alo_elpd(ll.view(), ll.view()).expect("alo elpd");
let expected: f64 = ll.iter().sum();
assert!((loo.elpd - expected).abs() < 1e-9);
assert_eq!(loo.pointwise.len(), ll.len());
assert_eq!(loo.n_k_bad, 0);
}
#[test]
fn alo_elpd_pointwise_is_local_to_alo_loglikelihoods() {
let ll_loo: Array1<f64> = array![
-1.0, -1.1, -1.2, -1.3, -1.4, -1.5, -1.6, -1.7, -1.8, -1.9, -2.0, -2.1
];
let ll_hat = ll_loo.clone();
let mut ll_hat_perturbed = ll_loo.clone();
ll_hat_perturbed[7] += 10.0;
let base = alo_elpd(ll_hat.view(), ll_loo.view()).expect("alo elpd");
let perturbed = alo_elpd(ll_hat_perturbed.view(), ll_loo.view()).expect("alo elpd");
for i in 0..ll_loo.len() {
assert_eq!(base.pointwise[i], ll_loo[i]);
assert_eq!(perturbed.pointwise[i], ll_loo[i]);
if i != 7 {
assert_eq!(base.pointwise[i], perturbed.pointwise[i]);
}
}
assert_eq!(perturbed.elpd, base.elpd);
}
fn gpd_sample(u: f64, k: f64, sigma: f64) -> f64 {
sigma * ((1.0 - u).powf(-k) - 1.0) / k
}
#[test]
fn alo_elpd_influence_diagnostic_fires_on_heavy_tailed_ratios() {
let mut ratios = vec![1.0; 200];
for i in 1..=120 {
let u = (i as f64 - 0.5) / 120.0;
ratios.push(1.0 + gpd_sample(u, 1.2, 0.5));
}
let ll_loo: Array1<f64> = Array1::from_elem(ratios.len(), -1.0);
let ll_hat: Array1<f64> = Array1::from_iter(
ll_loo
.iter()
.zip(ratios.iter())
.map(|(&ll, &ratio)| ll + ratio.ln()),
);
let loo = alo_elpd(ll_hat.view(), ll_loo.view()).expect("alo elpd");
assert_eq!(loo.pointwise, ll_loo);
assert!((loo.elpd - -(ratios.len() as f64)).abs() < 1e-12);
assert!(
loo.k_hat_max > 0.7,
"heavy fitted-vs-ALO ratio tail should fire influence diagnostic; got k_hat={}",
loo.k_hat_max
);
assert!(
loo.n_k_bad > 0,
"heavy fitted-vs-ALO ratio tail should count influential tail observations"
);
}
#[test]
fn compare_pairs_pointwise_and_orients_a_minus_b() {
let mk = |pw: Array1<f64>, aic: f64| ModelComparison {
log_lik: 0.0,
edf: CorrectedEdf {
conditional: 0.0,
corrected: 0.0,
},
aic_conditional: aic,
aic_corrected: aic,
loo: Some(AloElpd {
elpd: pw.iter().sum(),
se: 0.0,
pointwise: pw,
k_hat_max: 0.1,
n_k_bad: 0,
}),
};
let a = mk(array![-1.0, -1.0, -1.0, -1.0], 10.0);
let b = mk(array![-2.0, -2.0, -2.0, -2.0], 14.0);
let rep = compare(&a, &b);
assert!(rep.rows_aligned);
assert!((rep.delta_elpd - 4.0).abs() < 1e-12);
assert!((rep.delta_aic_corrected + 4.0).abs() < 1e-12);
assert!(rep.delta_elpd_se.abs() < 1e-12);
}
#[test]
fn compare_refuses_unpaired_rows() {
let mk = |pw: Array1<f64>| ModelComparison {
log_lik: 0.0,
edf: CorrectedEdf {
conditional: 0.0,
corrected: 0.0,
},
aic_conditional: 0.0,
aic_corrected: 5.0,
loo: Some(AloElpd {
elpd: pw.iter().sum(),
se: 0.0,
pointwise: pw,
k_hat_max: 0.1,
n_k_bad: 0,
}),
};
let a = mk(array![-1.0, -1.0, -1.0]);
let b = mk(array![-1.0, -1.0]);
let rep = compare(&a, &b);
assert!(!rep.rows_aligned);
assert!(rep.delta_elpd.is_nan());
assert!((rep.delta_aic_corrected - 0.0).abs() < 1e-12);
}
}