pub mod input;
pub mod interval_policy;
pub mod linalg;
use crate::estimate::{BlockRole, EstimationError, FittedLinkState, UnifiedFitResult};
use crate::families::bms::{EmpiricalZGrid, LatentMeasureKind};
use crate::families::bms::{bernoulli_marginal_link_map, empirical_intercept_from_marginal};
use crate::families::family_runtime::{
FamilyStrategy, ResolvedFamilyStrategy, strategy_for_family, strategy_for_spec,
strategy_from_fit,
};
use crate::families::marginal_slope_shared::{
ObservedDenestedCellPartials, eval_coeff4_at,
probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
};
use crate::families::survival::lognormal_kernel::FrailtySpec;
use crate::inference::model::{
SavedCompiledFlexBlock, SavedLatentZNormalization, SavedLinkWiggleRuntime,
};
use crate::inference::predict::interval_policy::{
EtaInterval, LinearState, MeanBoundMethod, PredictPass, PredictionTransform, ResponseBounds,
ResponseInterval, assemble_posterior_mean_bounds, predict_full_uncertainty_generic,
predict_plugin_response_generic, predict_posterior_mean_generic,
predict_with_uncertainty_generic,
};
use crate::inference::predict::linalg::{
PredictionCovarianceBackend, design_row_chunk, prediction_chunk_rows,
rowwise_local_covariances_parallel,
};
use crate::linalg::utils::predict_gam_dimension_mismatch_message;
use crate::matrix::{DesignMatrix, SymmetricMatrix};
use crate::mixture_link::{
InverseLinkJet, beta_logistic_inverse_link_jetwith_param_partials,
mixture_inverse_link_jetwith_rho_partials_into, sas_inverse_link_jetwith_param_partials,
};
use crate::probability::{
beta_moment_matched_interval, gamma_moment_matched_interval,
negative_binomial_moment_matched_interval, normal_cdf, normal_pdf,
poisson_moment_matched_interval, standard_normal_quantile, tweedie_moment_matched_interval,
};
use crate::quadrature::QuadratureContext;
use crate::types::{InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, ResponseFamily};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
thread_local! {
static PREDICT_QUADRATURE_CONTEXT: QuadratureContext = QuadratureContext::new();
}
pub fn se_from_covariance(cov: &Array2<f64>) -> Array1<f64> {
Array1::from_iter(cov.diag().iter().map(|&v| v.max(0.0).sqrt()))
}
fn apply_family_inverse_link(
eta: &Array1<f64>,
family: &LikelihoodSpec,
) -> Result<Array1<f64>, EstimationError> {
strategy_for_spec(family).inverse_link_array(eta.view())
}
fn spec_from_family_link(
family: LikelihoodSpec,
link_kind: Option<&InverseLink>,
) -> LikelihoodSpec {
match link_kind {
Some(link) => LikelihoodSpec::new(family.response, link.clone()),
None => family,
}
}
fn local_covariances_with_backend<F>(
backend: &PredictionCovarianceBackend<'_>,
n_rows: usize,
local_dim: usize,
build_chunk: F,
) -> Result<Vec<Vec<Array1<f64>>>, EstimationError>
where
F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
{
rowwise_local_covariances_parallel(backend, n_rows, local_dim, build_chunk)
.map_err(EstimationError::InvalidInput)
}
fn usable_penalized_hessian<'a>(
fit: &'a UnifiedFitResult,
expected_dim: usize,
label: &str,
) -> Option<&'a Array2<f64>> {
let hessian = fit.penalized_hessian()?;
if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
log::warn!(
"{label}: ignoring penalized Hessian with shape {}x{}; expected {}x{}",
hessian.nrows(),
hessian.ncols(),
expected_dim,
expected_dim
);
return None;
}
if !hessian.iter().any(|value| value.abs() > 0.0) {
log::warn!("{label}: ignoring zero penalized Hessian placeholder");
return None;
}
Some(hessian)
}
fn conditional_prediction_backend<'a>(
fit: &'a UnifiedFitResult,
expected_dim: usize,
label: &str,
) -> Option<PredictionCovarianceBackend<'a>> {
if let Some(covariance) = fit.beta_covariance() {
if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
}
log::warn!(
"{label}: ignoring conditional covariance with shape {}x{}; expected {}x{}",
covariance.nrows(),
covariance.ncols(),
expected_dim,
expected_dim
);
}
if let Some(hessian) = usable_penalized_hessian(fit, expected_dim, label) {
let scale = fit.coefficient_covariance_scale();
match PredictionCovarianceBackend::from_factorized_hessian_scaled(
SymmetricMatrix::Dense(hessian.clone()),
scale,
) {
Ok(backend) => return Some(backend),
Err(err) => {
log::warn!(
"{label}: failed to build factorized prediction precision backend: {err}"
);
}
}
}
None
}
fn selected_uncertainty_backend<'a>(
fit: &'a UnifiedFitResult,
expected_dim: usize,
requested_mode: InferenceCovarianceMode,
label: &str,
) -> Result<(PredictionCovarianceBackend<'a>, bool), EstimationError> {
match requested_mode {
InferenceCovarianceMode::Conditional => {
conditional_prediction_backend(fit, expected_dim, label)
.map(|backend| (backend, false))
.ok_or_else(|| {
EstimationError::InvalidInput(
"fit result does not contain conditional covariance or a usable penalized Hessian"
.to_string(),
)
})
}
InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
if let Some(covariance) = fit.beta_covariance_corrected() {
if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
expected_dim,
expected_dim,
covariance.nrows(),
covariance.ncols()
)));
}
Ok((
PredictionCovarianceBackend::from_dense(covariance.view()),
true,
))
} else {
selected_uncertainty_backend(
fit,
expected_dim,
InferenceCovarianceMode::Conditional,
label,
)
}
}
InferenceCovarianceMode::ConditionalPlusSmoothingRequired => {
let covariance = fit.beta_covariance_corrected().ok_or_else(|| {
EstimationError::InvalidInput(
"fit result does not contain smoothing-corrected covariance".to_string(),
)
})?;
if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
expected_dim,
expected_dim,
covariance.nrows(),
covariance.ncols()
)));
}
Ok((
PredictionCovarianceBackend::from_dense(covariance.view()),
true,
))
}
}
}
pub trait UncertaintyCovarianceSource {
fn select_uncertainty_backend(
&self,
expected_dim: usize,
mode: InferenceCovarianceMode,
label: &str,
) -> Result<(PredictionCovarianceBackend<'_>, bool), EstimationError>;
fn resolved_fitted_link_state(&self, family: &LikelihoodSpec) -> Option<FittedLinkState>;
fn resolved_bias_correction_beta(&self) -> Option<ArrayView1<'_, f64>> {
None
}
fn observation_standard_deviation(&self) -> f64 {
0.0
}
fn observation_phi(&self) -> Option<f64> {
None
}
fn observation_theta(&self) -> Option<f64> {
None
}
}
impl UncertaintyCovarianceSource for UnifiedFitResult {
fn select_uncertainty_backend(
&self,
expected_dim: usize,
mode: InferenceCovarianceMode,
label: &str,
) -> Result<(PredictionCovarianceBackend<'_>, bool), EstimationError> {
selected_uncertainty_backend(self, expected_dim, mode, label)
}
fn resolved_fitted_link_state(&self, family: &LikelihoodSpec) -> Option<FittedLinkState> {
UnifiedFitResult::fitted_link_state(self, family).ok()
}
fn resolved_bias_correction_beta(&self) -> Option<ArrayView1<'_, f64>> {
UnifiedFitResult::bias_correction_beta(self).map(|b| b.view())
}
fn observation_standard_deviation(&self) -> f64 {
self.standard_deviation
}
fn observation_phi(&self) -> Option<f64> {
self.likelihood_scale.fixed_phi()
}
fn observation_theta(&self) -> Option<f64> {
self.likelihood_scale.negbin_theta()
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ObservationScaleHints {
observation_phi: Option<f64>,
observation_theta: Option<f64>,
}
impl ObservationScaleHints {
pub const fn none() -> Self {
Self {
observation_phi: None,
observation_theta: None,
}
}
pub fn from_likelihood_scale(scale: LikelihoodScaleMetadata) -> Self {
Self {
observation_phi: positive_finite(scale.fixed_phi()),
observation_theta: positive_finite(scale.negbin_theta()),
}
}
pub fn from_fit(fit: &UnifiedFitResult) -> Self {
Self::from_likelihood_scale(fit.likelihood_scale.clone())
}
pub fn with_phi(phi: f64) -> Self {
Self {
observation_phi: positive_finite(Some(phi)),
observation_theta: None,
}
}
pub fn with_theta(theta: f64) -> Self {
Self {
observation_phi: None,
observation_theta: positive_finite(Some(theta)),
}
}
}
fn positive_finite(value: Option<f64>) -> Option<f64> {
value.filter(|v| v.is_finite() && *v > 0.0)
}
pub struct PredictionCovarianceWithScale<'a> {
covariance: ArrayView2<'a, f64>,
scale: ObservationScaleHints,
}
impl<'a> PredictionCovarianceWithScale<'a> {
pub fn new(covariance: ArrayView2<'a, f64>, scale: ObservationScaleHints) -> Self {
Self { covariance, scale }
}
pub fn from_fit(covariance: ArrayView2<'a, f64>, fit: &UnifiedFitResult) -> Self {
Self::new(covariance, ObservationScaleHints::from_fit(fit))
}
}
impl UncertaintyCovarianceSource for PredictionCovarianceWithScale<'_> {
fn select_uncertainty_backend(
&self,
expected_dim: usize,
mode: InferenceCovarianceMode,
label: &str,
) -> Result<(PredictionCovarianceBackend<'_>, bool), EstimationError> {
if self.covariance.nrows() != expected_dim || self.covariance.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"{label}: covariance dimension mismatch: expected {expected_dim}x{expected_dim}, got {}x{}",
self.covariance.nrows(),
self.covariance.ncols()
)));
}
match mode {
InferenceCovarianceMode::Conditional
| InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => Ok((
PredictionCovarianceBackend::from_dense(self.covariance),
false,
)),
InferenceCovarianceMode::ConditionalPlusSmoothingRequired => {
Err(EstimationError::InvalidInput(format!(
"{label}: raw covariance source cannot provide smoothing-corrected covariance"
)))
}
}
}
fn resolved_fitted_link_state(&self, family: &LikelihoodSpec) -> Option<FittedLinkState> {
match &family.link {
InverseLink::Standard(_)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => None,
}
}
fn observation_phi(&self) -> Option<f64> {
self.scale.observation_phi
}
fn observation_theta(&self) -> Option<f64> {
self.scale.observation_theta
}
}
impl UncertaintyCovarianceSource for Array2<f64> {
fn select_uncertainty_backend(
&self,
expected_dim: usize,
mode: InferenceCovarianceMode,
label: &str,
) -> Result<(PredictionCovarianceBackend<'_>, bool), EstimationError> {
if self.nrows() != expected_dim || self.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"{label}: covariance dimension mismatch: expected {expected_dim}x{expected_dim}, got {}x{}",
self.nrows(),
self.ncols()
)));
}
match mode {
InferenceCovarianceMode::Conditional
| InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
Ok((PredictionCovarianceBackend::from_dense(self.view()), false))
}
InferenceCovarianceMode::ConditionalPlusSmoothingRequired => {
Err(EstimationError::InvalidInput(format!(
"{label}: raw covariance source cannot provide smoothing-corrected covariance"
)))
}
}
}
fn resolved_fitted_link_state(&self, family: &LikelihoodSpec) -> Option<FittedLinkState> {
match &family.link {
InverseLink::Standard(_)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => None,
}
}
}
#[inline]
fn quadratic_form(cov: &Array2<f64>, grad: &[f64]) -> Result<f64, EstimationError> {
quadratic_form_indexed(cov, grad.len(), "gradient", |i| grad[i])
}
#[inline]
fn quadratic_form_from_jetmu(
cov: &Array2<f64>,
partials: &[InverseLinkJet],
) -> Result<f64, EstimationError> {
quadratic_form_indexed(cov, partials.len(), "mixture gradient", |i| partials[i].mu)
}
#[inline]
fn quadratic_form_indexed(
cov: &Array2<f64>,
m: usize,
label: &str,
g: impl Fn(usize) -> f64,
) -> Result<f64, EstimationError> {
if cov.nrows() != m || cov.ncols() != m {
return Err(EstimationError::InvalidInput(format!(
"covariance/{label} dimension mismatch: covariance is {}x{}, {label} length is {}",
cov.nrows(),
cov.ncols(),
m
)));
}
let mut diag_acc = 0.0_f64;
let mut off_acc = 0.0_f64;
for i in 0..m {
let row = cov.row(i);
let row_slice = row.as_slice().expect("Array2 row is contiguous");
let gi = g(i);
diag_acc += gi * gi * row_slice[i];
let mut row_off = 0.0_f64;
for j in (i + 1)..m {
row_off += g(j) * row_slice[j];
}
off_acc += gi * row_off;
}
Ok((diag_acc + 2.0 * off_acc).max(0.0))
}
fn linear_predictorvariance_from_backend(
x: &DesignMatrix,
backend: &PredictionCovarianceBackend<'_>,
) -> Result<Array1<f64>, EstimationError> {
let local = local_covariances_with_backend(backend, x.nrows(), 1, |rows| {
Ok(vec![design_row_chunk(x, rows)?])
})?;
Ok(local[0][0].mapv(|v| v.max(0.0)))
}
const POSTERIOR_MEAN_VARIANCE_TOL: f64 = 1e-10;
const POSTERIOR_MEAN_CROSS_TOL: f64 = 1e-10;
const SURVIVAL_STANDARDIZED_ARG_CLAMP: f64 = 1e6;
fn posterior_mean_backend_or_warn<'a>(
fit: &'a UnifiedFitResult,
fallback: Option<&'a Array2<f64>>,
expected_dim: usize,
label: &str,
) -> Option<PredictionCovarianceBackend<'a>> {
for (source, covariance) in [
("fit result", fit.beta_covariance()),
("predictor state", fallback),
] {
let Some(covariance) = covariance else {
continue;
};
if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
}
log::warn!(
"{label}: ignoring {source} covariance with shape {}x{}; expected {}x{}",
covariance.nrows(),
covariance.ncols(),
expected_dim,
expected_dim
);
}
if let Some(backend) = conditional_prediction_backend(fit, expected_dim, label) {
return Some(backend);
}
log::warn!(
"{label}: covariance/precision unavailable; falling back to plug-in point prediction"
);
None
}
fn require_posterior_mean_backend<'a>(
fit: &'a UnifiedFitResult,
fallback: Option<&'a Array2<f64>>,
expected_dim: usize,
label: &str,
) -> Result<PredictionCovarianceBackend<'a>, EstimationError> {
posterior_mean_backend_or_warn(fit, fallback, expected_dim, label).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"{label} requires covariance or penalized Hessian for posterior-mean prediction"
))
})
}
fn project_two_block_linear_predictor_covariance(
design_first: &DesignMatrix,
design_second: &DesignMatrix,
backend: &PredictionCovarianceBackend<'_>,
p_first: usize,
p_second: usize,
label: &str,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let p_total = p_first + p_second;
if backend.nrows() != p_total {
return Err(EstimationError::InvalidInput(format!(
"{label} covariance dimension mismatch: expected parameter dimension {}, got {}",
p_total,
backend.nrows()
)));
}
if design_first.ncols() != p_first || design_second.ncols() != p_second {
return Err(EstimationError::InvalidInput(format!(
"{label} design dimension mismatch: threshold/location design has {} columns (expected {}), scale design has {} columns (expected {})",
design_first.ncols(),
p_first,
design_second.ncols(),
p_second
)));
}
let local = local_covariances_with_backend(backend, design_first.nrows(), 2, |rows| {
let x_first = design_row_chunk(design_first, rows.clone())?;
let x_second = design_row_chunk(design_second, rows.clone())?;
let rows_in_chunk = rows.end - rows.start;
let mut first = Array2::<f64>::zeros((rows_in_chunk, p_total));
let mut second = Array2::<f64>::zeros((rows_in_chunk, p_total));
first
.slice_mut(ndarray::s![.., 0..p_first])
.assign(&x_first);
second
.slice_mut(ndarray::s![.., p_first..p_total])
.assign(&x_second);
Ok(vec![first, second])
})?;
Ok((
local[0][0].mapv(|v| v.max(0.0)),
local[1][1].mapv(|v| v.max(0.0)),
local[0][1].clone(),
))
}
fn linear_predictor_se_from_backend<F>(
backend: &PredictionCovarianceBackend<'_>,
n_rows: usize,
build_chunk: F,
) -> Result<Array1<f64>, EstimationError>
where
F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
{
let local = local_covariances_with_backend(backend, n_rows, 1, build_chunk)?;
Ok(local[0][0].mapv(|v| v.max(0.0).sqrt()))
}
#[derive(Clone, Copy)]
struct LinkWiggleGradientLayout {
p_main: usize,
p_total: usize,
wiggle_col_start: usize,
}
fn link_wiggle_eta_se_from_backend(
backend: &PredictionCovarianceBackend<'_>,
n_rows: usize,
design: &DesignMatrix,
q0_base: &Array1<f64>,
runtime: &SavedLinkWiggleRuntime,
layout: LinkWiggleGradientLayout,
dimension_label: &str,
) -> Result<Array1<f64>, EstimationError> {
if backend.nrows() != layout.p_total {
return Err(EstimationError::InvalidInput(format!(
"{dimension_label}: expected parameter dimension {}, got {}",
layout.p_total,
backend.nrows()
)));
}
let p_w = runtime.beta.len();
linear_predictor_se_from_backend(backend, n_rows, |rows| {
let q0_chunk = q0_base.slice(ndarray::s![rows.clone()]).to_owned();
let x_main = design_row_chunk(design, rows.clone())?;
let wiggle_design = runtime.design(&q0_chunk)?;
let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
let rows_in_chunk = q0_chunk.len();
let mut grad = Array2::<f64>::zeros((rows_in_chunk, layout.p_total));
for i in 0..rows_in_chunk {
let dqi = dq_dq0[i];
for j in 0..layout.p_main {
grad[[i, j]] = dqi * x_main[[i, j]];
}
}
grad.slice_mut(ndarray::s![
..,
layout.wiggle_col_start..layout.wiggle_col_start + p_w
])
.assign(&wiggle_design);
Ok(vec![grad])
})
}
fn padded_design_standard_errors_from_backend(
design: &DesignMatrix,
backend: &PredictionCovarianceBackend<'_>,
leading_zeros: usize,
trailing_zeros: usize,
label: &str,
) -> Result<Array1<f64>, EstimationError> {
let p_design = design.ncols();
let p_total = leading_zeros + p_design + trailing_zeros;
if backend.nrows() != p_total {
return Err(EstimationError::InvalidInput(format!(
"{label} covariance dimension mismatch: expected parameter dimension {p_total}, got {}",
backend.nrows()
)));
}
linear_predictor_se_from_backend(backend, design.nrows(), |rows| {
let x = design_row_chunk(design, rows)?;
let rows_in_chunk = x.nrows();
let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
grad.slice_mut(ndarray::s![.., leading_zeros..leading_zeros + p_design])
.assign(&x);
Ok(vec![grad])
})
}
fn projected_bivariate_posterior_mean_result<F>(
quadctx: &crate::quadrature::QuadratureContext,
mu: [f64; 2],
cov: [[f64; 2]; 2],
integrand: F,
) -> Result<f64, EstimationError>
where
F: Fn(f64, f64) -> Result<f64, EstimationError>,
{
let var0 = cov[0][0].max(0.0);
let var1 = cov[1][1].max(0.0);
let cov01 = cov[0][1];
if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && var1 <= POSTERIOR_MEAN_VARIANCE_TOL {
return integrand(mu[0], mu[1]);
}
if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
quadctx,
[mu[1]],
[[var1]],
21,
|x| integrand(mu[0], x[0]),
);
}
if var1 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
quadctx,
[mu[0]],
[[var0]],
21,
|x| integrand(x[0], mu[1]),
);
}
crate::quadrature::normal_expectation_2d_adaptive_result(quadctx, mu, cov, integrand)
}
pub struct PredictResult {
pub eta: Array1<f64>,
pub mean: Array1<f64>,
}
pub struct PredictInput {
pub design: DesignMatrix,
pub offset: Array1<f64>,
pub design_noise: Option<DesignMatrix>,
pub offset_noise: Option<Array1<f64>>,
pub auxiliary_scalar: Option<Array1<f64>>,
pub auxiliary_matrix: Option<Array2<f64>>,
}
fn slice_predict_input(
input: &PredictInput,
rows: std::ops::Range<usize>,
) -> Result<PredictInput, EstimationError> {
Ok(PredictInput {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
design_row_chunk(&input.design, rows.clone()).map_err(EstimationError::InvalidInput)?,
)),
offset: input.offset.slice(ndarray::s![rows.clone()]).to_owned(),
design_noise: input
.design_noise
.as_ref()
.map(|design| {
design_row_chunk(design, rows.clone())
.map(|d| DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(d)))
.map_err(EstimationError::InvalidInput)
})
.transpose()?,
offset_noise: input
.offset_noise
.as_ref()
.map(|offset| offset.slice(ndarray::s![rows.clone()]).to_owned()),
auxiliary_scalar: input
.auxiliary_scalar
.as_ref()
.map(|values| values.slice(ndarray::s![rows.clone()]).to_owned()),
auxiliary_matrix: input
.auxiliary_matrix
.as_ref()
.map(|values| values.slice(ndarray::s![rows, ..]).to_owned()),
})
}
pub struct PredictionWithSE {
pub eta: Array1<f64>,
pub mean: Array1<f64>,
pub eta_se: Option<Array1<f64>>,
pub mean_se: Option<Array1<f64>>,
}
pub trait PredictableModel {
fn predict_plugin_response(
&self,
input: &PredictInput,
) -> Result<PredictResult, EstimationError>;
fn predict_linear_predictor(
&self,
input: &PredictInput,
) -> Result<Array1<f64>, EstimationError> {
self.predict_plugin_response(input).map(|pred| pred.eta)
}
fn predict_with_uncertainty(
&self,
input: &PredictInput,
) -> Result<PredictionWithSE, EstimationError>;
fn predict_noise_scale(
&self,
input: &PredictInput,
) -> Result<Option<Array1<f64>>, EstimationError>;
fn predict_dispersion_scale(
&self,
_input: &PredictInput,
) -> Result<Option<Array1<f64>>, EstimationError> {
Ok(None)
}
fn predict_full_uncertainty(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PredictUncertaintyOptions,
) -> Result<PredictUncertaintyResult, EstimationError>;
fn predict_posterior_mean(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PosteriorMeanOptions,
) -> Result<PredictPosteriorMeanResult, EstimationError>;
fn n_blocks(&self) -> usize;
fn block_roles(&self) -> Vec<BlockRole>;
}
mod bernoulli_marginal_slope;
mod binomial_location_scale;
mod dispersion_location_scale;
mod gaussian_location_scale;
mod standard;
mod survival;
mod transformation_normal;
pub use bernoulli_marginal_slope::*;
pub use binomial_location_scale::*;
pub use dispersion_location_scale::*;
pub use gaussian_location_scale::*;
pub use standard::*;
pub use survival::*;
pub use transformation_normal::*;
fn eta_standard_errors_from_backend(
x: &DesignMatrix,
backend: &PredictionCovarianceBackend<'_>,
) -> Result<Array1<f64>, EstimationError> {
let vars = linear_predictorvariance_from_backend(x, backend)?;
Ok(vars.mapv(|v| v.max(0.0).sqrt()))
}
fn inverse_link_mean_and_d1(
strategy: &(dyn FamilyStrategy + Sync),
eta: ndarray::ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = eta.len();
let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
.into_par_iter()
.map(|i| {
let jet = strategy.inverse_link_jet(eta[i])?;
Ok((jet.mu, jet.d1))
})
.collect();
let pairs = pairs?;
let mut mean = Array1::<f64>::zeros(n);
let mut d1 = Array1::<f64>::zeros(n);
for (i, (mu, d1_i)) in pairs.into_iter().enumerate() {
mean[i] = mu;
d1[i] = d1_i;
}
Ok((mean, d1))
}
fn delta_method_mean_se_from_d1(dmu_deta: &Array1<f64>, eta_se: &Array1<f64>) -> Array1<f64> {
let n = dmu_deta.len();
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
out[i] = (dmu_deta[i] * eta_se[i]).abs();
}
out
}
pub struct PredictPosteriorMeanResult {
pub eta: Array1<f64>,
pub eta_standard_error: Array1<f64>,
pub mean: Array1<f64>,
pub mean_lower: Option<Array1<f64>>,
pub mean_upper: Option<Array1<f64>>,
pub observation_lower: Option<Array1<f64>>,
pub observation_upper: Option<Array1<f64>>,
}
#[derive(Clone, Copy, Debug)]
pub struct PosteriorMeanOptions {
pub confidence_level: Option<f64>,
pub covariance_mode: InferenceCovarianceMode,
pub include_observation_interval: bool,
}
impl PosteriorMeanOptions {
pub fn point_only() -> Self {
Self {
confidence_level: None,
covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingPreferred,
include_observation_interval: false,
}
}
pub fn with_level(level: f64) -> Self {
Self {
confidence_level: Some(level),
covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingPreferred,
include_observation_interval: false,
}
}
}
pub fn enrich_posterior_mean_bounds(
result: &mut PredictPosteriorMeanResult,
confidence_level: f64,
family: crate::types::LikelihoodSpec,
link_kind: Option<&InverseLink>,
) -> Result<(), EstimationError> {
let spec = spec_from_family_link(family, link_kind);
assemble_posterior_mean_bounds(
result,
Some(confidence_level),
EtaInterval::Symmetric,
MeanBoundMethod::TransformEta {
bounds: ResponseBounds::for_family(&spec.response),
response_map: &|eta: &Array1<f64>| apply_family_inverse_link(eta, &spec),
},
)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum InferenceCovarianceMode {
Conditional,
ConditionalPlusSmoothingPreferred,
ConditionalPlusSmoothingRequired,
}
#[derive(Clone, Debug)]
pub struct TrainingSupport {
pub axis_min: Array1<f64>,
pub axis_max: Array1<f64>,
}
#[derive(Clone)]
pub struct PredictUncertaintyOptions {
pub confidence_level: f64,
pub covariance_mode: InferenceCovarianceMode,
pub mean_interval_method: MeanIntervalMethod,
pub includeobservation_interval: bool,
pub apply_bias_correction: bool,
pub edgeworth_one_sided: bool,
pub boundary_correction: bool,
pub ood_inflation: bool,
pub multi_point_joint: bool,
pub predictor_x_for_corrections: Option<Array2<f64>>,
pub training_support: Option<TrainingSupport>,
pub extrapolation_variance: Option<Array1<f64>>,
pub eta_skewness_for_corrections: Option<Array1<f64>>,
pub joint_query_count: Option<usize>,
pub boundary_alpha: f64,
pub boundary_band_fraction: f64,
pub ood_gamma: f64,
pub conformal_level: Option<f64>,
}
impl Default for PredictUncertaintyOptions {
fn default() -> Self {
Self {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingPreferred,
mean_interval_method: MeanIntervalMethod::TransformEta,
includeobservation_interval: true,
apply_bias_correction: true,
edgeworth_one_sided: true,
boundary_correction: true,
ood_inflation: false,
multi_point_joint: false,
predictor_x_for_corrections: None,
training_support: None,
extrapolation_variance: None,
eta_skewness_for_corrections: None,
joint_query_count: None,
boundary_alpha: 0.25,
boundary_band_fraction: 0.05,
ood_gamma: 1.0,
conformal_level: None,
}
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct EdgeworthZ {
pub z_lower: f64,
pub z_upper: f64,
}
pub(crate) fn edgeworth_one_sided_quantile(z: f64, skew_kappa3: f64) -> EdgeworthZ {
let bump = (z * z - 1.0) * skew_kappa3 / 6.0;
EdgeworthZ {
z_lower: (z - bump).max(0.0),
z_upper: (z + bump).max(0.0),
}
}
pub(crate) fn boundary_variance_inflation_factor(
x_row: ArrayView1<'_, f64>,
axis_min: ArrayView1<'_, f64>,
axis_max: ArrayView1<'_, f64>,
alpha: f64,
band_fraction: f64,
) -> f64 {
let d = x_row.len();
if d == 0 || axis_min.len() != d || axis_max.len() != d || band_fraction <= 0.0 {
return 1.0;
}
let mut excess = 0.0_f64;
for k in 0..d {
let lo = axis_min[k];
let hi = axis_max[k];
let range = hi - lo;
if !(range > 0.0) {
continue;
}
let x = x_row[k];
let d_edge = (x - lo).min(hi - x);
if !d_edge.is_finite() || d_edge >= band_fraction * range {
continue;
}
if d_edge <= 0.0 {
excess += 1.0;
} else {
let shortfall = 1.0 - d_edge / (band_fraction * range);
excess += shortfall * shortfall;
}
}
(1.0 + alpha * excess).max(1.0)
}
pub(crate) fn ood_variance_inflation_factor(
x_row: ArrayView1<'_, f64>,
axis_min: ArrayView1<'_, f64>,
axis_max: ArrayView1<'_, f64>,
gamma: f64,
) -> f64 {
let d = x_row.len();
if d == 0 || axis_min.len() != d || axis_max.len() != d {
return 1.0;
}
let mut sq_excess = 0.0_f64;
for k in 0..d {
let lo = axis_min[k];
let hi = axis_max[k];
let range = hi - lo;
if !(range > 0.0) {
continue;
}
let x = x_row[k];
let excess = if x < lo {
lo - x
} else if x > hi {
x - hi
} else {
0.0
};
let frac = excess / range;
sq_excess += frac * frac;
}
(1.0 + gamma * sq_excess).max(1.0)
}
pub(crate) fn multi_point_joint_z(level: f64, m: usize) -> Result<f64, String> {
if m <= 1 || !(level.is_finite() && level > 0.0 && level < 1.0) {
return standard_normal_quantile(0.5 + 0.5 * level);
}
let alpha = 1.0 - level;
let per_row_alpha = alpha / (m as f64);
let per_row_level = 1.0 - per_row_alpha;
standard_normal_quantile(0.5 + 0.5 * per_row_level)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MeanIntervalMethod {
Delta,
TransformEta,
}
#[derive(Debug)]
pub struct PredictUncertaintyResult {
pub eta: Array1<f64>,
pub mean: Array1<f64>,
pub eta_standard_error: Array1<f64>,
pub mean_standard_error: Array1<f64>,
pub eta_lower: Array1<f64>,
pub eta_upper: Array1<f64>,
pub mean_lower: Array1<f64>,
pub mean_upper: Array1<f64>,
pub observation_lower: Option<Array1<f64>>,
pub observation_upper: Option<Array1<f64>>,
pub covariance_mode_requested: InferenceCovarianceMode,
pub covariance_corrected_used: bool,
}
fn predict_gam_posterior_mean_from_backend(
x: DesignMatrix,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
backend: &PredictionCovarianceBackend<'_>,
strategy: &(dyn FamilyStrategy + Sync),
label: &str,
) -> Result<PredictPosteriorMeanResult, EstimationError> {
predict_gam_posterior_mean_from_backendwith_bc(x, beta, offset, backend, strategy, label, None)
}
fn predict_gam_posterior_mean_from_backendwith_bc(
x: DesignMatrix,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
backend: &PredictionCovarianceBackend<'_>,
strategy: &(dyn FamilyStrategy + Sync),
label: &str,
bias_correction_beta: Option<ArrayView1<'_, f64>>,
) -> Result<PredictPosteriorMeanResult, EstimationError> {
if x.ncols() != beta.len() {
return Err(EstimationError::InvalidInput(format!(
"{label} dimension mismatch: X has {} columns but beta has length {}",
x.ncols(),
beta.len()
)));
}
if x.nrows() != offset.len() {
return Err(EstimationError::InvalidInput(format!(
"{label} dimension mismatch: X has {} rows but offset has length {}",
x.nrows(),
offset.len()
)));
}
if backend.nrows() != beta.len() {
return Err(EstimationError::InvalidInput(format!(
"{label} covariance/backend dimension mismatch: expected parameter dimension {}, got {}",
beta.len(),
backend.nrows()
)));
}
let mut eta = x.matrixvectormultiply(&beta.to_owned());
eta += &offset;
if let Some(bc) = bias_correction_beta {
if bc.len() != beta.len() {
return Err(EstimationError::InvalidInput(format!(
"{label} bias-correction dimension mismatch: beta has length {} but bias_correction_beta has length {}",
beta.len(),
bc.len()
)));
}
let bc_owned = bc.to_owned();
let delta = x.matrixvectormultiply(&bc_owned);
eta += δ
}
let etavar = linear_predictorvariance_from_backend(&x, backend)?;
let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
let quadctx = crate::quadrature::QuadratureContext::new();
let means: Result<Vec<f64>, EstimationError> = (0..eta.len())
.into_par_iter()
.map(|i| strategy.posterior_mean(&quadctx, eta[i], eta_standard_error[i]))
.collect();
Ok(PredictPosteriorMeanResult {
eta,
eta_standard_error,
mean: Array1::from_vec(means?),
mean_lower: None,
mean_upper: None,
observation_lower: None,
observation_upper: None,
})
}
pub struct CoefficientUncertaintyResult {
pub estimate: Array1<f64>,
pub standard_error: Array1<f64>,
pub lower: Array1<f64>,
pub upper: Array1<f64>,
pub corrected: bool,
pub covariance_mode_requested: InferenceCovarianceMode,
}
pub fn predict_gam<X>(
x: X,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodSpec,
) -> Result<PredictResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let x = x.into();
if let Some(message) =
predict_gam_dimension_mismatch_message(x.nrows(), x.ncols(), beta.len(), offset.len())
{
return Err(EstimationError::InvalidInput(message));
}
let mut eta = x.matrixvectormultiply(&beta.to_owned());
eta += &offset;
let mean = apply_family_inverse_link(&eta, &family)?;
Ok(PredictResult { eta, mean })
}
pub fn predict_gam_posterior_mean<X>(
x: X,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodSpec,
covariance: ArrayView2<'_, f64>,
) -> Result<PredictPosteriorMeanResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let x = x.into();
let backend = PredictionCovarianceBackend::from_dense(covariance.view());
let strategy = strategy_for_spec(&family);
predict_gam_posterior_mean_from_backend(
x,
beta,
offset,
&backend,
&strategy,
"predict_gam_posterior_mean",
)
}
pub fn predict_gam_posterior_meanwith_backend<X>(
x: X,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodSpec,
backend: &PredictionCovarianceBackend<'_>,
) -> Result<PredictPosteriorMeanResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let x = x.into();
let strategy = strategy_for_spec(&family);
predict_gam_posterior_mean_from_backend(
x,
beta,
offset,
backend,
&strategy,
"predict_gam_posterior_meanwith_backend",
)
}
pub fn predict_gam_posterior_meanwith_fit<X>(
x: X,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodSpec,
covariance: ArrayView2<'_, f64>,
fit: &UnifiedFitResult,
) -> Result<PredictPosteriorMeanResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let x = x.into();
let backend = PredictionCovarianceBackend::from_dense(covariance.view());
let strategy = strategy_from_fit(&family, fit)?;
predict_gam_posterior_mean_from_backend(
x,
beta,
offset,
&backend,
&strategy,
"predict_gam_posterior_meanwith_fit",
)
}
pub(crate) fn family_response_variance<S>(
response: &ResponseFamily,
mean: &Array1<f64>,
source: &S,
) -> Option<Array1<f64>>
where
S: UncertaintyCovarianceSource + ?Sized,
{
match response {
ResponseFamily::Gaussian => {
let obsvar = source.observation_standard_deviation().max(0.0).powi(2);
Some(Array1::from_elem(mean.len(), obsvar))
}
ResponseFamily::Poisson => Some(mean.mapv(|mu| mu.max(0.0))),
ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
let theta = if *theta_fixed {
Some(*theta)
} else {
source.observation_theta()
}?;
Some(mean.mapv(|mu| mu + mu.powi(2) / theta))
}
ResponseFamily::Tweedie { p } => {
let phi = source.observation_phi()?;
Some(mean.mapv(|mu| phi * mu.powf(*p)))
}
ResponseFamily::Gamma => {
let phi = source.observation_phi()?;
Some(mean.mapv(|mu| phi * mu.powi(2)))
}
ResponseFamily::Beta { .. } => {
let phi = source.observation_phi()?;
Some(mean.mapv(|mu| mu * (1.0 - mu) / (1.0 + phi)))
}
ResponseFamily::Binomial => Some(mean.mapv(|mu| {
let p = mu.clamp(0.0, 1.0);
p * (1.0 - p)
})),
ResponseFamily::RoystonParmar => None,
}
}
pub(crate) fn family_observation_band<S>(
response: &ResponseFamily,
eta: &Array1<f64>,
etavar: &Array1<f64>,
mean: &Array1<f64>,
mean_standard_error: &Array1<f64>,
z_lower_per_row: &Array1<f64>,
z_upper_per_row: &Array1<f64>,
source: &S,
) -> (Option<Array1<f64>>, Option<Array1<f64>>)
where
S: UncertaintyCovarianceSource + ?Sized,
{
let observation_support = ResponseBounds::response_support(response);
let clamp_to_support = |mut lower: Array1<f64>, mut upper: Array1<f64>| {
observation_support.clamp_in_place(&mut lower);
observation_support.clamp_in_place(&mut upper);
(Some(lower), Some(upper))
};
let response_observation_bounds = |response_var: Array1<f64>| {
let obs_se = Array1::from_iter(
mean_standard_error
.iter()
.zip(response_var.iter())
.map(|(&mean_se, &obsvar)| (mean_se.powi(2) + obsvar).max(0.0).sqrt()),
);
let lower = Array1::from_iter(
mean.iter()
.zip(obs_se.iter())
.zip(z_lower_per_row.iter())
.map(|((&m, &s), &zl)| m - zl * s),
);
let upper = Array1::from_iter(
mean.iter()
.zip(obs_se.iter())
.zip(z_upper_per_row.iter())
.map(|((&m, &s), &zu)| m + zu * s),
);
clamp_to_support(lower, upper)
};
let skew_predictive_bounds =
|response_var: Array1<f64>,
predictive: &dyn Fn(f64, f64, f64, f64) -> Option<(f64, f64)>| {
let n = mean.len();
let mut lower = Array1::<f64>::zeros(n);
let mut upper = Array1::<f64>::zeros(n);
for i in 0..n {
let mu = mean[i];
let total_var = (mean_standard_error[i].powi(2) + response_var[i]).max(0.0);
let p_lower = normal_cdf(-z_lower_per_row[i]);
let p_upper = normal_cdf(z_upper_per_row[i]);
match predictive(mu, total_var, p_lower, p_upper) {
Some((q_lo, q_hi)) => {
lower[i] = q_lo;
upper[i] = q_hi;
}
None => {
let s = total_var.sqrt();
lower[i] = mu - z_lower_per_row[i] * s;
upper[i] = mu + z_upper_per_row[i] * s;
}
}
}
clamp_to_support(lower, upper)
};
match response {
ResponseFamily::Gaussian => {
let obsvar = source.observation_standard_deviation().max(0.0).powi(2);
let obs_se = etavar.mapv(|v| (v + obsvar).max(0.0).sqrt());
let lower = Array1::from_iter(
eta.iter()
.zip(obs_se.iter())
.zip(z_lower_per_row.iter())
.map(|((&e, &s), &zl)| e - zl * s),
);
let upper = Array1::from_iter(
eta.iter()
.zip(obs_se.iter())
.zip(z_upper_per_row.iter())
.map(|((&e, &s), &zu)| e + zu * s),
);
clamp_to_support(lower, upper)
}
ResponseFamily::Poisson => {
let response_var = mean.mapv(|mu| mu.max(0.0));
skew_predictive_bounds(response_var, &|mu, total_var, p_lo, p_hi| {
poisson_moment_matched_interval(mu, total_var, p_lo, p_hi)
})
}
ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
let Some(theta) = (if *theta_fixed {
Some(*theta)
} else {
source.observation_theta()
}) else {
return (None, None);
};
let response_var = mean.mapv(|mu| mu + mu.powi(2) / theta);
skew_predictive_bounds(response_var, &|mu, total_var, p_lo, p_hi| {
negative_binomial_moment_matched_interval(mu, theta, total_var, p_lo, p_hi)
})
}
ResponseFamily::Tweedie { p } => {
let Some(phi) = source.observation_phi() else {
return (None, None);
};
let response_var = mean.mapv(|mu| phi * mu.powf(*p));
let power = *p;
skew_predictive_bounds(response_var, &|mu, total_var, p_lo, p_hi| {
tweedie_moment_matched_interval(mu, phi, power, total_var, p_lo, p_hi)
})
}
ResponseFamily::Gamma => {
let Some(phi) = source.observation_phi() else {
return (None, None);
};
let response_var = mean.mapv(|mu| phi * mu.powi(2));
skew_predictive_bounds(response_var, &|mu, total_var, p_lo, p_hi| {
gamma_moment_matched_interval(mu, total_var, p_lo, p_hi)
})
}
ResponseFamily::Beta { .. } => {
let Some(phi) = source.observation_phi() else {
return (None, None);
};
let response_var = mean.mapv(|mu| mu * (1.0 - mu) / (1.0 + phi));
skew_predictive_bounds(response_var, &|mu, total_var, p_lo, p_hi| {
beta_moment_matched_interval(mu, total_var, p_lo, p_hi)
})
}
ResponseFamily::Binomial => {
let response_var = mean.mapv(|mu| {
let p = mu.clamp(0.0, 1.0);
p * (1.0 - p)
});
response_observation_bounds(response_var)
}
ResponseFamily::RoystonParmar => (None, None),
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn family_observation_band_per_row(
response: &ResponseFamily,
mean: &Array1<f64>,
mean_standard_error: &Array1<f64>,
response_var: &Array1<f64>,
dispersion: &Array1<f64>,
z_lower_per_row: &Array1<f64>,
z_upper_per_row: &Array1<f64>,
) -> (Option<Array1<f64>>, Option<Array1<f64>>) {
let n = mean.len();
if mean_standard_error.len() != n
|| response_var.len() != n
|| dispersion.len() != n
|| z_lower_per_row.len() != n
|| z_upper_per_row.len() != n
{
return (None, None);
}
let predictive: Box<dyn Fn(f64, f64, f64, f64, f64) -> Option<(f64, f64)>> = match response {
ResponseFamily::Gamma => Box::new(|mu, _disp, total_var, p_lo, p_hi| {
gamma_moment_matched_interval(mu, total_var, p_lo, p_hi)
}),
ResponseFamily::Beta { .. } => Box::new(|mu, _disp, total_var, p_lo, p_hi| {
beta_moment_matched_interval(mu, total_var, p_lo, p_hi)
}),
ResponseFamily::NegativeBinomial { .. } => Box::new(|mu, theta, total_var, p_lo, p_hi| {
negative_binomial_moment_matched_interval(mu, theta, total_var, p_lo, p_hi)
}),
ResponseFamily::Tweedie { p } => {
let power = *p;
Box::new(move |mu, phi, total_var, p_lo, p_hi| {
tweedie_moment_matched_interval(mu, phi, power, total_var, p_lo, p_hi)
})
}
_ => return (None, None),
};
let observation_support = ResponseBounds::response_support(response);
let mut lower = Array1::<f64>::zeros(n);
let mut upper = Array1::<f64>::zeros(n);
for i in 0..n {
let mu = mean[i];
let total_var = (mean_standard_error[i].powi(2) + response_var[i]).max(0.0);
let p_lower = normal_cdf(-z_lower_per_row[i]);
let p_upper = normal_cdf(z_upper_per_row[i]);
match predictive(mu, dispersion[i], total_var, p_lower, p_upper) {
Some((q_lo, q_hi)) => {
lower[i] = q_lo;
upper[i] = q_hi;
}
None => {
let s = total_var.sqrt();
lower[i] = mu - z_lower_per_row[i] * s;
upper[i] = mu + z_upper_per_row[i] * s;
}
}
}
observation_support.clamp_in_place(&mut lower);
observation_support.clamp_in_place(&mut upper);
(Some(lower), Some(upper))
}
pub fn predict_gamwith_uncertainty<X, S>(
x: X,
beta: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
family: LikelihoodSpec,
source: &S,
options: &PredictUncertaintyOptions,
) -> Result<PredictUncertaintyResult, EstimationError>
where
X: Into<DesignMatrix>,
S: UncertaintyCovarianceSource + ?Sized,
{
let x = x.into();
if x.ncols() != beta.len() {
return Err(EstimationError::InvalidInput(format!(
"predict_gamwith_uncertainty dimension mismatch: X has {} columns but beta has length {}",
x.ncols(),
beta.len()
)));
}
if x.nrows() != offset.len() {
return Err(EstimationError::InvalidInput(format!(
"predict_gamwith_uncertainty dimension mismatch: X has {} rows but offset has length {}",
x.nrows(),
offset.len()
)));
}
if !(options.confidence_level.is_finite()
&& options.confidence_level > 0.0
&& options.confidence_level < 1.0)
{
return Err(EstimationError::InvalidInput(format!(
"confidence_level must be in (0,1), got {}",
options.confidence_level
)));
}
let requested_mode = options.covariance_mode;
let (backend, covariance_corrected_used) = source.select_uncertainty_backend(
beta.len(),
requested_mode,
"predict_gamwith_uncertainty",
)?;
let mut eta = x.matrixvectormultiply(&beta.to_owned());
eta += &offset;
if options.apply_bias_correction
&& let Some(bc) = source.resolved_bias_correction_beta()
{
if bc.len() == beta.len() {
let bc_owned = bc.to_owned();
let delta = x.matrixvectormultiply(&bc_owned);
eta += δ
} else {
log::warn!(
"predict_gamwith_uncertainty: bias-correction dimension mismatch \
(beta {}, bc {}); skipping bias correction",
beta.len(),
bc.len()
);
}
}
let fitted_link_state = source.resolved_fitted_link_state(&family);
let mixture_state = match fitted_link_state.as_ref() {
Some(FittedLinkState::Mixture { state, .. }) => Some(state.clone()),
_ => None,
};
let sas_state = match fitted_link_state.as_ref() {
Some(FittedLinkState::Sas { state, .. })
| Some(FittedLinkState::BetaLogistic { state, .. }) => Some(*state),
_ => None,
};
let link_kind = match fitted_link_state.as_ref() {
Some(FittedLinkState::Standard(Some(link))) => Some(InverseLink::Standard(*link)),
Some(FittedLinkState::LatentCLogLog { state }) => Some(InverseLink::LatentCLogLog(*state)),
Some(FittedLinkState::Sas { state, .. }) => Some(InverseLink::Sas(*state)),
Some(FittedLinkState::BetaLogistic { state, .. }) => {
Some(InverseLink::BetaLogistic(*state))
}
Some(FittedLinkState::Mixture { state, .. }) => Some(InverseLink::Mixture(state.clone())),
Some(FittedLinkState::Standard(None)) | None => None,
};
let likelihood = if let Some(link) = link_kind.clone() {
LikelihoodSpec::new(family.response.clone(), link)
} else {
family.clone()
};
let strategy = strategy_for_spec(&likelihood);
let mean = apply_family_inverse_link(&eta, &likelihood)?;
let etavar_raw = linear_predictorvariance_from_backend(&x, &backend)?;
let n_rows = etavar_raw.len();
let ood_inflation_active = options.ood_inflation && options.extrapolation_variance.is_none();
if options.ood_inflation && !ood_inflation_active {
log::warn!(
"predict_gamwith_uncertainty: ood_inflation is enabled but an additive \
extrapolation_variance is supplied; skipping the multiplicative OOD \
inflation to avoid double-counting off-support uncertainty"
);
}
let mut variance_inflation = Array1::<f64>::ones(n_rows);
if (options.boundary_correction || ood_inflation_active)
&& let (Some(predictor_x), Some(support)) = (
options.predictor_x_for_corrections.as_ref(),
options.training_support.as_ref(),
)
&& predictor_x.nrows() == n_rows
&& predictor_x.ncols() == support.axis_min.len()
&& support.axis_min.len() == support.axis_max.len()
{
for i in 0..n_rows {
let row = predictor_x.row(i);
let mut factor = 1.0_f64;
if options.boundary_correction {
factor *= boundary_variance_inflation_factor(
row,
support.axis_min.view(),
support.axis_max.view(),
options.boundary_alpha,
options.boundary_band_fraction,
);
}
if ood_inflation_active {
factor *= ood_variance_inflation_factor(
row,
support.axis_min.view(),
support.axis_max.view(),
options.ood_gamma,
);
}
variance_inflation[i] = factor;
}
}
let mut etavar = if variance_inflation.iter().all(|&f| f == 1.0) {
etavar_raw.clone()
} else {
Array1::from_iter(
etavar_raw
.iter()
.zip(variance_inflation.iter())
.map(|(&v, &f)| v * f),
)
};
if let Some(extra) = options.extrapolation_variance.as_ref() {
if extra.len() != n_rows {
return Err(EstimationError::InvalidInput(format!(
"extrapolation_variance length {} does not match prediction batch {}",
extra.len(),
n_rows
)));
}
etavar += extra;
}
let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
let level = options.confidence_level;
let z_central = if options.multi_point_joint {
let m = options.joint_query_count.unwrap_or(n_rows).max(1);
multi_point_joint_z(level, m).map_err(EstimationError::InvalidInput)?
} else {
standard_normal_quantile(0.5 + 0.5 * level).map_err(EstimationError::InvalidInput)?
};
let mut z_lower_per_row = Array1::<f64>::from_elem(n_rows, z_central);
let mut z_upper_per_row = Array1::<f64>::from_elem(n_rows, z_central);
if options.edgeworth_one_sided
&& let Some(skew) = options.eta_skewness_for_corrections.as_ref()
&& skew.len() == n_rows
{
for i in 0..n_rows {
let adj = edgeworth_one_sided_quantile(z_central, skew[i]);
z_lower_per_row[i] = adj.z_lower;
z_upper_per_row[i] = adj.z_upper;
}
}
let eta_lower = Array1::from_iter(
eta.iter()
.zip(eta_standard_error.iter())
.zip(z_lower_per_row.iter())
.map(|((&e, &s), &zl)| e - zl * s),
);
let eta_upper = Array1::from_iter(
eta.iter()
.zip(eta_standard_error.iter())
.zip(z_upper_per_row.iter())
.map(|((&e, &s), &zu)| e + zu * s),
);
let quadctx = crate::quadrature::QuadratureContext::new();
let mean_standard_error = Array1::from_vec(
(0..eta.len())
.into_par_iter()
.map(|i| -> Result<f64, EstimationError> {
let se_i = etavar[i].max(0.0).sqrt();
let (_, mut meanvar) = strategy.posterior_meanvariance(&quadctx, eta[i], se_i)?;
if family.is_binomial_sas()
&& let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
FittedLinkState::Sas { covariance, .. } => covariance.as_ref(),
_ => None,
})
{
let sas = sas_state.ok_or_else(|| {
EstimationError::InvalidInput(
"BinomialSas uncertainty requires fitted sas_epsilon/sas_log_delta"
.to_string(),
)
})?;
let jets =
sas_inverse_link_jetwith_param_partials(eta[i], sas.epsilon, sas.log_delta);
let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
meanvar += quadratic_form(cov_theta, &g)?;
}
if family.is_binomial_beta_logistic()
&& let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
FittedLinkState::BetaLogistic { covariance, .. } => covariance.as_ref(),
_ => None,
})
{
let sas = sas_state.ok_or_else(|| {
EstimationError::InvalidInput(
"BinomialBetaLogistic uncertainty requires fitted parameters"
.to_string(),
)
})?;
let jets = beta_logistic_inverse_link_jetwith_param_partials(
eta[i],
sas.log_delta,
sas.epsilon,
);
let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
meanvar += quadratic_form(cov_theta, &g)?;
}
if family.is_binomial_mixture()
&& let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
FittedLinkState::Mixture { covariance, .. } => covariance.as_ref(),
_ => None,
})
&& let Some(state) = mixture_state.as_ref()
{
let mut mix_partials = vec![
InverseLinkJet {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
state.rho.len()
];
mixture_inverse_link_jetwith_rho_partials_into(
state,
eta[i],
&mut mix_partials,
);
meanvar += quadratic_form_from_jetmu(cov_theta, &mix_partials)?;
}
Ok(meanvar.max(0.0).sqrt())
})
.collect::<Result<Vec<_>, _>>()?,
);
let (mut mean_lower, mut mean_upper) = match options.mean_interval_method {
MeanIntervalMethod::Delta => (
Array1::from_iter(
mean.iter()
.zip(mean_standard_error.iter())
.zip(z_lower_per_row.iter())
.map(|((&m, &s), &zl)| m - zl * s),
),
Array1::from_iter(
mean.iter()
.zip(mean_standard_error.iter())
.zip(z_upper_per_row.iter())
.map(|((&m, &s), &zu)| m + zu * s),
),
),
MeanIntervalMethod::TransformEta => {
let transformed_lower = apply_family_inverse_link(&eta_lower, &likelihood)?;
let transformed_upper = apply_family_inverse_link(&eta_upper, &likelihood)?;
(
Array1::from_iter(
transformed_lower
.iter()
.zip(transformed_upper.iter())
.map(|(&lo, &hi)| lo.min(hi)),
),
Array1::from_iter(
transformed_lower
.iter()
.zip(transformed_upper.iter())
.map(|(&lo, &hi)| lo.max(hi)),
),
)
}
};
let spec = &likelihood;
let response_bounds = ResponseBounds::for_family(&spec.response);
response_bounds.clamp_in_place(&mut mean_lower);
response_bounds.clamp_in_place(&mut mean_upper);
let (observation_lower, observation_upper) = if options.includeobservation_interval {
family_observation_band(
&spec.response,
&eta,
&etavar,
&mean,
&mean_standard_error,
&z_lower_per_row,
&z_upper_per_row,
source,
)
} else {
(None, None)
};
Ok(PredictUncertaintyResult {
eta,
mean,
eta_standard_error,
mean_standard_error,
eta_lower,
eta_upper,
mean_lower,
mean_upper,
observation_lower,
observation_upper,
covariance_mode_requested: requested_mode,
covariance_corrected_used,
})
}
pub struct ConformalCalibrationFold<'a> {
pub input: PredictInput,
pub y: ArrayView1<'a, f64>,
}
pub fn predict_full_uncertainty_conformal<M: PredictableModel + ?Sized>(
model: &M,
input: &PredictInput,
fit: &UnifiedFitResult,
family: &LikelihoodSpec,
options: &PredictUncertaintyOptions,
calibration: &ConformalCalibrationFold<'_>,
) -> Result<PredictUncertaintyResult, EstimationError> {
let mut result = model.predict_full_uncertainty(input, fit, options)?;
let Some(level) = options.conformal_level else {
return Ok(result);
};
if !(level.is_finite() && level > 0.0 && level < 1.0) {
return Err(EstimationError::InvalidInput(format!(
"conformal_level must be in (0,1), got {level}"
)));
}
let alpha = 1.0 - level;
let cal_options = PredictUncertaintyOptions {
conformal_level: None,
includeobservation_interval: false,
..options.clone()
};
let cal_result = model.predict_full_uncertainty(&calibration.input, fit, &cal_options)?;
if cal_result.mean.len() != calibration.y.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal calibration: predicted {} calibration means but y_cal has length {}",
cal_result.mean.len(),
calibration.y.len()
)));
}
let cal_scale = predictive_standard_error(
family,
&cal_result.mean,
&cal_result.mean_standard_error,
fit,
);
let test_scale =
predictive_standard_error(family, &result.mean, &result.mean_standard_error, fit);
let calibrator = crate::inference::conformal::ConformalCalibrator::from_held_out_fold(
calibration.y,
cal_result.mean.view(),
cal_scale.view(),
alpha,
)?;
let bounds = ResponseBounds::for_family(&family.response);
let (lower, upper) = calibrator.calibrated_interval(&result.mean, &test_scale, bounds)?;
result.mean_lower = lower;
result.mean_upper = upper;
Ok(result)
}
fn predictive_standard_error<S>(
family: &LikelihoodSpec,
mean: &Array1<f64>,
mean_standard_error: &Array1<f64>,
source: &S,
) -> Array1<f64>
where
S: UncertaintyCovarianceSource + ?Sized,
{
match family_response_variance(&family.response, mean, source) {
Some(response_var) => Array1::from_iter(
mean_standard_error
.iter()
.zip(response_var.iter())
.map(|(&se, &var)| (se.powi(2) + var.max(0.0)).max(0.0).sqrt()),
),
None => mean_standard_error.clone(),
}
}
pub fn coefficient_uncertainty(
fit: &UnifiedFitResult,
confidence_level: f64,
covariance_mode: InferenceCovarianceMode,
) -> Result<CoefficientUncertaintyResult, EstimationError> {
coefficient_uncertaintywith_mode(fit, confidence_level, covariance_mode)
}
pub fn coefficient_uncertaintywith_mode(
fit: &UnifiedFitResult,
confidence_level: f64,
covariance_mode: InferenceCovarianceMode,
) -> Result<CoefficientUncertaintyResult, EstimationError> {
if !(confidence_level.is_finite() && confidence_level > 0.0 && confidence_level < 1.0) {
return Err(EstimationError::InvalidInput(format!(
"confidence_level must be in (0,1), got {}",
confidence_level
)));
}
let (se, corrected) = match covariance_mode {
InferenceCovarianceMode::Conditional => (
fit.beta_standard_errors().cloned().ok_or_else(|| {
EstimationError::InvalidInput(
"fit result does not contain conditional coefficient standard errors"
.to_string(),
)
})?,
false,
),
InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
if let Some(se_corr) = fit.beta_standard_errors_corrected() {
(se_corr.clone(), true)
} else if let Some(se_base) = fit.beta_standard_errors() {
(se_base.clone(), false)
} else {
return Err(EstimationError::InvalidInput(
"fit result does not contain coefficient standard errors".to_string(),
));
}
}
InferenceCovarianceMode::ConditionalPlusSmoothingRequired => (
fit.beta_standard_errors_corrected()
.cloned()
.ok_or_else(|| {
EstimationError::InvalidInput(
"fit result does not contain smoothing-corrected coefficient standard errors"
.to_string(),
)
})?,
true,
),
};
if se.len() != fit.beta.len() {
return Err(EstimationError::InvalidInput(format!(
"standard error length mismatch: beta has {}, se has {}",
fit.beta.len(),
se.len()
)));
}
let z = standard_normal_quantile(0.5 + 0.5 * confidence_level)
.map_err(EstimationError::InvalidInput)?;
let lower = &fit.beta - &se.mapv(|s| z * s);
let upper = &fit.beta + &se.mapv(|s| z * s);
Ok(CoefficientUncertaintyResult {
estimate: fit.beta.clone(),
standard_error: se,
lower,
upper,
corrected,
covariance_mode_requested: covariance_mode,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::estimate::{
BlockRole, FitArtifacts, FittedBlock, FittedLinkState, UnifiedFitResult,
UnifiedFitResultParts,
};
use crate::inference::model::SavedCompiledFlexBlock;
use crate::pirls::PirlsStatus;
use crate::types::{LinkFunction, StandardLink};
use ndarray::{Array1, Array2, array};
fn saved_runtime_from_deviation_runtime(
runtime: &crate::families::bms::DeviationRuntime,
) -> SavedCompiledFlexBlock {
SavedCompiledFlexBlock {
kernel: crate::families::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL.to_string(),
breakpoints: runtime.breakpoints().to_vec(),
basis_dim: runtime.basis_dim(),
span_c0: runtime
.span_c0()
.outer_iter()
.map(|row| row.to_vec())
.collect(),
span_c1: runtime
.span_c1()
.outer_iter()
.map(|row| row.to_vec())
.collect(),
span_c2: runtime
.span_c2()
.outer_iter()
.map(|row| row.to_vec())
.collect(),
span_c3: runtime
.span_c3()
.outer_iter()
.map(|row| row.to_vec())
.collect(),
anchor_correction: None,
anchor_components: Vec::new(),
}
}
#[test]
fn raw_covariance_observation_intervals_require_fitted_scale_hints() {
let x = array![[1.0_f64]];
let beta = array![0.0_f64];
let offset = array![0.0_f64];
let covariance = Array2::<f64>::zeros((1, 1));
let options = PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::Delta,
includeobservation_interval: true,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
};
let beta_seed = crate::types::LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
);
let beta_raw = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
beta_seed,
&covariance,
&options,
)
.expect("raw beta covariance prediction");
assert!(
beta_raw.observation_lower.is_none() && beta_raw.observation_upper.is_none(),
"bare Vb must not build a Beta observation interval from the seed phi"
);
let nb_seed = crate::types::LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: 1.0,
theta_fixed: false,
},
InverseLink::Standard(StandardLink::Log),
);
let nb_raw = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
nb_seed,
&covariance,
&options,
)
.expect("raw NB covariance prediction");
assert!(
nb_raw.observation_lower.is_none() && nb_raw.observation_upper.is_none(),
"bare Vb must not build an estimated-NB observation interval from the seed theta"
);
}
#[test]
fn raw_covariance_with_scale_hints_drives_observation_interval_width() {
let x = array![[1.0_f64]];
let beta = array![0.0_f64];
let offset = array![0.0_f64];
let covariance = Array2::<f64>::zeros((1, 1));
let z = standard_normal_quantile(0.975).expect("z");
let options = PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::Delta,
includeobservation_interval: true,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
};
let beta_phi = 31.0;
let beta_source = PredictionCovarianceWithScale::new(
covariance.view(),
ObservationScaleHints::with_phi(beta_phi),
);
let beta_pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
),
&beta_source,
&options,
)
.expect("hinted beta covariance prediction");
let beta_lower = beta_pred.observation_lower.expect("beta lower");
let beta_upper = beta_pred.observation_upper.expect("beta upper");
let beta_half_width = 0.5 * (beta_upper[0] - beta_lower[0]);
let expected_beta_half_width = z * (0.25 / (1.0 + beta_phi)).sqrt();
assert!(
(beta_half_width - expected_beta_half_width).abs() < 1e-12,
"Beta observation interval must use fitted phi hint"
);
let theta_hat = 4.0;
let nb_source = PredictionCovarianceWithScale::new(
covariance.view(),
ObservationScaleHints::with_theta(theta_hat),
);
let nb_pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: 1.0,
theta_fixed: false,
},
InverseLink::Standard(StandardLink::Log),
),
&nb_source,
&options,
)
.expect("hinted NB covariance prediction");
let nb_upper = nb_pred.observation_upper.expect("nb upper");
let expected_nb_upper = 1.0 + z * (1.0 + 1.0 / theta_hat).sqrt();
assert!(
(nb_upper[0] - expected_nb_upper).abs() < 1e-12,
"NB observation interval must use fitted theta hint"
);
}
fn test_fit_with_covariance(beta: Array1<f64>, covariance: Array2<f64>) -> UnifiedFitResult {
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![FittedBlock {
beta: beta.clone(),
role: BlockRole::Mean,
edf: 0.0,
lambdas: Array1::zeros(0),
}],
log_lambdas: Array1::zeros(0),
lambdas: Array1::zeros(0),
likelihood_family: Some(crate::types::LikelihoodSpec::gaussian_identity()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
log_likelihood: 0.0,
deviance: 0.0,
reml_score: 0.0,
stable_penalty_term: 0.0,
penalized_objective: 0.0,
used_device: false,
outer_iterations: 0,
outer_converged: true,
outer_gradient_norm: None,
standard_deviation: 1.0,
covariance_conditional: Some(covariance),
covariance_corrected: None,
inference: None,
fitted_link: FittedLinkState::Standard(None),
geometry: None,
block_states: Vec::new(),
pirls_status: PirlsStatus::Converged,
max_abs_eta: 0.0,
constraint_kkt: None,
artifacts: FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
.expect("test fit")
}
fn gaussian_location_scale_fit_with_covariance(
beta_mu: Array1<f64>,
beta_noise: Array1<f64>,
covariance: Array2<f64>,
) -> UnifiedFitResult {
gaussian_location_scale_fit_with_covariance_and_corrected(
beta_mu, beta_noise, covariance, None,
)
}
fn gaussian_location_scale_fit_with_covariance_and_corrected(
beta_mu: Array1<f64>,
beta_noise: Array1<f64>,
covariance: Array2<f64>,
covariance_corrected: Option<Array2<f64>>,
) -> UnifiedFitResult {
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![
FittedBlock {
beta: beta_mu,
role: BlockRole::Location,
edf: 0.0,
lambdas: Array1::zeros(0),
},
FittedBlock {
beta: beta_noise,
role: BlockRole::Scale,
edf: 0.0,
lambdas: Array1::zeros(0),
},
],
log_lambdas: Array1::zeros(0),
lambdas: Array1::zeros(0),
likelihood_family: Some(crate::types::LikelihoodSpec::gaussian_identity()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
log_likelihood: 0.0,
deviance: 0.0,
reml_score: 0.0,
stable_penalty_term: 0.0,
penalized_objective: 0.0,
used_device: false,
outer_iterations: 0,
outer_converged: true,
outer_gradient_norm: None,
standard_deviation: 1.0,
covariance_conditional: Some(covariance),
covariance_corrected,
inference: None,
fitted_link: FittedLinkState::Standard(None),
geometry: None,
block_states: Vec::new(),
pirls_status: PirlsStatus::Converged,
max_abs_eta: 0.0,
constraint_kkt: None,
artifacts: FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
.expect("gaussian location-scale fit")
}
fn survival_fit_with_covariance(
beta_threshold: Array1<f64>,
beta_log_sigma: Array1<f64>,
covariance: Array2<f64>,
) -> UnifiedFitResult {
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![
FittedBlock {
beta: beta_threshold,
role: BlockRole::Threshold,
edf: 0.0,
lambdas: Array1::zeros(0),
},
FittedBlock {
beta: beta_log_sigma,
role: BlockRole::Scale,
edf: 0.0,
lambdas: Array1::zeros(0),
},
],
log_lambdas: Array1::zeros(0),
lambdas: Array1::zeros(0),
likelihood_family: Some(crate::types::LikelihoodSpec::royston_parmar()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
log_likelihood: 0.0,
deviance: 0.0,
reml_score: 0.0,
stable_penalty_term: 0.0,
penalized_objective: 0.0,
used_device: false,
outer_iterations: 0,
outer_converged: true,
outer_gradient_norm: None,
standard_deviation: 1.0,
covariance_conditional: Some(covariance),
covariance_corrected: None,
inference: None,
fitted_link: FittedLinkState::Standard(None),
geometry: None,
block_states: Vec::new(),
pirls_status: PirlsStatus::Converged,
max_abs_eta: 0.0,
constraint_kkt: None,
artifacts: FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
.expect("survival fit")
}
#[test]
fn predict_posterior_mean_probit_matches_closed_form_reference() {
let x = array![[1.0], [1.0]];
let beta = array![0.7];
let offset = array![0.0, 0.0];
let covariance = Array2::from_diag(&array![0.25]);
let out = predict_gam_posterior_mean(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::binomial_probit(),
covariance.view(),
)
.expect("predict posterior mean");
let expected = crate::quadrature::probit_posterior_meanwith_deriv_exact(0.7, 0.5).mean;
assert!((out.mean[0] - expected).abs() <= 1e-12);
assert!((out.mean[1] - expected).abs() <= 1e-12);
}
#[test]
fn predict_posterior_mean_logit_uses_integrated_dispatch() {
let x = array![[1.0], [1.0]];
let beta = array![0.4];
let offset = array![0.0, 0.0];
let covariance = Array2::from_diag(&array![0.16]);
let out = predict_gam_posterior_mean(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::binomial_logit(),
covariance.view(),
)
.expect("predict posterior mean");
let quadctx = crate::quadrature::QuadratureContext::new();
let expected = crate::quadrature::integrated_inverse_link_mean_and_derivative(
&quadctx,
LinkFunction::Logit,
0.4,
0.4,
)
.expect("logit integrated inverse-link moments should evaluate")
.mean;
assert!((out.mean[0] - expected).abs() <= 1e-12);
assert!((out.mean[1] - expected).abs() <= 1e-12);
}
#[test]
fn bernoulli_marginal_slope_predictor_rejects_structurally_invalid_or_unknown_runtime_kernel() {
let seed = array![-1.5, -0.2, 0.6, 1.4];
let prepared = crate::families::bms::build_score_warp_deviation_block_from_seed(
&seed,
&crate::families::bms::DeviationBlockConfig {
degree: 3,
num_internal_knots: 3,
..Default::default()
},
)
.expect("production score-warp runtime");
let production_runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
let score_only = BernoulliMarginalSlopePredictor {
beta_marginal: array![0.8],
beta_logslope: array![1.6],
beta_score_warp: Some(array![0.7, -0.4]),
beta_link_dev: None,
base_link: InverseLink::Standard(crate::types::StandardLink::Probit),
z_column: "z".to_string(),
latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureKind::StandardNormal,
baseline_marginal: 0.0,
baseline_logslope: 0.0,
covariance: None,
score_warp_runtime: Some(SavedCompiledFlexBlock {
kernel: "OldQuadrature".to_string(),
..production_runtime.clone()
}),
link_deviation_runtime: None,
gaussian_frailty_sd: None,
latent_z_calibration: None,
latent_z_conditional_calibration: None,
};
let err = score_only
.score_warp_runtime
.as_ref()
.unwrap()
.design(&array![0.0])
.unwrap_err();
assert!(err.to_string().contains("DenestedCubicTransport"));
let err = crate::families::bms::build_score_warp_deviation_block_from_seed(
&seed,
&crate::families::bms::DeviationBlockConfig {
degree: 2,
num_internal_knots: 3,
..Default::default()
},
)
.expect_err("non-cubic deviation runtimes should be rejected");
assert!(err.contains("degree must be 3"));
let mut structurally_invalid = production_runtime.clone();
structurally_invalid.span_c0[0].pop();
let err = structurally_invalid.design(&array![0.0]).unwrap_err();
assert!(err.to_string().contains("c0 row 0 has width"));
let cubic = production_runtime;
assert!(cubic.design(&array![0.0]).is_ok());
}
#[test]
fn saved_anchored_deviation_runtime_local_cubic_reconstructs_values() {
let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
let prepared = crate::families::bms::build_score_warp_deviation_block_from_seed(
&seed,
&crate::families::bms::DeviationBlockConfig {
num_internal_knots: 4,
..Default::default()
},
)
.expect("build saved anchored deviation runtime");
let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
let beta = Array1::from_iter(
(0..runtime.basis_dim)
.map(|idx| 0.02 * (idx as f64 + 1.0) * (-1.0_f64).powi(idx as i32)),
);
let n_spans = runtime.span_count().expect("span count");
assert!(n_spans >= 2);
for span_idx in 0..n_spans {
let cubic = runtime
.local_cubic_on_span(&beta, span_idx)
.expect("local cubic");
let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
let expected = runtime.design(&x_eval).expect("design").dot(&beta);
let expected_d1 = runtime
.first_derivative_design(&x_eval)
.expect("d1 design")
.dot(&beta);
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!((cubic.evaluate(x) - expected[i]).abs() < 1e-10);
assert!((cubic.first_derivative(x) - expected_d1[i]).abs() < 1e-10);
let selected = runtime.local_cubic_at(&beta, x).expect("local cubic at x");
let expected_span_idx = if i == 0 && span_idx > 0 {
span_idx - 1
} else {
span_idx
};
let expected_cubic = runtime
.local_cubic_on_span(&beta, expected_span_idx)
.expect("expected local cubic on span");
assert_eq!(selected.left, expected_cubic.left);
assert_eq!(selected.right, expected_cubic.right);
}
}
}
#[test]
fn saved_anchored_deviation_runtime_design_with_anchor_rows_applies_residual() {
use crate::families::bms::deviation_runtime::ParametricAnchorBlock;
use crate::inference::model::{SavedAnchorComponent, SavedAnchorKind};
let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
let prepared = crate::families::bms::build_score_warp_deviation_block_from_seed(
&seed,
&crate::families::bms::DeviationBlockConfig {
num_internal_knots: 4,
..Default::default()
},
)
.expect("build saved anchored deviation runtime");
let mut runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
let d = 3usize;
let m: Vec<Vec<f64>> = (0..d)
.map(|i| {
(0..runtime.basis_dim)
.map(|j| 0.1 * (i as f64 + 1.0) - 0.05 * (j as f64 + 1.0))
.collect()
})
.collect();
runtime.anchor_correction = Some(m.clone());
runtime.anchor_components = vec![SavedAnchorComponent {
kind: SavedAnchorKind::Parametric {
block: ParametricAnchorBlock::Marginal,
ncols: d,
},
}];
let values = array![-1.0, 0.0, 0.5, 2.0];
let n = values.len();
let anchor_rows = Array2::from_shape_fn((n, d), |(i, j)| {
0.3 * (i as f64 + 1.0) - 0.1 * (j as f64 + 1.0)
});
let raw = runtime
.design_uncorrected(&values)
.expect("uncorrected design");
let corrected = runtime
.design_with_anchor_rows(&values, anchor_rows.view())
.expect("design with anchor rows");
let mut m_dense = Array2::<f64>::zeros((d, runtime.basis_dim));
for (i, row) in m.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
m_dense[[i, j]] = v;
}
}
let expected = &raw - &anchor_rows.dot(&m_dense);
for i in 0..n {
for j in 0..runtime.basis_dim {
assert!(
(corrected[[i, j]] - expected[[i, j]]).abs() < 1e-12,
"residual-corrected design mismatch at ({i}, {j}): \
got {got}, expected {exp}",
got = corrected[[i, j]],
exp = expected[[i, j]],
);
}
}
let correction = runtime
.anchor_correction_matrix(anchor_rows.view())
.expect("anchor correction matrix")
.expect("Some correction when residual is present");
for i in 0..n {
for j in 0..runtime.basis_dim {
assert!((raw[[i, j]] - correction[[i, j]] - corrected[[i, j]]).abs() < 1e-12,);
}
}
}
#[test]
fn bernoulli_marginal_slope_rigid_gaussian_frailty_uses_scaled_closed_form() {
let predictor = BernoulliMarginalSlopePredictor {
beta_marginal: array![0.7],
beta_logslope: array![-0.4],
beta_score_warp: None,
beta_link_dev: None,
base_link: InverseLink::Standard(crate::types::StandardLink::Probit),
z_column: "z".to_string(),
latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureKind::StandardNormal,
baseline_marginal: 0.1,
baseline_logslope: -0.2,
covariance: None,
score_warp_runtime: None,
link_deviation_runtime: None,
gaussian_frailty_sd: Some(0.8),
latent_z_calibration: None,
latent_z_conditional_calibration: None,
};
let theta = predictor.theta();
let input = PredictInput {
design: DesignMatrix::from(array![[1.0], [1.0]]),
offset: array![0.0, 0.05],
design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
offset_noise: Some(array![0.0, -0.1]),
auxiliary_scalar: Some(array![-0.3, 1.2]),
auxiliary_matrix: None,
};
let (eta, grad) = predictor
.final_eta_and_gradient_from_theta(&input, &theta, true)
.expect("rigid frailty path should evaluate");
let scale = predictor.probit_frailty_scale();
let marginal_eta = array![0.8, 0.85];
let logslope_eta = array![-0.6, -0.7];
let z = array![-0.3, 1.2];
for i in 0..eta.len() {
let sb = scale * logslope_eta[i];
let c = (1.0 + sb * sb).sqrt();
let expected_eta = marginal_eta[i] * c + sb * z[i];
assert!((eta[i] - expected_eta).abs() <= 1e-12);
let expected_d_marginal = c;
let expected_d_logslope =
marginal_eta[i] * scale * scale * logslope_eta[i] / c + scale * z[i];
let grad = grad.as_ref().expect("gradient should be returned");
assert!((grad[[i, 0]] - expected_d_marginal).abs() <= 1e-12);
assert!((grad[[i, 1]] - expected_d_logslope).abs() <= 1e-12);
}
}
#[test]
fn bernoulli_marginal_slope_predictor_uses_local_empirical_latent_law() {
let grids = vec![
EmpiricalZGrid {
nodes: vec![-1.2, -0.2, 0.7],
weights: vec![0.45, 0.35, 0.20],
},
EmpiricalZGrid {
nodes: vec![-0.4, 0.6, 2.4],
weights: vec![0.20, 0.35, 0.45],
},
];
let predictor = BernoulliMarginalSlopePredictor {
beta_marginal: array![0.2],
beta_logslope: array![0.9],
beta_score_warp: None,
beta_link_dev: None,
base_link: InverseLink::Standard(crate::types::StandardLink::Probit),
z_column: "z".to_string(),
latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureKind::LocalEmpirical {
feature_cols: vec![0],
input_scales: None,
centers: vec![vec![-1.0], vec![1.0]],
grids: grids.clone(),
top_k: 1,
bandwidth: 0.25,
train_row_mixtures: std::sync::Arc::new(Vec::new()),
},
baseline_marginal: 0.0,
baseline_logslope: 0.0,
covariance: None,
score_warp_runtime: None,
link_deviation_runtime: None,
gaussian_frailty_sd: None,
latent_z_calibration: None,
latent_z_conditional_calibration: None,
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0], [1.0]]),
offset: array![0.0, 0.0],
design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
offset_noise: Some(array![0.0, 0.0]),
auxiliary_scalar: Some(array![0.0, 0.0]),
auxiliary_matrix: Some(array![[-1.0], [1.0]]),
};
let (eta, _) = predictor
.final_eta_and_gradient_from_theta(&input, &predictor.theta(), true)
.expect("local empirical prediction");
let (chain_eta, deta_dq) = predictor
.predict_eta_and_q_chain(&input)
.expect("local empirical q chain");
for (row, grid) in grids.iter().enumerate() {
let expected_intercept = empirical_intercept_from_marginal(
normal_cdf(0.2),
0.2,
0.9,
1.0,
&grid.nodes,
&grid.weights,
None,
)
.expect("expected empirical intercept");
assert!((eta[row] - expected_intercept).abs() <= 1e-10);
assert!((chain_eta[row] - eta[row]).abs() <= 1e-12);
assert!(deta_dq[row].is_finite() && deta_dq[row] > 0.0);
}
}
#[test]
fn bernoulli_marginal_slope_predictor_rejects_nonprobit_base_link_scale() {
let predictor = BernoulliMarginalSlopePredictor {
beta_marginal: array![0.7],
beta_logslope: array![-0.4],
beta_score_warp: None,
beta_link_dev: None,
base_link: InverseLink::Standard(crate::types::StandardLink::Logit),
z_column: "z".to_string(),
latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureKind::StandardNormal,
baseline_marginal: 0.1,
baseline_logslope: -0.2,
covariance: None,
score_warp_runtime: None,
link_deviation_runtime: None,
gaussian_frailty_sd: Some(0.8),
latent_z_calibration: None,
latent_z_conditional_calibration: None,
};
let theta = predictor.theta();
let input = PredictInput {
design: DesignMatrix::from(array![[1.0], [1.0]]),
offset: array![0.0, 0.05],
design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
offset_noise: Some(array![0.0, -0.1]),
auxiliary_scalar: Some(array![-0.3, 1.2]),
auxiliary_matrix: None,
};
let err = predictor
.final_eta_and_gradient_from_theta(&input, &theta, true)
.expect_err("non-probit marginal-slope prediction should be rejected");
assert!(err.to_string().contains("requires link(type=probit)"));
}
#[test]
fn bernoulli_marginal_slope_point_state_emits_covariance_based_interval() {
let predictor = BernoulliMarginalSlopePredictor {
beta_marginal: array![0.7],
beta_logslope: array![-0.4],
beta_score_warp: None,
beta_link_dev: None,
base_link: InverseLink::Standard(crate::types::StandardLink::Probit),
z_column: "z".to_string(),
latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
latent_measure: LatentMeasureKind::StandardNormal,
baseline_marginal: 0.1,
baseline_logslope: -0.2,
covariance: Some(array![[0.040, 0.010], [0.010, 0.090]]),
score_warp_runtime: None,
link_deviation_runtime: None,
gaussian_frailty_sd: None,
latent_z_calibration: None,
latent_z_conditional_calibration: None,
};
let theta = predictor.theta();
assert_eq!(
theta.len(),
2,
"rigid marginal-slope θ is [marginal | logslope]"
);
let input = PredictInput {
design: DesignMatrix::from(array![[1.0], [1.0], [1.0]]),
offset: array![0.0, 0.05, -0.10],
design_noise: Some(DesignMatrix::from(array![[1.0], [1.0], [1.0]])),
offset_noise: Some(array![0.0, -0.1, 0.2]),
auxiliary_scalar: Some(array![-0.3, 1.2, 0.4]),
auxiliary_matrix: None,
};
let state = predictor
.point_state(&input)
.expect("marginal-slope point_state should evaluate with a covariance");
let eta = state.eta.clone();
let eta_se = state
.eta_se
.as_ref()
.expect("issue #1049: covariance-backed point_state must emit an η-scale SE");
let mean_se = state
.mean_se
.as_ref()
.expect("issue #1049: covariance-backed point_state must emit a mean SE");
let cov = predictor.covariance.as_ref().unwrap();
let (_, grad) = predictor
.final_eta_and_gradient_from_theta(&input, &theta, true)
.expect("analytic gradient");
let grad = grad.expect("gradient rows");
for i in 0..eta.len() {
let g = grad.row(i).to_owned();
let cg = cov.dot(&g);
let var = g.dot(&cg);
let se_oracle = var.max(0.0).sqrt();
assert!(se_oracle > 0.0, "row {i} SE collapsed to zero");
assert!(
(eta_se[i] - se_oracle).abs() <= 1e-10,
"row {i}: η-SE {} != oracle gᵀΣg^{{1/2}} {}",
eta_se[i],
se_oracle
);
let mean_se_oracle = se_oracle * normal_pdf(eta[i]);
assert!(
(mean_se[i] - mean_se_oracle).abs() <= 1e-10,
"row {i}: mean-SE {} != eta_se·φ(η) {}",
mean_se[i],
mean_se_oracle
);
let z = crate::probability::standard_normal_quantile(0.975).unwrap();
let lo = normal_cdf(eta[i] - z * se_oracle).clamp(0.0, 1.0);
let hi = normal_cdf(eta[i] + z * se_oracle).clamp(0.0, 1.0);
let mean = normal_cdf(eta[i]);
assert!(
lo <= mean + 1e-12 && hi >= mean - 1e-12,
"row {i}: band brackets mean"
);
assert!((0.0..=1.0).contains(&lo) && (0.0..=1.0).contains(&hi));
assert!(
hi - lo > 0.0,
"row {i}: TransformEta band has positive width"
);
}
}
#[test]
fn saved_anchored_deviation_runtime_basis_cubic_matches_basis_column() {
let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
let prepared = crate::families::bms::build_score_warp_deviation_block_from_seed(
&seed,
&crate::families::bms::DeviationBlockConfig {
num_internal_knots: 4,
..Default::default()
},
)
.expect("build saved anchored deviation runtime");
let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
let cubic = runtime.basis_span_cubic(0, 1).expect("basis span cubic");
let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
let design = runtime.design(&x_eval).expect("basis design");
let d1 = runtime
.first_derivative_design(&x_eval)
.expect("basis d1 design");
for i in 0..x_eval.len() {
let x = x_eval[i];
assert!((cubic.evaluate(x) - design[[i, 1]]).abs() < 1e-10);
assert!((cubic.first_derivative(x) - d1[[i, 1]]).abs() < 1e-10);
let selected = runtime.basis_cubic_at(1, x).expect("basis cubic at x");
let expected_span_idx = 0;
let expected_cubic = runtime
.basis_span_cubic(expected_span_idx, 1)
.expect("expected basis span cubic");
assert_eq!(selected.left, expected_cubic.left);
assert_eq!(selected.right, expected_cubic.right);
}
}
#[test]
fn predict_royston_parmar_point_prediction_returns_survival_probability() {
let x = array![[1.0], [1.0]];
let beta = array![0.4];
let offset = array![0.0, 0.8];
let out = predict_gam(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::royston_parmar(),
)
.expect("royston-parmar point prediction");
let expected_eta = array![0.4, 1.2];
let expected_mean = expected_eta.mapv(|eta: f64| (-(eta.exp())).exp().clamp(0.0, 1.0));
for i in 0..out.eta.len() {
assert!(
(out.eta[i] - expected_eta[i]).abs() <= 1e-14,
"eta[{i}] mismatch"
);
}
for i in 0..out.mean.len() {
assert!((out.mean[i] - expected_mean[i]).abs() <= 1e-12);
}
}
#[test]
fn predict_royston_parmar_posterior_mean_matches_quadrature_and_fit_path() {
let x = array![[1.0], [1.0]];
let beta = array![0.35];
let offset = array![0.0, 0.0];
let covariance = Array2::from_diag(&array![0.09]);
let fit = test_fit_with_covariance(beta.clone(), covariance.clone());
let out = predict_gam_posterior_mean(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::royston_parmar(),
covariance.view(),
)
.expect("royston-parmar posterior mean");
let out_with_fit = predict_gam_posterior_meanwith_fit(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::royston_parmar(),
covariance.view(),
&fit,
)
.expect("royston-parmar posterior mean with fit");
let quadctx = crate::quadrature::QuadratureContext::new();
let expected = crate::quadrature::survival_posterior_mean(&quadctx, 0.35, 0.3);
for i in 0..out.mean.len() {
assert!((out.mean[i] - expected).abs() <= 1e-12);
assert!((out_with_fit.mean[i] - expected).abs() <= 1e-12);
assert!((out_with_fit.mean[i] - out.mean[i]).abs() <= 1e-12);
assert!(
(out_with_fit.eta_standard_error[i] - out.eta_standard_error[i]).abs() <= 1e-12
);
}
}
#[test]
fn predict_royston_parmar_uncertainty_clamps_and_orders_intervals() {
let x = array![[1.0]];
let beta = array![0.6];
let offset = array![0.0];
let covariance = Array2::from_diag(&array![0.25]);
let fit = test_fit_with_covariance(beta.clone(), covariance);
let options = PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::TransformEta,
includeobservation_interval: false,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
};
let out = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::royston_parmar(),
&fit,
&options,
)
.expect("royston-parmar uncertainty");
let quadctx = crate::quadrature::QuadratureContext::new();
let (_, variance) = crate::quadrature::survival_posterior_meanvariance(&quadctx, 0.6, 0.5);
assert!((out.mean[0] - (-(0.6_f64.exp())).exp()).abs() <= 1e-12);
assert!((out.eta_standard_error[0] - 0.5).abs() <= 1e-12);
assert!((out.mean_standard_error[0] - variance.sqrt()).abs() <= 1e-12);
assert!(out.mean_lower[0] <= out.mean_upper[0]);
assert!((0.0..=1.0).contains(&out.mean_lower[0]));
assert!((0.0..=1.0).contains(&out.mean_upper[0]));
}
#[test]
fn extrapolation_variance_adds_to_eta_variance_after_inflations() {
let x = array![[1.0], [1.0]];
let beta = array![0.5];
let offset = array![0.0, 0.0];
let covariance = Array2::from_diag(&array![0.16]);
let fit = test_fit_with_covariance(beta.clone(), covariance);
let base_options = PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::TransformEta,
includeobservation_interval: false,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
};
let options_fused = PredictUncertaintyOptions {
extrapolation_variance: Some(array![0.0, 0.09]),
..base_options.clone()
};
let baseline = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&base_options,
)
.expect("baseline gaussian uncertainty");
let fused = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&options_fused,
)
.expect("fused gaussian uncertainty");
assert!((baseline.eta_standard_error[0] - 0.4).abs() <= 1e-12);
assert!((baseline.eta_standard_error[1] - 0.4).abs() <= 1e-12);
assert!((fused.eta_standard_error[0] - 0.4).abs() <= 1e-12);
assert!((fused.eta_standard_error[1] - 0.5).abs() <= 1e-12);
assert!((fused.mean_standard_error[1] - 0.5).abs() <= 1e-12);
assert!(
fused.mean_upper[1] - fused.mean_lower[1]
> baseline.mean_upper[1] - baseline.mean_lower[1]
);
let options_mismatched = PredictUncertaintyOptions {
extrapolation_variance: Some(array![0.09]),
..base_options
};
let err = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&options_mismatched,
)
.expect_err("length mismatch must be rejected");
assert!(
err.to_string().contains("extrapolation_variance length"),
"unexpected error: {err}"
);
}
#[test]
fn gaussian_location_scale_sigma_includes_noise_offset() {
let predictor = GaussianLocationScalePredictor {
beta_mu: array![0.0],
beta_noise: array![0.0],
sigma_floor: crate::families::sigma_link::LOGB_SIGMA_FLOOR,
covariance: None,
link_wiggle: None,
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0], [1.0]]),
offset: array![0.0, 0.0],
design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
offset_noise: Some(array![(3.0f64).ln(), (5.0f64).ln()]),
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let sigma = predictor
.predict_noise_scale(&input)
.expect("gaussian location-scale sigma")
.expect("sigma should be returned");
assert!((sigma[0] - 3.01).abs() <= 1e-12);
assert!((sigma[1] - 5.01).abs() <= 1e-12);
let out = predictor
.predict_with_uncertainty(&input)
.expect("gaussian location-scale uncertainty");
assert!(out.eta_se.is_none());
assert!(out.mean_se.is_none());
}
#[test]
fn gaussian_location_scale_eta_se_pads_scale_block_without_wiggle() {
let predictor = GaussianLocationScalePredictor {
beta_mu: array![0.5],
beta_noise: array![0.1],
sigma_floor: crate::families::sigma_link::LOGB_SIGMA_FLOOR,
covariance: Some(array![[4.0, 0.0], [0.0, 9.0]]),
link_wiggle: None,
};
let fit = gaussian_location_scale_fit_with_covariance(
array![0.5],
array![0.1],
array![[4.0, 0.0], [0.0, 9.0]],
);
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: None,
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let out = predictor
.predict_posterior_mean(&input, &fit, &PosteriorMeanOptions::point_only())
.expect("gaussian location-scale posterior mean");
assert!((out.eta_standard_error[0] - 2.0).abs() <= 1e-12);
}
#[test]
fn gaussian_location_scale_required_corrected_covariance_uses_corrected_backend() {
let predictor = GaussianLocationScalePredictor {
beta_mu: array![0.0],
beta_noise: array![0.0],
sigma_floor: crate::families::sigma_link::LOGB_SIGMA_FLOOR,
covariance: Some(array![[1.0, 0.0], [0.0, 0.0]]),
link_wiggle: None,
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: None,
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let options = PredictUncertaintyOptions {
covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingRequired,
includeobservation_interval: false,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
};
let corrected_fit = gaussian_location_scale_fit_with_covariance_and_corrected(
array![0.0],
array![0.0],
array![[1.0, 0.0], [0.0, 0.0]],
Some(array![[9.0, 0.0], [0.0, 0.0]]),
);
let out = predictor
.predict_full_uncertainty(&input, &corrected_fit, &options)
.expect("required corrected covariance should be available");
assert!((out.eta_standard_error[0] - 3.0).abs() <= 1e-12);
assert!(out.covariance_corrected_used);
let missing_fit = gaussian_location_scale_fit_with_covariance(
array![0.0],
array![0.0],
array![[1.0, 0.0], [0.0, 0.0]],
);
let err = match predictor.predict_full_uncertainty(&input, &missing_fit, &options) {
Ok(_) => panic!("required corrected covariance must error when unavailable"),
Err(err) => err.to_string(),
};
assert!(
err.contains("smoothing-corrected covariance"),
"unexpected required-covariance error: {err}"
);
}
#[test]
fn survival_eta_se_pads_log_sigma_block() {
let predictor = SurvivalPredictor {
beta_threshold: array![0.5],
beta_log_sigma: array![0.0],
inverse_link: InverseLink::Standard(StandardLink::Probit),
covariance: Some(array![[9.0, 0.0], [0.0, 16.0]]),
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: Some(array![0.0]),
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let out = predictor
.predict_with_uncertainty(&input)
.expect("survival uncertainty");
let eta_se = out.eta_se.expect("eta_se should be present");
assert!((eta_se[0] - 3.0).abs() <= 1e-12);
}
#[test]
fn survival_predictor_cloglog_point_and_se_use_upper_tail_at_q0() {
let predictor = SurvivalPredictor {
beta_threshold: array![-1.0],
beta_log_sigma: array![0.0],
inverse_link: InverseLink::Standard(StandardLink::CLogLog),
covariance: Some(array![[4.0, 0.0], [0.0, 0.0]]),
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: Some(array![0.0]),
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let out = predictor
.predict_with_uncertainty(&input)
.expect("cloglog survival prediction");
let q0 = 1.0_f64;
let expected_survival = (-(q0.exp())).exp();
let expected_mean_se = 2.0 * (q0 - q0.exp()).exp();
assert!((out.mean[0] - expected_survival).abs() <= 1e-12);
assert!(
(out.mean_se.expect("mean_se should be present")[0] - expected_mean_se).abs() <= 1e-12
);
}
#[test]
fn survival_predictor_cloglog_posterior_mean_zero_covariance_matches_point_prediction() {
let predictor = SurvivalPredictor {
beta_threshold: array![-1.0],
beta_log_sigma: array![0.0],
inverse_link: InverseLink::Standard(StandardLink::CLogLog),
covariance: Some(Array2::zeros((2, 2))),
};
let fit = survival_fit_with_covariance(array![-1.0], array![0.0], Array2::zeros((2, 2)));
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: Some(array![0.0]),
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let point = predictor
.predict_plugin_response(&input)
.expect("cloglog survival point prediction");
let posterior = predictor
.predict_posterior_mean(&input, &fit, &PosteriorMeanOptions::point_only())
.expect("cloglog survival posterior mean");
assert!((posterior.mean[0] - point.mean[0]).abs() <= 1e-12);
}
#[test]
fn survival_predictor_zero_threshold_with_tiny_sigma_stays_finite() {
let predictor = SurvivalPredictor {
beta_threshold: array![0.0],
beta_log_sigma: array![0.0],
inverse_link: InverseLink::Standard(StandardLink::CLogLog),
covariance: None,
};
let input = PredictInput {
design: DesignMatrix::from(array![[1.0]]),
offset: array![0.0],
design_noise: Some(DesignMatrix::from(array![[1.0]])),
offset_noise: Some(array![-1000.0]),
auxiliary_scalar: None,
auxiliary_matrix: None,
};
let point = predictor
.predict_plugin_response(&input)
.expect("cloglog survival point prediction");
let expected = (-1.0_f64).exp();
assert!(point.mean[0].is_finite());
assert!((point.mean[0] - expected).abs() <= 1e-12);
}
fn test_fit_with_bias_correction(
beta: Array1<f64>,
covariance: Array2<f64>,
bias_correction_beta: Option<Array1<f64>>,
) -> UnifiedFitResult {
use crate::estimate::FitInference;
let p = beta.len();
let inf = FitInference {
edf_by_block: vec![],
penalty_block_trace: vec![],
edf_total: p as f64,
smoothing_correction: None,
penalized_hessian: Array2::<f64>::eye(p).into(),
working_weights: Array1::zeros(0),
working_response: Array1::zeros(0),
reparam_qs: None,
dispersion: crate::estimate::Dispersion::Known(1.0),
beta_covariance: Some(covariance.clone().into()),
beta_standard_errors: None,
beta_covariance_corrected: None,
beta_standard_errors_corrected: None,
beta_covariance_frequentist: None,
coefficient_influence: None,
weighted_gram: None,
bias_correction_beta,
};
UnifiedFitResult::new_for_test_unchecked(UnifiedFitResultParts {
blocks: vec![FittedBlock {
beta: beta.clone(),
role: BlockRole::Mean,
edf: p as f64,
lambdas: Array1::zeros(0),
}],
log_lambdas: Array1::zeros(0),
lambdas: Array1::zeros(0),
likelihood_family: Some(crate::types::LikelihoodSpec::gaussian_identity()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
log_likelihood: 0.0,
deviance: 0.0,
reml_score: 0.0,
stable_penalty_term: 0.0,
penalized_objective: 0.0,
used_device: false,
outer_iterations: 0,
outer_converged: true,
outer_gradient_norm: None,
standard_deviation: 1.0,
covariance_conditional: Some(covariance),
covariance_corrected: None,
inference: Some(inf),
fitted_link: FittedLinkState::Standard(Some(StandardLink::Identity)),
geometry: None,
block_states: Vec::new(),
pirls_status: PirlsStatus::Converged,
max_abs_eta: 0.0,
constraint_kkt: None,
artifacts: FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
}
fn bc_options(apply: bool) -> PredictUncertaintyOptions {
PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::TransformEta,
includeobservation_interval: false,
apply_bias_correction: apply,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
}
}
#[test]
fn test_bias_correction_idempotent_with_flag() {
let x = array![[1.0, 0.5]];
let beta = array![1.0, 2.0];
let bc = array![0.1, -0.05];
let cov = Array2::<f64>::eye(2);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
let offset = array![0.0];
let pred_off = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("predict no-bc");
let pred_on = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("predict bc");
assert!((pred_off.eta[0] - 2.0).abs() < 1e-12);
let expected_delta = 1.0 * 0.1 + 0.5 * (-0.05);
assert!((pred_on.eta[0] - (2.0 + expected_delta)).abs() < 1e-12);
assert!(
(pred_off.eta_standard_error[0] - pred_on.eta_standard_error[0]).abs() < 1e-14,
"bias correction must not affect eta standard error"
);
}
#[test]
fn test_bias_correction_zero_when_unset() {
let x = array![[1.0, 0.5]];
let beta = array![1.0, 2.0];
let cov = Array2::<f64>::eye(2);
let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
let offset = array![0.0];
let pred = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("predict");
assert!((pred.eta[0] - 2.0).abs() < 1e-12);
}
#[test]
fn test_bias_correction_does_not_affect_posterior_se() {
let x = array![[1.0, 0.5], [0.7, -0.3]];
let beta = array![0.4, 0.9];
let bc = array![0.2, -0.1];
let cov = array![[1.0, 0.1], [0.1, 0.5]];
let fit_with = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(bc));
let fit_without = test_fit_with_bias_correction(beta.clone(), cov, None);
let offset = array![0.0, 0.0];
let pred_with = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit_with,
&bc_options(true),
)
.expect("predict with bc");
let pred_without = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit_without,
&bc_options(true),
)
.expect("predict without bc");
for i in 0..2 {
assert!(
(pred_with.eta_standard_error[i] - pred_without.eta_standard_error[i]).abs()
< 1e-14,
"BC must not perturb eta SE at index {i}"
);
}
}
#[test]
fn test_bias_correction_accessor_propagates() {
let beta = array![1.0, 2.0];
let bc = array![0.3, -0.2];
let cov = Array2::<f64>::eye(2);
let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
let recovered = fit
.bias_correction_beta()
.expect("bias correction should be present");
assert_eq!(recovered.len(), bc.len());
for i in 0..bc.len() {
assert!((recovered[i] - bc[i]).abs() < 1e-15);
}
}
fn solve_3x3_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
assert_eq!(h.nrows(), 3);
assert_eq!(h.ncols(), 3);
let m = |i: usize, j: usize| h[[i, j]];
let det = m(0, 0) * (m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1))
- m(0, 1) * (m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0))
+ m(0, 2) * (m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0));
assert!(det.abs() > 1e-12, "singular matrix in solve_3x3_spd");
let cof = array![
[
m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1),
-(m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0)),
m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0)
],
[
-(m(0, 1) * m(2, 2) - m(0, 2) * m(2, 1)),
m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0),
-(m(0, 0) * m(2, 1) - m(0, 1) * m(2, 0))
],
[
m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1),
-(m(0, 0) * m(1, 2) - m(0, 2) * m(1, 0)),
m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0)
]
];
let mut y = Array1::<f64>::zeros(3);
for i in 0..3 {
let mut acc = 0.0;
for j in 0..3 {
acc += cof[[j, i]] * r[j];
}
y[i] = acc / det;
}
y
}
struct Lcg(u64);
impl Lcg {
fn new(seed: u64) -> Self {
Self(
seed.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407),
)
}
fn next_u64(&mut self) -> u64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.0
}
fn unif(&mut self) -> f64 {
((self.next_u64() >> 11) as f64) / ((1u64 << 53) as f64)
}
fn normal(&mut self) -> f64 {
let u1 = self.unif().max(1e-300);
let u2 = self.unif();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
}
#[test]
fn test_bias_correction_matches_explicit_formula() {
let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.0]];
let s_pen = array![[1.0_f64, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 2.0]];
let beta = array![0.7_f64, -1.3, 0.4];
let s_beta = s_pen.dot(&beta);
let b_hat = solve_3x3_spd(&h, &s_beta);
let cov = Array2::<f64>::eye(3);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat.clone()));
let x = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let offset = array![0.0, 0.0, 0.0];
let pred_raw = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("raw predict");
let pred_bc = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("bc predict");
for i in 0..3 {
assert!(
(pred_raw.eta[i] - beta[i]).abs() < 1e-12,
"raw eta[{i}] = {} expected {}",
pred_raw.eta[i],
beta[i]
);
let expected = beta[i] + b_hat[i];
assert!(
(pred_bc.eta[i] - expected).abs() < 1e-12,
"BC eta[{i}] = {} expected β+b̂ = {} (b̂[{i}] = {})",
pred_bc.eta[i],
expected,
b_hat[i]
);
}
}
#[test]
fn test_bias_correction_zero_for_zero_penalty() {
let beta = array![0.5_f64, -0.4, 1.7];
let bc_zero = Array1::<f64>::zeros(3);
let cov = Array2::<f64>::eye(3);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc_zero));
let x = array![[1.0, 2.0, -0.5], [0.3, -0.7, 1.2], [2.0, 0.1, 0.0]];
let offset = array![0.0, 0.0, 0.0];
let pred_raw = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("raw predict");
let pred_bc = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("bc predict");
for i in 0..3 {
assert!(
(pred_bc.eta[i] - pred_raw.eta[i]).abs() < 1e-15,
"S=0 ⇒ BC must be a no-op; got Δ={} at i={i}",
pred_bc.eta[i] - pred_raw.eta[i]
);
}
}
#[test]
fn test_bias_correction_increases_with_penalty_strength() {
let h_base = array![[3.0_f64, 0.4, 0.1], [0.4, 2.5, 0.2], [0.1, 0.2, 4.0]];
let beta = array![1.2_f64, -0.8, 0.5];
let x = array![[1.0, 0.5, -0.2], [0.3, -0.4, 0.9], [0.7, 0.7, 0.7]];
let offset = array![0.0, 0.0, 0.0];
let lambdas = [0.1_f64, 1.0, 10.0];
let mut deltas = Vec::with_capacity(lambdas.len());
for &lam in &lambdas {
let mut h = h_base.clone();
for k in 0..3 {
h[[k, k]] += lam;
}
let s_beta = beta.mapv(|v| lam * v);
let b_hat = solve_3x3_spd(&h, &s_beta);
let cov = Array2::<f64>::eye(3);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat));
let pred_raw = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("raw predict");
let pred_bc = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("bc predict");
let mut sumsq = 0.0;
for i in 0..3 {
let d = pred_bc.eta[i] - pred_raw.eta[i];
sumsq += d * d;
}
deltas.push(sumsq.sqrt());
}
assert!(
deltas[0] < deltas[1],
"‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
lambdas[0],
deltas[0],
lambdas[1],
deltas[1]
);
assert!(
deltas[1] < deltas[2],
"‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
lambdas[1],
deltas[1],
lambdas[2],
deltas[2]
);
assert!(
deltas[2] > 10.0 * deltas[0],
"expected order-of-magnitude growth in BC magnitude across λ ∈ {{0.1,1,10}}; got {:?}",
deltas
);
}
#[test]
fn test_bias_correction_recovers_unpenalized_in_simulation() {
let n = 200usize;
let p = 5usize;
let mut rng = Lcg::new(0xC0FFEE_u64);
let mut x_data = vec![0.0_f64; n * p];
for i in 0..n {
x_data[i * p] = 1.0;
for j in 1..p {
x_data[i * p + j] = rng.normal();
}
}
let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
let beta_true = array![0.5_f64, 1.0, -0.7, 0.3, 0.8];
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut eta = 0.0;
for j in 0..p {
eta += x[[i, j]] * beta_true[j];
}
y[i] = eta + 0.3 * rng.normal();
}
let xtx = x.t().dot(&x);
let xty = x.t().dot(&y);
let beta_ols = solve_dense_spd(&xtx, &xty);
let shrink = 0.6_f64;
let beta_hat = beta_ols.mapv(|v| shrink * v);
let lambda = 100.0_f64;
let mut h = xtx.clone();
for k in 0..p {
h[[k, k]] += lambda;
}
let s_beta = beta_hat.mapv(|v| lambda * v);
let b_hat = solve_dense_spd(&h, &s_beta);
let cov = Array2::<f64>::eye(p);
let fit = test_fit_with_bias_correction(beta_hat.clone(), cov, Some(b_hat.clone()));
let m = 50usize;
let mut xt_data = vec![0.0_f64; m * p];
for i in 0..m {
xt_data[i * p] = 1.0;
for j in 1..p {
xt_data[i * p + j] = rng.normal();
}
}
let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
let offset = Array1::<f64>::zeros(m);
let pred_raw = predict_gamwith_uncertainty(
xt.clone(),
beta_hat.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("raw predict");
let pred_bc = predict_gamwith_uncertainty(
xt.clone(),
beta_hat.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("bc predict");
let eta_ols = xt.dot(&beta_ols);
let mut closer = 0usize;
for i in 0..m {
let raw_gap = (eta_ols[i] - pred_raw.eta[i]).abs();
let bc_gap = (eta_ols[i] - pred_bc.eta[i]).abs();
if bc_gap < raw_gap {
closer += 1;
}
}
let frac = closer as f64 / m as f64;
assert!(
frac >= 0.9,
"BC must close the OLS gap at ≥90% of test points; got {}/{} = {:.2}",
closer,
m,
frac
);
}
#[test]
fn test_bias_correction_bias_drops_with_n_simulation() {
let p = 4usize;
let beta_true = array![0.4_f64, 0.9, -0.5, 0.6];
let lambda = 5.0_f64;
let ns = [200usize, 1000, 5000];
let m = 32usize;
let mut probe_rng = Lcg::new(424242);
let mut xt_data = vec![0.0_f64; m * p];
for i in 0..m {
xt_data[i * p] = 1.0;
for j in 1..p {
xt_data[i * p + j] = probe_rng.normal();
}
}
let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
let eta_true = xt.dot(&beta_true);
let offset = Array1::<f64>::zeros(m);
let mut mean_abs_raw_bias = [0.0_f64; 3];
let mut mean_abs_bc_bias = [0.0_f64; 3];
let bias_by_n: Vec<(usize, f64, f64)> = (0..ns.len())
.into_par_iter()
.map(|kn| {
let n = ns[kn];
let mut rng = Lcg::new(0xBEEFu64);
let mut x_data = vec![0.0_f64; n * p];
for i in 0..n {
x_data[i * p] = 1.0;
for j in 1..p {
x_data[i * p + j] = rng.normal();
}
}
let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
let xtx = x.t().dot(&x);
let mut h = xtx.clone();
for k in 0..p {
h[[k, k]] += lambda;
}
let xtx_beta = xtx.dot(&beta_true);
let beta_mean = solve_dense_spd(&h, &xtx_beta);
let s_beta_mean = beta_mean.mapv(|v| lambda * v);
let b_hat = solve_dense_spd(&h, &s_beta_mean);
let cov = Array2::<f64>::eye(p);
let fit = test_fit_with_bias_correction(beta_mean.clone(), cov, Some(b_hat));
let pred_raw = predict_gamwith_uncertainty(
xt.clone(),
beta_mean.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("raw predict");
let pred_bc = predict_gamwith_uncertainty(
xt.clone(),
beta_mean.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("bc predict");
let mut acc_raw = 0.0;
let mut acc_bc = 0.0;
for i in 0..m {
acc_raw += (pred_raw.eta[i] - eta_true[i]).abs();
acc_bc += (pred_bc.eta[i] - eta_true[i]).abs();
}
(kn, acc_raw / m as f64, acc_bc / m as f64)
})
.collect();
for (kn, raw, bc) in bias_by_n {
mean_abs_raw_bias[kn] = raw;
mean_abs_bc_bias[kn] = bc;
}
assert!(
mean_abs_raw_bias[2] < mean_abs_raw_bias[0],
"raw penalized conditional bias should shrink with n: got {:?}",
mean_abs_raw_bias
);
let ratio_large = mean_abs_bc_bias[2] / mean_abs_raw_bias[2].max(1e-300);
assert!(
ratio_large < 0.5,
"BC must reduce conditional bias by >2× at n={}; raw={}, bc={}, ratio={}",
ns[2],
mean_abs_raw_bias[2],
mean_abs_bc_bias[2],
ratio_large
);
let ratio_small = mean_abs_bc_bias[0] / mean_abs_raw_bias[0].max(1e-300);
assert!(
ratio_large <= ratio_small + 1e-6,
"BC/raw ratio should not grow with n: small-n ratio={}, large-n ratio={}",
ratio_small,
ratio_large
);
}
#[test]
fn test_bias_correction_identity_in_basis_change() {
let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.5]];
let s_pen = array![[0.7_f64, 0.1, 0.0], [0.1, 0.5, 0.05], [0.0, 0.05, 1.2]];
let beta = array![0.6_f64, -0.4, 1.1];
let s_beta = s_pen.dot(&beta);
let b_hat = solve_3x3_spd(&h, &s_beta);
let q = array![[1.0_f64, 0.3, -0.2], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]];
let qinv = invert_upper_triangular_3(&q);
let theta = qinv.dot(&beta);
let b_tilde = qinv.dot(&b_hat);
let x_row = array![[0.4_f64, -0.7, 0.9]];
let mut x_tilde = Array2::<f64>::zeros((1, 3));
for j in 0..3 {
let mut acc = 0.0;
for i in 0..3 {
acc += q[[i, j]] * x_row[[0, i]];
}
x_tilde[[0, j]] = acc;
}
let offset = array![0.0_f64];
let cov = Array2::<f64>::eye(3);
let fit_orig = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(b_hat));
let fit_repar = test_fit_with_bias_correction(theta.clone(), cov, Some(b_tilde));
let pred_orig = predict_gamwith_uncertainty(
x_row,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit_orig,
&bc_options(true),
)
.expect("orig predict");
let pred_repar = predict_gamwith_uncertainty(
x_tilde,
theta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit_repar,
&bc_options(true),
)
.expect("repar predict");
assert!(
(pred_orig.eta[0] - pred_repar.eta[0]).abs() < 1e-12,
"BC must be invariant under reparameterization: orig η={} repar η={} Δ={}",
pred_orig.eta[0],
pred_repar.eta[0],
(pred_orig.eta[0] - pred_repar.eta[0]).abs()
);
}
#[test]
fn test_bias_correction_does_not_inflate_se() {
let p = 4usize;
let beta = array![0.5_f64, -0.7, 1.1, 0.3];
let cov = array![
[2.0_f64, 0.3, 0.1, 0.0],
[0.3, 1.5, 0.2, 0.05],
[0.1, 0.2, 1.8, 0.1],
[0.0, 0.05, 0.1, 2.2]
];
let bc = array![0.2_f64, -0.15, 0.05, 0.1];
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
let m = 100usize;
let mut rng = Lcg::new(0xBEEFCAFE_u64);
let mut x_data = vec![0.0_f64; m * p];
for i in 0..m {
for j in 0..p {
x_data[i * p + j] = rng.normal();
}
}
let x = Array2::from_shape_vec((m, p), x_data).expect("X shape");
let offset = Array1::<f64>::zeros(m);
let pred_off = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("predict no-bc");
let pred_on = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("predict bc");
for i in 0..m {
let a = pred_off.eta_standard_error[i];
let b = pred_on.eta_standard_error[i];
let rel = (a - b).abs() / a.abs().max(b.abs()).max(1e-300);
assert!(
rel < 1e-14,
"SE leakage detected at i={}: off={}, on={}, relΔ={}",
i,
a,
b,
rel
);
}
}
#[test]
fn test_bias_correction_finite_for_pathological_inputs() {
let beta = array![1.0_f64, f64::NAN, 0.5];
let bc = array![0.1_f64, 0.2, f64::INFINITY];
let cov = Array2::<f64>::eye(3);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
let x = array![[1.0_f64, 1.0, 1.0]];
let offset = array![0.0_f64];
let pred = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("pathological predict should not error, only propagate NaN/Inf");
assert!(
!pred.eta[0].is_finite(),
"expected non-finite η to propagate; got η = {}",
pred.eta[0]
);
}
#[test]
fn test_bias_correction_disabled_via_options_returns_raw() {
let beta = array![1.5_f64, -0.7];
let bc = array![0.4_f64, -0.3];
let cov = Array2::<f64>::eye(2);
let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
let x = array![[1.0_f64, 0.5], [0.7, -0.3]];
let offset = array![0.0_f64, 0.0];
let pred = predict_gamwith_uncertainty(
x.clone(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(false),
)
.expect("predict no-bc");
let expected = x.dot(&beta);
for i in 0..2 {
let d = (pred.eta[i] - expected[i]).abs();
assert!(
d < 1e-15,
"apply_bias_correction=false must return raw plug-in: η[{i}]={} expected={} Δ={}",
pred.eta[i],
expected[i],
d
);
}
}
#[test]
fn test_bias_correction_with_nonidentity_covariance_uses_correct_h() {
let h_true = array![[5.0_f64, 0.7, 0.2], [0.7, 4.0, 0.3], [0.2, 0.3, 3.5]];
let s_pen = array![[0.8_f64, 0.0, 0.0], [0.0, 1.2, 0.0], [0.0, 0.0, 0.6]];
let beta = array![0.9_f64, -1.1, 0.4];
let s_beta = s_pen.dot(&beta);
let b_hat_correct = solve_3x3_spd(&h_true, &s_beta);
let cov_wrong = array![[2.0_f64, 0.4, 0.0], [0.4, 1.5, 0.3], [0.0, 0.3, 1.8]];
let h_inv = invert_3x3_spd(&h_true);
let mut diff = 0.0;
for i in 0..3 {
for j in 0..3 {
diff += (h_inv[[i, j]] - cov_wrong[[i, j]]).abs();
}
}
assert!(
diff > 0.5,
"test setup error: cov_wrong should be far from H_true⁻¹ (diff={})",
diff
);
let fit =
test_fit_with_bias_correction(beta.clone(), cov_wrong, Some(b_hat_correct.clone()));
let x = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let offset = array![0.0_f64, 0.0, 0.0];
let pred = predict_gamwith_uncertainty(
x,
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&bc_options(true),
)
.expect("predict bc");
for i in 0..3 {
let expected = beta[i] + b_hat_correct[i];
assert!(
(pred.eta[i] - expected).abs() < 1e-12,
"prediction must use the supplied bias_correction_beta verbatim: \
η[{i}]={} expected={} (β+b̂_correct[{i}]={})",
pred.eta[i],
expected,
b_hat_correct[i]
);
}
}
#[test]
fn test_bias_correction_propagates_through_unified_fit_result() {
let beta = array![0.7_f64, -0.4, 1.2];
let bc = array![0.123456789_f64, -0.987654321, 0.5];
let cov = Array2::<f64>::eye(3);
let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
let json = serde_json::to_string(&fit).expect("serialize unified fit");
let decoded: UnifiedFitResult =
serde_json::from_str(&json).expect("deserialize unified fit");
let recovered = decoded
.bias_correction_beta()
.expect("bias_correction_beta must survive JSON round-trip");
assert_eq!(
recovered.len(),
bc.len(),
"bc length changed across round-trip"
);
for i in 0..bc.len() {
assert!(
(recovered[i] - bc[i]).abs() < 1e-15,
"bc[{i}] drifted across JSON round-trip: in={}, out={}",
bc[i],
recovered[i]
);
}
}
fn solve_dense_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
let n = h.nrows();
assert_eq!(h.ncols(), n);
assert_eq!(r.len(), n);
let mut a = Array2::<f64>::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
a[[i, j]] = h[[i, j]];
}
a[[i, n]] = r[i];
}
for k in 0..n {
let mut piv = k;
let mut best = a[[k, k]].abs();
for i in (k + 1)..n {
if a[[i, k]].abs() > best {
best = a[[i, k]].abs();
piv = i;
}
}
assert!(best > 1e-14, "near-singular system in solve_dense_spd");
if piv != k {
for j in 0..=n {
let tmp = a[[k, j]];
a[[k, j]] = a[[piv, j]];
a[[piv, j]] = tmp;
}
}
for i in (k + 1)..n {
let factor = a[[i, k]] / a[[k, k]];
for j in k..=n {
a[[i, j]] -= factor * a[[k, j]];
}
}
}
let mut y = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut acc = a[[i, n]];
for j in (i + 1)..n {
acc -= a[[i, j]] * y[j];
}
y[i] = acc / a[[i, i]];
}
y
}
fn invert_3x3_spd(h: &Array2<f64>) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((3, 3));
for col in 0..3 {
let mut e = Array1::<f64>::zeros(3);
e[col] = 1.0;
let v = solve_3x3_spd(h, &e);
for row in 0..3 {
out[[row, col]] = v[row];
}
}
out
}
fn invert_upper_triangular_3(q: &Array2<f64>) -> Array2<f64> {
let a = q[[0, 1]];
let b = q[[0, 2]];
let c = q[[1, 2]];
array![[1.0, -a, a * c - b], [0.0, 1.0, -c], [0.0, 0.0, 1.0]]
}
fn coverage_correction_fixture() -> (UnifiedFitResult, Array2<f64>, Array1<f64>, Array1<f64>) {
let beta = array![1.0];
let cov = array![[0.25_f64]];
let fit = test_fit_with_bias_correction(beta.clone(), cov.clone(), None);
let x = array![[1.0_f64]];
let offset = array![0.0_f64];
(fit, x, beta, offset)
}
fn corrections_baseline_options() -> PredictUncertaintyOptions {
PredictUncertaintyOptions {
confidence_level: 0.95,
covariance_mode: InferenceCovarianceMode::Conditional,
mean_interval_method: MeanIntervalMethod::TransformEta,
includeobservation_interval: false,
apply_bias_correction: false,
edgeworth_one_sided: false,
boundary_correction: false,
ood_inflation: false,
multi_point_joint: false,
..PredictUncertaintyOptions::default()
}
}
#[test]
fn coverage_corrections_all_off_matches_legacy() {
let (fit, x, beta, offset) = coverage_correction_fixture();
let opts = corrections_baseline_options();
let pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("prediction baseline");
let z = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
let expected_se = (0.25_f64).sqrt();
assert!((pred.eta_standard_error[0] - expected_se).abs() <= 1e-12);
let expected_lower = 1.0 - z * expected_se;
let expected_upper = 1.0 + z * expected_se;
assert!(
(pred.eta_lower[0] - expected_lower).abs() <= 1e-12,
"baseline lower drifted: got {}, expected {}",
pred.eta_lower[0],
expected_lower
);
assert!(
(pred.eta_upper[0] - expected_upper).abs() <= 1e-12,
"baseline upper drifted: got {}, expected {}",
pred.eta_upper[0],
expected_upper
);
}
#[test]
fn edgeworth_one_sided_makes_interval_asymmetric_with_positive_skew() {
let (fit, x, beta, offset) = coverage_correction_fixture();
let mut opts = corrections_baseline_options();
opts.edgeworth_one_sided = true;
opts.eta_skewness_for_corrections = Some(array![0.6_f64]);
let pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("edgeworth prediction");
let dist_upper = pred.eta_upper[0] - 1.0;
let dist_lower = 1.0 - pred.eta_lower[0];
assert!(
dist_upper > dist_lower + 1e-9,
"positive skew should push upper tail further than lower: \
upper-dist={dist_upper}, lower-dist={dist_lower}"
);
opts.eta_skewness_for_corrections = Some(array![0.0_f64]);
let pred_sym = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("edgeworth zero-skew prediction");
let sym_upper = pred_sym.eta_upper[0] - 1.0;
let sym_lower = 1.0 - pred_sym.eta_lower[0];
assert!((sym_upper - sym_lower).abs() <= 1e-12);
}
#[test]
fn boundary_correction_widens_interval_near_edge() {
let beta = array![1.0_f64];
let cov = array![[0.25_f64]];
let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
let x = array![[1.0_f64], [1.0_f64]];
let offset = array![0.0_f64, 0.0_f64];
let mut opts = corrections_baseline_options();
opts.boundary_correction = true;
opts.predictor_x_for_corrections = Some(array![[5.0_f64], [9.9_f64]]);
opts.training_support = Some(TrainingSupport {
axis_min: array![0.0_f64],
axis_max: array![10.0_f64],
});
let pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("boundary-corrected prediction");
let baseline_se = (0.25_f64).sqrt();
assert!(
(pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12,
"interior row must not be inflated: {} vs {}",
pred.eta_standard_error[0],
baseline_se
);
assert!(
pred.eta_standard_error[1] > baseline_se + 1e-9,
"near-edge row must be inflated: got {}, baseline {}",
pred.eta_standard_error[1],
baseline_se
);
let width0 = pred.eta_upper[0] - pred.eta_lower[0];
let width1 = pred.eta_upper[1] - pred.eta_lower[1];
assert!(
width1 > width0 + 1e-9,
"near-edge interval not wider: width0={width0}, width1={width1}"
);
}
#[test]
fn ood_inflation_widens_interval_outside_support() {
let beta = array![1.0_f64];
let cov = array![[0.25_f64]];
let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
let x = array![[1.0_f64], [1.0_f64]];
let offset = array![0.0_f64, 0.0_f64];
let mut opts = corrections_baseline_options();
opts.ood_inflation = true;
opts.predictor_x_for_corrections = Some(array![[5.0_f64], [15.0_f64]]);
opts.training_support = Some(TrainingSupport {
axis_min: array![0.0_f64],
axis_max: array![10.0_f64],
});
let pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("ood-inflated prediction");
let baseline_se = (0.25_f64).sqrt();
assert!((pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12);
let expected = (0.25_f64 * 1.25).sqrt();
assert!(
(pred.eta_standard_error[1] - expected).abs() <= 1e-12,
"ood inflation factor wrong: got {}, expected {}",
pred.eta_standard_error[1],
expected
);
assert!(pred.eta_standard_error[1] > baseline_se);
}
#[test]
fn multi_point_joint_widens_interval_relative_to_per_row() {
let beta = array![1.0_f64];
let cov = array![[0.25_f64]];
let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
let x = Array2::<f64>::from_elem((5, 1), 1.0_f64);
let offset = Array1::zeros(5);
let mut opts = corrections_baseline_options();
opts.multi_point_joint = true;
let pred = predict_gamwith_uncertainty(
x.view(),
beta.view(),
offset.view(),
crate::types::LikelihoodSpec::gaussian_identity(),
&fit,
&opts,
)
.expect("joint-adjusted prediction");
let z_per_row = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
let z_joint = standard_normal_quantile(0.5 + 0.5 * (1.0 - 0.05_f64 / 5.0)).unwrap();
assert!(
z_joint > z_per_row + 1e-6,
"Bonferroni z must exceed per-row z: joint={z_joint}, per-row={z_per_row}"
);
let baseline_se = (0.25_f64).sqrt();
for i in 0..5 {
let width = pred.eta_upper[i] - pred.eta_lower[i];
let expected = 2.0 * z_joint * baseline_se;
assert!(
(width - expected).abs() <= 1e-12,
"joint row {i} width mismatch: got {width}, expected {expected}"
);
}
}
#[test]
fn edgeworth_helper_zero_skew_returns_central_z() {
let z = 1.96_f64;
let adj = edgeworth_one_sided_quantile(z, 0.0);
assert!((adj.z_lower - z).abs() <= 1e-12);
assert!((adj.z_upper - z).abs() <= 1e-12);
}
#[test]
fn boundary_helper_returns_one_in_interior() {
let f = boundary_variance_inflation_factor(
array![5.0_f64].view(),
array![0.0_f64].view(),
array![10.0_f64].view(),
0.25,
0.05,
);
assert!((f - 1.0).abs() <= 1e-12);
}
#[test]
fn ood_helper_returns_one_inside_box() {
let f = ood_variance_inflation_factor(
array![5.0_f64].view(),
array![0.0_f64].view(),
array![10.0_f64].view(),
1.0,
);
assert!((f - 1.0).abs() <= 1e-12);
}
#[test]
fn multi_point_joint_z_passthrough_at_m_one() {
let z1 = multi_point_joint_z(0.95, 1).unwrap();
let z_baseline = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
assert!((z1 - z_baseline).abs() <= 1e-12);
}
}