use crate::estimate::UnifiedFitResult;
use crate::inference::alo::AloDiagnostics;
use crate::inference::psis::pareto_smooth_weights;
use crate::types::{GlmLikelihoodSpec, LikelihoodSpec};
use ndarray::{Array1, ArrayView1, ArrayView2};
#[derive(Debug, Clone)]
pub struct PsisLoo {
pub elpd: f64,
pub se: f64,
pub pointwise: Array1<f64>,
pub k_hat_max: f64,
pub n_k_bad: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct CorrectedEdf {
pub conditional: f64,
pub corrected: f64,
}
impl CorrectedEdf {
pub fn rho_uncertainty_df(&self) -> f64 {
self.corrected - self.conditional
}
}
#[derive(Debug, Clone)]
pub struct ModelComparison {
pub log_lik: f64,
pub edf: CorrectedEdf,
pub aic_conditional: f64,
pub aic_corrected: f64,
pub loo: Option<PsisLoo>,
}
pub fn corrected_edf(
edf_conditional: f64,
penalized_hessian: Option<ArrayView2<'_, f64>>,
coefficient_influence: Option<ArrayView2<'_, f64>>,
smoothing_correction: Option<ArrayView2<'_, f64>>,
phi: f64,
) -> CorrectedEdf {
let correction = wps_correction_term(
penalized_hessian,
coefficient_influence,
smoothing_correction,
phi,
);
CorrectedEdf {
conditional: edf_conditional,
corrected: edf_conditional + correction,
}
}
fn wps_correction_term(
penalized_hessian: Option<ArrayView2<'_, f64>>,
coefficient_influence: Option<ArrayView2<'_, f64>>,
smoothing_correction: Option<ArrayView2<'_, f64>>,
phi: f64,
) -> f64 {
let (Some(h), Some(f), Some(corr)) = (
penalized_hessian,
coefficient_influence,
smoothing_correction,
) else {
return 0.0;
};
let k = h.nrows();
if k == 0
|| h.ncols() != k
|| f.nrows() != k
|| f.ncols() != k
|| corr.nrows() != k
|| corr.ncols() != k
|| !(phi.is_finite() && phi > 0.0)
{
return 0.0;
}
let sigma_rho = &corr.to_owned() / phi;
let m = f.dot(&sigma_rho);
let mut trace = 0.0;
for i in 0..k {
for j in 0..k {
trace += h[[i, j]] * m[[j, i]];
}
}
if trace.is_finite() { trace } else { 0.0 }
}
pub fn psis_loo(
loglik_fitted: ArrayView1<'_, f64>,
loglik_loo: ArrayView1<'_, f64>,
) -> Option<PsisLoo> {
let n = loglik_loo.len();
if n == 0 || loglik_fitted.len() != n {
return None;
}
if loglik_fitted
.iter()
.chain(loglik_loo.iter())
.any(|v| !v.is_finite())
{
return None;
}
let log_ratio: Array1<f64> = &loglik_fitted.to_owned() - &loglik_loo.to_owned();
let max_lr = log_ratio.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if !max_lr.is_finite() {
return None;
}
let raw: Vec<f64> = log_ratio.iter().map(|&lr| (lr - max_lr).exp()).collect();
let pointwise: Array1<f64>;
let (k_hat_max, n_k_bad);
match pareto_smooth_weights(&raw) {
Some(psis) => {
let mut pw = Array1::<f64>::zeros(n);
for i in 0..n {
let r = raw[i];
let rs = psis.smoothed[i];
let shift = if r > 0.0 && rs > 0.0 {
(rs / r).ln()
} else {
0.0
};
pw[i] = loglik_loo[i] + shift;
}
pointwise = pw;
k_hat_max = psis.k_hat;
n_k_bad = if psis.k_hat > 0.7 { psis.tail_count } else { 0 };
}
None => {
pointwise = loglik_loo.to_owned();
k_hat_max = f64::NAN;
n_k_bad = 0;
}
}
let elpd: f64 = pointwise.iter().sum();
let mean = elpd / n as f64;
let var = pointwise
.iter()
.map(|&p| (p - mean) * (p - mean))
.sum::<f64>()
/ n as f64;
let se = (n as f64 * var).sqrt();
Some(PsisLoo {
elpd,
se,
pointwise,
k_hat_max,
n_k_bad,
})
}
#[derive(Debug, Clone)]
pub struct ComparisonReport {
pub delta_elpd: f64,
pub delta_elpd_se: f64,
pub delta_aic_corrected: f64,
pub rows_aligned: bool,
}
pub fn compare(a: &ModelComparison, b: &ModelComparison) -> ComparisonReport {
let delta_aic_corrected = a.aic_corrected - b.aic_corrected;
match (&a.loo, &b.loo) {
(Some(la), Some(lb))
if la.pointwise.len() == lb.pointwise.len() && !la.pointwise.is_empty() =>
{
let n = la.pointwise.len();
let diff: Array1<f64> = &la.pointwise - &lb.pointwise;
let delta_elpd: f64 = diff.iter().sum();
let mean = delta_elpd / n as f64;
let var = diff.iter().map(|&d| (d - mean) * (d - mean)).sum::<f64>() / n as f64;
ComparisonReport {
delta_elpd,
delta_elpd_se: (n as f64 * var).sqrt(),
delta_aic_corrected,
rows_aligned: true,
}
}
_ => ComparisonReport {
delta_elpd: f64::NAN,
delta_elpd_se: f64::NAN,
delta_aic_corrected,
rows_aligned: false,
},
}
}
pub fn model_comparison_from_unified(
fit: &UnifiedFitResult,
y: ArrayView1<'_, f64>,
eta_hat: ArrayView1<'_, f64>,
prior_weights: ArrayView1<'_, f64>,
alo: Option<&AloDiagnostics>,
) -> ModelComparison {
let log_lik = fit.log_likelihood;
let phi = fit.dispersion_phi();
let edf_conditional = fit.edf_total().unwrap_or(f64::NAN);
let edf = corrected_edf(
edf_conditional,
fit.penalized_hessian().map(|h| h.view()),
fit.coefficient_influence().map(|f| f.view()),
fit.smoothing_correction().map(|c| c.view()),
phi,
);
let aic_conditional = -2.0 * log_lik + 2.0 * edf.conditional;
let aic_corrected = -2.0 * log_lik + 2.0 * edf.corrected;
let loo = alo.and_then(|alo| {
let spec = fit.likelihood_family.clone()?;
psis_loo_from_family(
y,
eta_hat,
alo.eta_tilde.view(),
prior_weights,
&spec,
fit.likelihood_scale.clone(),
)
});
ModelComparison {
log_lik,
edf,
aic_conditional,
aic_corrected,
loo,
}
}
pub fn psis_loo_from_family(
y: ArrayView1<'_, f64>,
eta_hat: ArrayView1<'_, f64>,
eta_loo: ArrayView1<'_, f64>,
prior_weights: ArrayView1<'_, f64>,
spec: &LikelihoodSpec,
scale: crate::types::LikelihoodScaleMetadata,
) -> Option<PsisLoo> {
use crate::families::strategy::{FamilyStrategy, strategy_for_spec};
use crate::pirls::pointwise_loglikelihood_omitting_constants;
let n = y.len();
if eta_hat.len() != n || eta_loo.len() != n || prior_weights.len() != n || n == 0 {
return None;
}
let strategy = strategy_for_spec(spec);
let mu_hat = strategy.inverse_link_array(eta_hat).ok()?;
let mu_loo = strategy.inverse_link_array(eta_loo).ok()?;
let glm = GlmLikelihoodSpec {
spec: spec.clone(),
scale,
};
let ll_hat = pointwise_loglikelihood_omitting_constants(y, &mu_hat, &glm, prior_weights);
let ll_loo = pointwise_loglikelihood_omitting_constants(y, &mu_loo, &glm, prior_weights);
psis_loo(ll_hat.view(), ll_loo.view())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, array};
#[test]
fn wps_correction_is_trace_of_h_f_sigma_over_phi() {
let h = Array2::<f64>::eye(3);
let f = Array2::<f64>::eye(3);
let corr = array![[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]];
let phi = 2.0;
let edf = corrected_edf(3.0, Some(h.view()), Some(f.view()), Some(corr.view()), phi);
assert!((edf.corrected - 9.0).abs() < 1e-12);
assert!((edf.rho_uncertainty_df() - 6.0).abs() < 1e-12);
assert!((edf.conditional - 3.0).abs() < 1e-12);
}
#[test]
fn corrected_edf_falls_back_to_conditional_without_inputs() {
let edf = corrected_edf(5.5, None, None, None, 1.0);
assert_eq!(edf.conditional, 5.5);
assert_eq!(edf.corrected, 5.5);
assert_eq!(edf.rho_uncertainty_df(), 0.0);
}
#[test]
fn psis_loo_elpd_sums_pointwise_and_flags_no_tail() {
let ll: Array1<f64> = array![-1.0, -2.0, -0.5, -1.5, -0.8, -1.2, -0.9, -1.1, -0.7, -1.3];
let loo = psis_loo(ll.view(), ll.view()).expect("psis-loo");
let expected: f64 = ll.iter().sum();
assert!((loo.elpd - expected).abs() < 1e-9);
assert_eq!(loo.pointwise.len(), ll.len());
assert_eq!(loo.n_k_bad, 0);
}
#[test]
fn psis_loo_penalizes_loo_drop_relative_to_fitted() {
let n = 40;
let ll_hat: Array1<f64> = Array1::from_elem(n, -1.0);
let ll_loo: Array1<f64> = Array1::from_elem(n, -1.5);
let loo = psis_loo(ll_hat.view(), ll_loo.view()).expect("psis-loo");
assert!((loo.elpd - (-1.5 * n as f64)).abs() < 1e-6);
}
#[test]
fn compare_pairs_pointwise_and_orients_a_minus_b() {
let mk = |pw: Array1<f64>, aic: f64| ModelComparison {
log_lik: 0.0,
edf: CorrectedEdf {
conditional: 0.0,
corrected: 0.0,
},
aic_conditional: aic,
aic_corrected: aic,
loo: Some(PsisLoo {
elpd: pw.iter().sum(),
se: 0.0,
pointwise: pw,
k_hat_max: 0.1,
n_k_bad: 0,
}),
};
let a = mk(array![-1.0, -1.0, -1.0, -1.0], 10.0);
let b = mk(array![-2.0, -2.0, -2.0, -2.0], 14.0);
let rep = compare(&a, &b);
assert!(rep.rows_aligned);
assert!((rep.delta_elpd - 4.0).abs() < 1e-12);
assert!((rep.delta_aic_corrected + 4.0).abs() < 1e-12);
assert!(rep.delta_elpd_se.abs() < 1e-12);
}
#[test]
fn compare_refuses_unpaired_rows() {
let mk = |pw: Array1<f64>| ModelComparison {
log_lik: 0.0,
edf: CorrectedEdf {
conditional: 0.0,
corrected: 0.0,
},
aic_conditional: 0.0,
aic_corrected: 5.0,
loo: Some(PsisLoo {
elpd: pw.iter().sum(),
se: 0.0,
pointwise: pw,
k_hat_max: 0.1,
n_k_bad: 0,
}),
};
let a = mk(array![-1.0, -1.0, -1.0]);
let b = mk(array![-1.0, -1.0]);
let rep = compare(&a, &b);
assert!(!rep.rows_aligned);
assert!(rep.delta_elpd.is_nan());
assert!((rep.delta_aic_corrected - 0.0).abs() < 1e-12);
}
}