use ndarray::{Array2, ArrayView2};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use crate::families::inverse_link::apply_inverse_link_vec;
use crate::util::quantile::quantile_from_sorted;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PosteriorPredictBandsPayload {
pub linear_predictor: Vec<f64>,
pub linear_predictor_lower: Vec<f64>,
pub linear_predictor_upper: Vec<f64>,
pub mean: Vec<f64>,
pub mean_lower: Vec<f64>,
pub mean_upper: Vec<f64>,
pub n_rows: usize,
pub n_draws: usize,
pub model_class: String,
pub family_kind: String,
}
pub fn eta_bands_from_matrix(
eta: ArrayView2<'_, f64>,
family_kind: &str,
level: f64,
) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>), String> {
if !(level > 0.0 && level < 1.0) {
return Err(format!("interval level must lie in (0, 1); got {level}"));
}
let alpha = (1.0 - level) / 2.0;
let n_draws = eta.nrows();
let n_rows = eta.ncols();
if n_draws == 0 {
return Err("posterior bands unavailable: zero draws".to_string());
}
let mut eta_mean = vec![0.0_f64; n_rows];
let mut eta_lower = vec![0.0_f64; n_rows];
let mut eta_upper = vec![0.0_f64; n_rows];
let mut response_mean = vec![0.0_f64; n_rows];
let mut column = vec![0.0_f64; n_draws];
let inv_n = 1.0 / n_draws as f64;
for j in 0..n_rows {
for k in 0..n_draws {
column[k] = eta[[k, j]];
}
let mut sum = 0.0_f64;
for v in &column {
sum += *v;
}
eta_mean[j] = sum * inv_n;
let response_draws = apply_inverse_link_vec(&column, family_kind)?;
let mut rsum = 0.0_f64;
for v in &response_draws {
rsum += *v;
}
response_mean[j] = rsum * inv_n;
column.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
eta_lower[j] = quantile_from_sorted(&column, alpha);
eta_upper[j] = quantile_from_sorted(&column, 1.0 - alpha);
}
let mean_lower = apply_inverse_link_vec(&eta_lower, family_kind)?;
let mean_upper = apply_inverse_link_vec(&eta_upper, family_kind)?;
Ok((
eta_mean,
eta_lower,
eta_upper,
response_mean,
mean_lower,
mean_upper,
))
}
pub fn posterior_eta_bands(
eta_flat: Vec<f64>,
n_draws: usize,
n_rows: usize,
family_kind: &str,
level: f64,
) -> Result<PosteriorPredictBandsPayload, String> {
if eta_flat.len() != n_draws * n_rows {
return Err(format!(
"posterior_eta_bands shape mismatch: got {} floats, expected {} * {}",
eta_flat.len(),
n_draws,
n_rows
));
}
let eta = Array2::<f64>::from_shape_vec((n_draws, n_rows), eta_flat)
.map_err(|err| format!("failed to reshape eta matrix: {err}"))?;
let (eta_mean, eta_lower, eta_upper, mean, mean_lower, mean_upper) =
eta_bands_from_matrix(eta.view(), family_kind, level)?;
Ok(PosteriorPredictBandsPayload {
linear_predictor: eta_mean,
linear_predictor_lower: eta_lower,
linear_predictor_upper: eta_upper,
mean,
mean_lower,
mean_upper,
n_rows,
n_draws,
model_class: String::new(),
family_kind: family_kind.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::quantile::quantile_from_sorted;
use ndarray::Array2;
#[test]
fn eta_bands_match_shared_quantile_and_response_mean_semantics() {
let eta = Array2::from_shape_vec(
(5, 2),
vec![
-2.0, 0.0, -1.0, 0.5, 0.0, 1.0, 1.0, 1.5, 2.0, 4.0, ],
)
.expect("shape");
let level = 0.80; let alpha = (1.0 - level) / 2.0;
let (eta_mean, eta_lower, eta_upper, mean, mean_lower, mean_upper) =
eta_bands_from_matrix(eta.view(), "log", level).expect("bands");
for j in 0..2 {
let mut col: Vec<f64> = (0..5).map(|k| eta[[k, j]]).collect();
let inv_n = 1.0 / 5.0;
let mean_eta: f64 = col.iter().sum::<f64>() * inv_n;
assert!(
(eta_mean[j] - mean_eta).abs() < 1e-12,
"eta mean mismatch col {j}"
);
let resp_mean: f64 = col.iter().map(|e| e.exp()).sum::<f64>() * inv_n;
assert!(
(mean[j] - resp_mean).abs() < 1e-12,
"response mean must be E[g^-1(eta)] for col {j}"
);
col.sort_by(|a, b| a.partial_cmp(b).expect("finite"));
let lo = quantile_from_sorted(&col, alpha);
let hi = quantile_from_sorted(&col, 1.0 - alpha);
assert!(
(eta_lower[j] - lo).abs() < 1e-12,
"lower band must be shared linear quantile col {j}"
);
assert!(
(eta_upper[j] - hi).abs() < 1e-12,
"upper band must be shared linear quantile col {j}"
);
assert!(
eta_lower[j] <= eta_mean[j] && eta_mean[j] <= eta_upper[j],
"eta mean must sit inside nonzero interval col {j}"
);
assert!(
mean_lower[j] <= mean[j] && mean[j] <= mean_upper[j],
"response mean must sit inside nonzero interval col {j}"
);
}
let mean_eta_col1: f64 = (0..5).map(|k| eta[[k, 1]]).sum::<f64>() / 5.0;
assert!(
mean[1] > mean_eta_col1.exp() + 1e-9,
"E[exp(eta)] must exceed exp(E[eta]) for the convex link"
);
}
#[test]
fn empty_draws_reject_posterior_bands() {
let eta = Array2::<f64>::zeros((0, 3));
let err =
eta_bands_from_matrix(eta.view(), "identity", 0.95).expect_err("zero draws must fail");
assert!(err.contains("zero draws"));
}
#[test]
fn response_bands_use_exact_log_inverse_link_not_solver_clamp() {
let eta = Array2::from_shape_vec((1, 1), vec![705.0]).expect("shape");
let (_eta_mean, _eta_lower, _eta_upper, mean, mean_lower, mean_upper) =
eta_bands_from_matrix(eta.view(), "log", 0.90).expect("bands");
let exact = 705.0_f64.exp();
assert!(exact.is_finite(), "exp(705) must be representable in f64");
let clamped = 700.0_f64.exp();
for (label, v) in [
("mean", mean[0]),
("mean_lower", mean_lower[0]),
("mean_upper", mean_upper[0]),
] {
assert_eq!(
v, exact,
"{label} must be exact exp(705), not the solver-clamped exp(700)"
);
assert!(
v > clamped * 100.0,
"{label} must exceed the clamped exp(700) by ~exp(5); got {v} vs {clamped}"
);
}
}
#[test]
fn level_must_lie_in_open_unit_interval() {
let eta = Array2::<f64>::zeros((4, 2));
assert!(eta_bands_from_matrix(eta.view(), "identity", 0.0).is_err());
assert!(eta_bands_from_matrix(eta.view(), "identity", 1.0).is_err());
}
}