use self::reml::{DirectionalHyperParam, RemlState};
use std::fmt;
use std::time::Instant;
use crate::construction::ReparamInvariant;
use crate::diagnostics::should_emit_h_min_eig_diag;
use crate::inference::predict::se_from_covariance;
use crate::linalg::utils::{
KahanSum, add_relative_diag_ridge, enforce_symmetry, matrix_inversewith_regularization,
row_mismatch_message,
};
use crate::matrix::DesignMatrix;
use crate::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
use crate::pirls::{self, PirlsResult};
use crate::seeding::{SeedConfig, SeedRiskProfile};
use crate::terms::smooth::BlockwisePenalty;
use crate::types::{
Coefficients, GlmLikelihoodFamily, GlmLikelihoodSpec, InverseLink, LatentCLogLogState,
LikelihoodFamily, LikelihoodScaleMetadata, LinkFunction, LogLikelihoodNormalization,
LogSmoothingParamsView, MixtureLinkState, RidgePassport, SasLinkState,
};
use crate::types::{MixtureLinkSpec, SasLinkSpec};
use ndarray::{Array1, Array2, ArrayView1, Axis, s};
use crate::faer_ndarray::{
FaerArrayView, FaerCholesky, FaerEigh, FaerLinalgError, fast_ab, fast_atb,
};
use faer::{MatRef, Side};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::ops::Range;
#[derive(Clone)]
pub enum PenaltySpec {
Block {
local: Array2<f64>,
col_range: Range<usize>,
structure_hint: Option<crate::terms::smooth::PenaltyStructureHint>,
op: Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>,
},
Dense(Array2<f64>),
}
impl std::fmt::Debug for PenaltySpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PenaltySpec::Block {
local,
col_range,
structure_hint,
op,
} => f
.debug_struct("Block")
.field(
"local",
&format_args!("{}×{}", local.nrows(), local.ncols()),
)
.field("col_range", col_range)
.field("structure_hint", structure_hint)
.field("op", &op.as_ref().map(|o| o.dim()))
.finish(),
PenaltySpec::Dense(m) => f
.debug_tuple("Dense")
.field(&format_args!("{}×{}", m.nrows(), m.ncols()))
.finish(),
}
}
}
impl PenaltySpec {
pub fn col_range(&self, p: usize) -> Range<usize> {
match self {
PenaltySpec::Block { col_range, .. } => col_range.clone(),
PenaltySpec::Dense(m) => {
debug_assert_eq!(m.ncols(), p);
0..p
}
}
}
pub fn op(&self) -> Option<&std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>> {
match self {
PenaltySpec::Block { op, .. } => op.as_ref(),
PenaltySpec::Dense(_) => None,
}
}
pub fn from_blockwise(bp: crate::terms::smooth::BlockwisePenalty) -> Self {
PenaltySpec::Block {
local: bp.local,
col_range: bp.col_range,
structure_hint: bp.structure_hint,
op: bp.op,
}
}
pub fn from_blockwise_ref(bp: &crate::terms::smooth::BlockwisePenalty) -> Self {
PenaltySpec::Block {
local: bp.local.clone(),
col_range: bp.col_range.clone(),
structure_hint: bp.structure_hint.clone(),
op: bp.op.clone(),
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
PenaltySpec::Dense(m) => m.clone(),
PenaltySpec::Block {
local, col_range, ..
} => {
let p = col_range.end.max(local.nrows());
let mut out = Array2::zeros((p, p));
out.slice_mut(s![col_range.clone(), col_range.clone()])
.assign(local);
out
}
}
}
pub fn to_global(&self, p_total: usize) -> Array2<f64> {
match self {
PenaltySpec::Dense(m) => {
debug_assert_eq!(m.nrows(), p_total);
m.clone()
}
PenaltySpec::Block {
local, col_range, ..
} => {
let mut out = Array2::zeros((p_total, p_total));
out.slice_mut(s![col_range.clone(), col_range.clone()])
.assign(local);
out
}
}
}
}
const KAHAN_SWITCH_ELEMS: usize = 10_000;
fn faer_frob_inner(a: MatRef<'_, f64>, b: MatRef<'_, f64>) -> f64 {
let (m, n) = (a.nrows(), a.ncols());
let elem_count = m.saturating_mul(n);
if elem_count < KAHAN_SWITCH_ELEMS {
let mut sum = 0.0_f64;
for j in 0..n {
for i in 0..m {
sum += a[(i, j)] * b[(i, j)];
}
}
sum
} else {
let mut sum = KahanSum::default();
for j in 0..n {
for i in 0..m {
sum.add(a[(i, j)] * b[(i, j)]);
}
}
sum.sum()
}
}
fn kahan_sum<I>(iter: I) -> f64
where
I: IntoIterator<Item = f64>,
{
let mut acc = KahanSum::default();
for value in iter {
acc.add(value);
}
acc.sum()
}
#[derive(Clone, Debug)]
struct ParametricColumnConditioning {
intercept_idx: Option<usize>,
columns: Vec<(usize, f64, f64)>,
}
impl ParametricColumnConditioning {
fn from_column_indices(x: &DesignMatrix, unpenalized_cols: &[usize]) -> Self {
const SCALE_EPS: f64 = 1e-12;
let n = x.nrows();
if n == 0 {
return Self {
intercept_idx: None,
columns: Vec::new(),
};
}
let mut intercept_idx = None;
let mut columns = Vec::new();
for &j in unpenalized_cols {
let col = x.extract_column(j);
let first = col[0];
let is_constant = col.iter().all(|&v| (v - first).abs() <= 1e-12);
if is_constant {
if (first - 1.0).abs() <= 1e-12 && intercept_idx.is_none() {
intercept_idx = Some(j);
}
continue;
}
let mean = col.iter().copied().sum::<f64>() / n as f64;
let var = col
.iter()
.map(|&v| {
let d = v - mean;
d * d
})
.sum::<f64>()
/ n as f64;
if !var.is_finite() || var <= SCALE_EPS * SCALE_EPS {
continue;
}
columns.push((j, mean, var.sqrt()));
}
if intercept_idx.is_none() {
for (_, mean, _) in &mut columns {
*mean = 0.0;
}
}
Self {
intercept_idx,
columns,
}
}
fn infer_from_penalty_specs(x: &DesignMatrix, specs: &[PenaltySpec]) -> Self {
let p = x.ncols();
let mut penalized = vec![false; p];
for spec in specs {
let range = spec.col_range(p);
for j in range {
penalized[j] = true;
}
}
let unpenalized: Vec<usize> = (0..p).filter(|&j| !penalized[j]).collect();
Self::from_column_indices(x, &unpenalized)
}
fn is_active(&self) -> bool {
!self.columns.is_empty()
}
fn apply_to_design(&self, x: &DesignMatrix) -> DesignMatrix {
if !self.is_active() {
return x.clone();
}
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(
crate::matrix::ConditionedDesign::new(x.clone(), self.columns.clone()),
)))
}
fn transform_constraint_matrix_to_internal(&self, a_original: &Array2<f64>) -> Array2<f64> {
let mut out = a_original.clone();
for &(j, mean, scale) in &self.columns {
let intercept_col = self.intercept_idx.map(|idx| out.column(idx).to_owned());
let mut target = out.column_mut(j);
if mean != 0.0
&& let Some(intercept_col) = intercept_col
{
target += &(intercept_col * mean);
}
if scale != 1.0 {
target.mapv_inplace(|v| v * scale);
}
}
out
}
fn transform_linear_constraints_to_internal(
&self,
constraints: Option<crate::pirls::LinearInequalityConstraints>,
) -> Option<crate::pirls::LinearInequalityConstraints> {
constraints.map(|constraints| crate::pirls::LinearInequalityConstraints {
a: self.transform_constraint_matrix_to_internal(&constraints.a),
b: constraints.b,
})
}
fn backtransform_beta(&self, beta_internal: &Array1<f64>) -> Array1<f64> {
let mut beta = beta_internal.clone();
for &(j, mean, scale) in &self.columns {
if let Some(intercept_idx) = self.intercept_idx {
beta[intercept_idx] -= beta_internal[j] * mean / scale;
}
beta[j] = beta_internal[j] / scale;
}
beta
}
fn transform_matrix_columnswith_a(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
self.transform_matrix_columnswith_a_inplace(&mut out);
out
}
fn transform_matrix_columnswith_a_inplace(&self, mat: &mut Array2<f64>) {
if !self.is_active() {
return;
}
let intercept_col = self.intercept_idx.map(|idx| mat.column(idx).to_owned());
for &(j, mean, scale) in &self.columns {
let mut target = mat.column_mut(j);
if mean != 0.0
&& let Some(intercept_col) = intercept_col.as_ref()
{
target -= &(intercept_col * mean);
}
if scale != 1.0 {
target.mapv_inplace(|v| v / scale);
}
}
}
fn transform_matrixrowswith_a_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
for &(j, mean, scale) in &self.columns {
let interceptrow = self.intercept_idx.map(|idx| out.row(idx).to_owned());
let mut target = out.row_mut(j);
if mean != 0.0
&& let Some(interceptrow) = interceptrow
{
target -= &(interceptrow * mean);
}
if scale != 1.0 {
target.mapv_inplace(|v| v / scale);
}
}
out
}
fn transform_matrix_columnswith_b(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
for &(j, mean, scale) in &self.columns {
let intercept_col = self.intercept_idx.map(|idx| out.column(idx).to_owned());
let mut target = out.column_mut(j);
if mean != 0.0
&& let Some(intercept_col) = intercept_col
{
target += &(intercept_col * mean);
}
if scale != 1.0 {
target.mapv_inplace(|v| v * scale);
}
}
out
}
fn transform_matrixrowswith_b_transpose(&self, mat: &Array2<f64>) -> Array2<f64> {
let mut out = mat.clone();
for &(j, mean, scale) in &self.columns {
let interceptrow = self.intercept_idx.map(|idx| out.row(idx).to_owned());
let mut target = out.row_mut(j);
if mean != 0.0
&& let Some(interceptrow) = interceptrow
{
target += &(interceptrow * mean);
}
if scale != 1.0 {
target.mapv_inplace(|v| v * scale);
}
}
out
}
fn backtransform_covariance(&self, cov_internal: &Array2<f64>) -> Array2<f64> {
let right = self.transform_matrix_columnswith_a(cov_internal);
self.transform_matrixrowswith_a_transpose(&right)
}
fn backtransform_penalized_hessian(&self, h_internal: &Array2<f64>) -> Array2<f64> {
let right = self.transform_matrix_columnswith_b(h_internal);
self.transform_matrixrowswith_b_transpose(&right)
}
fn backtransform_external_result(
&self,
mut result: ExternalOptimResult,
) -> ExternalOptimResult {
if !self.is_active() {
return result;
}
result.beta = self.backtransform_beta(&result.beta);
if let Some(inf) = result.inference.as_mut() {
inf.penalized_hessian = self.backtransform_penalized_hessian(&inf.penalized_hessian);
inf.beta_covariance = inf
.beta_covariance
.take()
.map(|cov| self.backtransform_covariance(&cov));
inf.beta_standard_errors = inf.beta_covariance.as_ref().map(se_from_covariance);
inf.beta_covariance_corrected = inf
.beta_covariance_corrected
.take()
.map(|cov| self.backtransform_covariance(&cov));
inf.beta_standard_errors_corrected = inf
.beta_covariance_corrected
.as_ref()
.map(se_from_covariance);
inf.bias_correction_beta = inf
.bias_correction_beta
.take()
.map(|b| self.backtransform_beta(&b));
inf.smoothing_correction = inf
.smoothing_correction
.take()
.map(|cov| self.backtransform_covariance(&cov));
inf.reparam_qs = None;
}
result.constraint_kkt = None;
result.artifacts = FitArtifacts {
pirls: None,
..Default::default()
};
result
}
}
fn map_hessian_to_original_basis(
pirls: &crate::pirls::PirlsResult,
) -> Result<Array2<f64>, EstimationError> {
let qs = &pirls.reparam_result.qs;
let h_t = &pirls.penalized_hessian_transformed;
let tmp = h_t.left_dot_matrix(qs);
let mut h = tmp.dot(&qs.t());
crate::families::custom_family::symmetrize_dense_in_place(&mut h);
Ok(h)
}
pub(crate) const PIRLS_INNER_TOLERANCE_FLOOR: f64 = 1e-6;
#[derive(Clone)]
pub(crate) struct RemlConfig {
likelihood: GlmLikelihoodSpec,
link_kind: InverseLink,
pirls_convergence_tolerance: f64,
max_iterations: usize,
reml_convergence_tolerance: f64,
firth_bias_reduction: bool,
}
impl RemlConfig {
fn external(likelihood: GlmLikelihoodSpec, reml_tol: f64, firth_bias_reduction: bool) -> Self {
let pirls_tol = reml_tol.min(PIRLS_INNER_TOLERANCE_FLOOR);
Self {
likelihood,
link_kind: InverseLink::Standard(likelihood.link_function()),
pirls_convergence_tolerance: pirls_tol,
max_iterations: 0,
reml_convergence_tolerance: reml_tol,
firth_bias_reduction,
}
.with_max_iterations(300)
}
pub(crate) fn with_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations;
self
}
fn link_function(&self) -> LinkFunction {
self.link_kind.link_function()
}
fn as_pirls_config(&self) -> pirls::PirlsConfig {
pirls::PirlsConfig {
likelihood: self.likelihood,
link_kind: self.link_kind.clone(),
max_iterations: self.max_iterations,
convergence_tolerance: self.pirls_convergence_tolerance,
firth_bias_reduction: self.firth_bias_reduction,
initial_lm_lambda: None,
}
}
}
const MAX_FACTORIZATION_ATTEMPTS: usize = 4;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use thiserror::Error;
const LAML_RIDGE: f64 = 1e-8;
pub(crate) const DP_FLOOR: f64 = 1e-12;
const DP_FLOOR_SMOOTH_WIDTH: f64 = 1e-8;
pub(crate) const PIRLS_CACHE_BYTE_BUDGET: usize = 256 * 1024 * 1024;
pub(crate) const RHO_BOUND: f64 = 30.0;
const RHO_SOFT_PRIOR_WEIGHT: f64 = 1e-6;
const RHO_SOFT_PRIOR_SHARPNESS: f64 = 4.0;
const AUTO_CUBATURE_MAX_RHO_DIM: usize = 12;
const AUTO_CUBATURE_MAX_EIGENVECTORS: usize = 4;
const AUTO_CUBATURE_TARGET_VAR_FRAC: f64 = 0.95;
const AUTO_CUBATURE_MAX_BETA_DIM: usize = 1600;
const AUTO_CUBATURE_BOUNDARY_MARGIN: f64 = 2.0;
pub(crate) fn smooth_floor_dp(dp: f64) -> (f64, f64, f64) {
let tau = DP_FLOOR_SMOOTH_WIDTH.max(f64::EPSILON);
let scaled = (dp - DP_FLOOR) / tau;
let softplus = if scaled > 20.0 {
scaled + (-scaled).exp()
} else if scaled < -20.0 {
scaled.exp()
} else {
(1.0 + scaled.exp()).ln()
};
let sigma = if scaled >= 0.0 {
let exp_neg = (-scaled).exp();
1.0 / (1.0 + exp_neg)
} else {
let exp_pos = scaled.exp();
exp_pos / (1.0 + exp_pos)
};
let dp_c = DP_FLOOR + tau * softplus;
let dp_cgrad2 = sigma * (1.0 - sigma) / tau;
(dp_c, sigma, dp_cgrad2)
}
pub(crate) struct SmoothingCorrectionComputation {
pub correction: Option<Array2<f64>>,
pub hessian_rho: Option<Array2<f64>>,
}
fn invert_regularized_rho_hessian(hessian_rho: &Array2<f64>) -> Option<(Array2<f64>, bool)> {
if let Ok(chol) = hessian_rho.cholesky(faer::Side::Lower) {
let n = hessian_rho.nrows();
let mut inverse = Array2::<f64>::eye(n);
for col in 0..n {
let colvec = inverse.column(col).to_owned();
let solved = chol.solvevec(&colvec);
inverse.column_mut(col).assign(&solved);
}
return Some((inverse, false));
}
let (eigenvalues, eigenvectors) = hessian_rho.eigh(faer::Side::Lower).ok()?;
if eigenvalues.iter().any(|v| !v.is_finite()) || !eigenvectors.iter().all(|v| v.is_finite()) {
return None;
}
let n = hessian_rho.nrows();
let spectral_scale = eigenvalues
.iter()
.copied()
.map(f64::abs)
.fold(0.0_f64, f64::max)
.max(1.0);
let floor = (spectral_scale * 1e-10).max(LAML_RIDGE);
let mut inverse = Array2::<f64>::zeros((n, n));
for i in 0..n {
let lambda = eigenvalues[i].max(floor);
let inv_lambda = 1.0 / lambda;
let v = eigenvectors.column(i);
for row in 0..n {
for col in 0..n {
inverse[[row, col]] += inv_lambda * v[row] * v[col];
}
}
}
Some((inverse, true))
}
fn compute_smoothing_correction(
reml_state: &RemlState<'_>,
final_rho: &Array1<f64>,
final_fit: &pirls::PirlsResult,
) -> SmoothingCorrectionComputation {
use crate::faer_ndarray::{FaerCholesky, FaerEigh};
let n_rho = final_rho.len();
if n_rho == 0 {
return SmoothingCorrectionComputation {
correction: None,
hessian_rho: None,
};
}
let n_coeffs_trans = final_fit.beta_transformed.len();
let n_coeffs_orig = final_fit.reparam_result.qs.nrows();
let lambdas: Array1<f64> = final_rho.mapv(f64::exp);
let h_trans = reml_state
.objective_innerhessian(final_rho)
.unwrap_or_else(|_| final_fit.stabilizedhessian_transformed.to_dense());
let h_chol = match h_trans.cholesky(faer::Side::Lower) {
Ok(c) => c,
Err(_) => {
log::warn!("Cholesky decomposition failed for smoothing correction; skipping.");
return SmoothingCorrectionComputation {
correction: None,
hessian_rho: None,
};
}
};
let beta_trans = final_fit.beta_transformed.as_ref();
let ct = &final_fit.reparam_result.canonical_transformed;
let mut jacobian_trans = Array2::<f64>::zeros((n_coeffs_trans, n_rho));
for k in 0..n_rho {
if k >= ct.len() {
continue;
}
let cp = &ct[k];
if cp.rank() == 0 {
continue;
}
let r = &cp.col_range;
let beta_block = beta_trans.slice(s![r.start..r.end]);
let r_beta = cp.root.dot(&beta_block);
let mut s_k_beta = Array1::<f64>::zeros(n_coeffs_trans);
for a in 0..cp.block_dim() {
s_k_beta[r.start + a] = (0..cp.rank())
.map(|row| cp.root[[row, a]] * r_beta[row])
.sum::<f64>();
}
let rhs = s_k_beta.mapv(|v| -lambdas[k] * v);
let delta = h_chol.solvevec(&rhs);
jacobian_trans.column_mut(k).assign(&delta);
}
let mut hessian_rho = match reml_state.compute_lamlhessian_consistent(final_rho) {
Ok(h) => h,
Err(err) => {
log::warn!(
"LAML Hessian unavailable ({}); skipping smoothing correction.",
err
);
return SmoothingCorrectionComputation {
correction: None,
hessian_rho: None,
};
}
};
enforce_symmetry(&mut hessian_rho);
add_relative_diag_ridge(&mut hessian_rho, LAML_RIDGE, LAML_RIDGE);
let (v_rho, repaired_hessian) = match invert_regularized_rho_hessian(&hessian_rho) {
Some(inverse) => inverse,
None => {
log::warn!(
"Failed to invert LAML Hessian for smoothing correction after spectral repair; skipping."
);
return SmoothingCorrectionComputation {
correction: None,
hessian_rho: Some(hessian_rho),
};
}
};
if repaired_hessian {
log::debug!(
"Projected indefinite LAML Hessian onto a positive spectrum before smoothing correction inversion."
);
}
let jv_rho = jacobian_trans.dot(&v_rho); let v_corr_trans = jv_rho.dot(&jacobian_trans.t());
let qs = &final_fit.reparam_result.qs;
let qsv = qs.dot(&v_corr_trans);
let v_corr_orig = qsv.dot(&qs.t());
if !v_corr_orig.iter().all(|v| v.is_finite()) {
log::warn!("Non-finite values in smoothing correction matrix; skipping.");
return SmoothingCorrectionComputation {
correction: None,
hessian_rho: Some(hessian_rho),
};
}
match v_corr_orig.eigh(faer::Side::Lower) {
Ok((eigenvalues, eigenvectors)) => {
let min_eig = eigenvalues.iter().fold(f64::INFINITY, |a, &b| a.min(b));
if min_eig < -1e-10 {
log::debug!(
"Smoothing correction has negative eigenvalue {:.3e}; clamping to zero.",
min_eig
);
let mut result = Array2::<f64>::zeros((n_coeffs_orig, n_coeffs_orig));
for i in 0..n_coeffs_orig {
let eig = eigenvalues[i].max(0.0);
let v = eigenvectors.column(i);
for j in 0..n_coeffs_orig {
for k in 0..n_coeffs_orig {
result[[j, k]] += eig * v[j] * v[k];
}
}
}
return SmoothingCorrectionComputation {
correction: Some(result),
hessian_rho: Some(hessian_rho),
};
}
}
Err(_) => {
log::warn!("Eigendecomposition failed for smoothing correction validation.");
}
}
SmoothingCorrectionComputation {
correction: Some(v_corr_orig),
hessian_rho: Some(hessian_rho),
}
}
#[derive(Error)]
pub enum EstimationError {
#[error("Underlying basis function generation failed: {0}")]
BasisError(#[from] crate::basis::BasisError),
#[error("A linear system solve failed. The penalized Hessian may be singular. Error: {0}")]
LinearSystemSolveFailed(FaerLinalgError),
#[error("Eigendecomposition failed: {0}")]
EigendecompositionFailed(FaerLinalgError),
#[error("Parameter constraint violation: {0}")]
ParameterConstraintViolation(String),
#[error(
"The P-IRLS inner loop did not converge within {max_iterations} iterations. Last gradient norm was {last_change:.6e}."
)]
PirlsDidNotConverge {
max_iterations: usize,
last_change: f64,
},
#[error(
"Perfect or quasi-perfect separation detected during model fitting at iteration {iteration}. \
The model cannot converge because a predictor perfectly separates the binary outcomes. \
(Diagnostic: max|eta| = {max_abs_eta:.2e})."
)]
PerfectSeparationDetected { iteration: usize, max_abs_eta: f64 },
#[error(
"Hessian matrix is not positive definite (minimum eigenvalue: {min_eigenvalue:.4e}). This indicates a numerical instability."
)]
HessianNotPositiveDefinite { min_eigenvalue: f64 },
#[error("REML smoothing optimization failed to converge: {0}")]
RemlOptimizationFailed(String),
#[error("An internal error occurred during model layout or coefficient mapping: {0}")]
LayoutError(String),
#[error(
"Model is over-parameterized: {num_coeffs} coefficients for {num_samples} samples.\n\n\
Coefficient Breakdown:\n\
- Intercept: {intercept_coeffs}\n\
- Binary Main Effects: {binary_main_coeffs}\n\
- Primary Smooth Effects: {primary_smooth_coeffs}\n\
- Binary×Primary Interactions: {binary_primary_interaction_coeffs}\n\
- Auxiliary Main Effects: {aux_main_coeffs}\n\
- Auxiliary Interactions: {aux_interaction_coeffs}"
)]
ModelOverparameterized {
num_coeffs: usize,
num_samples: usize,
intercept_coeffs: usize,
binary_main_coeffs: usize,
primary_smooth_coeffs: usize,
aux_main_coeffs: usize,
binary_primary_interaction_coeffs: usize,
aux_interaction_coeffs: usize,
},
#[error(
"Model is ill-conditioned with condition number {condition_number:.2e}. This typically occurs when the model is over-parameterized (too many knots relative to data points). Consider reducing the number of knots or increasing regularization."
)]
ModelIsIllConditioned { condition_number: f64 },
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Calibrator training failed: {0}")]
CalibratorTrainingFailed(String),
#[error("Invalid specification: {0}")]
InvalidSpecification(String),
#[error("Prediction error")]
PredictionError,
}
impl core::fmt::Debug for EstimationError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self)
}
}
pub struct ExternalOptimResult {
pub beta: Array1<f64>,
pub lambdas: Array1<f64>,
pub likelihood_family: LikelihoodFamily,
pub likelihood_scale: LikelihoodScaleMetadata,
pub log_likelihood_normalization: LogLikelihoodNormalization,
pub log_likelihood: f64,
pub standard_deviation: f64,
pub iterations: usize,
pub finalgrad_norm: f64,
pub pirls_status: crate::pirls::PirlsStatus,
pub deviance: f64,
pub stable_penalty_term: f64,
pub max_abs_eta: f64,
pub constraint_kkt: Option<crate::pirls::ConstraintKktDiagnostics>,
pub artifacts: FitArtifacts,
pub inference: Option<FitInference>,
pub reml_score: f64,
pub fitted_link: FittedLinkState,
}
#[derive(Clone)]
pub struct ExternalOptimOptions {
pub family: crate::types::LikelihoodFamily,
pub latent_cloglog: Option<LatentCLogLogState>,
pub mixture_link: Option<MixtureLinkSpec>,
pub optimize_mixture: bool,
pub sas_link: Option<SasLinkSpec>,
pub optimize_sas: bool,
pub compute_inference: bool,
pub max_iter: usize,
pub tol: f64,
pub nullspace_dims: Vec<usize>,
pub linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
pub firth_bias_reduction: Option<bool>,
pub penalty_shrinkage_floor: Option<f64>,
pub rho_prior: crate::types::RhoPrior,
pub kronecker_penalty_system: Option<crate::smooth::KroneckerPenaltySystem>,
pub kronecker_factored: Option<crate::basis::KroneckerFactoredBasis>,
}
fn resolve_external_family(
family: crate::types::LikelihoodFamily,
firth_override: Option<bool>,
) -> Result<(GlmLikelihoodSpec, bool), EstimationError> {
if matches!(family, crate::types::LikelihoodFamily::RoystonParmar) {
return Err(EstimationError::InvalidInput(
"optimize_external_design does not support RoystonParmar; use survival training APIs"
.to_string(),
));
}
if firth_override == Some(true) && !family.supports_firth() {
return Err(EstimationError::InvalidInput(format!(
"firth_bias_reduction is currently implemented only for {}; {} does not support it",
crate::types::LikelihoodFamily::BinomialLogit.pretty_name(),
family.pretty_name()
)));
}
Ok((
GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::try_from(family).map_err(|msg| {
EstimationError::InvalidInput(format!(
"optimize_external_design requires a GLM family; {msg}"
))
})?),
firth_override.unwrap_or(false) && family.supports_firth(),
))
}
#[inline]
fn effective_sas_link_for_family(
family: crate::types::LikelihoodFamily,
sas_link: Option<SasLinkSpec>,
) -> Option<SasLinkSpec> {
if matches!(
family,
crate::types::LikelihoodFamily::BinomialSas
| crate::types::LikelihoodFamily::BinomialBetaLogistic
) && sas_link.is_none()
{
Some(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
} else {
sas_link
}
}
#[inline]
fn resolved_external_inverse_link(
link: LinkFunction,
latent_cloglog: Option<LatentCLogLogState>,
mixture_link: Option<&MixtureLinkSpec>,
sas_link: Option<SasLinkSpec>,
) -> Result<InverseLink, EstimationError> {
if let Some(state) = latent_cloglog {
return Ok(InverseLink::LatentCLogLog(state));
}
if let Some(spec) = mixture_link {
return Ok(InverseLink::Mixture(state_fromspec(spec).map_err(|e| {
EstimationError::InvalidInput(format!("invalid blended inverse link: {e}"))
})?));
}
if let Some(spec) = sas_link {
return Ok(match link {
LinkFunction::BetaLogistic => {
InverseLink::BetaLogistic(state_from_beta_logisticspec(spec).map_err(|e| {
EstimationError::InvalidInput(format!("invalid Beta-Logistic link: {e}"))
})?)
}
_ => InverseLink::Sas(
state_from_sasspec(spec)
.map_err(|e| EstimationError::InvalidInput(format!("invalid SAS link: {e}")))?,
),
});
}
Ok(InverseLink::Standard(link))
}
#[inline]
fn resolved_external_config(
opts: &ExternalOptimOptions,
) -> Result<(RemlConfig, Option<SasLinkSpec>), EstimationError> {
if opts.latent_cloglog.is_some() && (opts.mixture_link.is_some() || opts.sas_link.is_some()) {
return Err(EstimationError::InvalidInput(
"latent_cloglog cannot be combined with mixture_link or sas_link".to_string(),
));
}
if opts.mixture_link.is_some() && opts.sas_link.is_some() {
return Err(EstimationError::InvalidInput(
"mixture_link and sas_link are mutually exclusive".to_string(),
));
}
if matches!(
opts.family,
crate::types::LikelihoodFamily::BinomialLatentCLogLog
) && opts.latent_cloglog.is_none()
{
return Err(EstimationError::InvalidInput(
"BinomialLatentCLogLog requires latent_cloglog state".to_string(),
));
}
if opts.latent_cloglog.is_some()
&& !matches!(
opts.family,
crate::types::LikelihoodFamily::BinomialLatentCLogLog
)
{
return Err(EstimationError::InvalidInput(
"latent_cloglog is only supported with BinomialLatentCLogLog".to_string(),
));
}
let effective_sas_link = effective_sas_link_for_family(opts.family, opts.sas_link);
let (likelihood, firth_active) =
resolve_external_family(opts.family, opts.firth_bias_reduction)?;
let mut cfg = RemlConfig::external(likelihood, opts.tol, firth_active);
let link = likelihood.link_function();
cfg.link_kind = resolved_external_inverse_link(
link,
opts.latent_cloglog,
opts.mixture_link.as_ref(),
effective_sas_link,
)?;
Ok((cfg, effective_sas_link))
}
#[inline]
fn ensure_exact_directional_hyper_supported(
_: LinkFunction,
_: bool,
_: bool,
_: &str,
) -> Result<(), EstimationError> {
Ok(())
}
fn validate_penalty_specs(
specs: &[PenaltySpec],
p: usize,
context: &str,
) -> Result<(), EstimationError> {
for (idx, spec) in specs.iter().enumerate() {
match spec {
PenaltySpec::Block {
local, col_range, ..
} => {
let bd = col_range.len();
if local.nrows() != bd || local.ncols() != bd {
return Err(EstimationError::InvalidInput(format!(
"{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
local.nrows(),
local.ncols()
)));
}
if col_range.end > p {
return Err(EstimationError::InvalidInput(format!(
"{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
col_range.start, col_range.end
)));
}
}
PenaltySpec::Dense(m) => {
if m.nrows() != p || m.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
m.nrows(),
m.ncols()
)));
}
}
}
}
Ok(())
}
fn validate_joint_hyper_direction_shapes(
x: &DesignMatrix,
canonical_len: usize,
theta: &Array1<f64>,
rho_dim: usize,
hyper_dirs: &[DirectionalHyperParam],
) -> Result<(), EstimationError> {
if rho_dim > theta.len() {
return Err(EstimationError::InvalidInput(format!(
"rho_dim {} exceeds theta dimension {}",
rho_dim,
theta.len()
)));
}
let p = x.ncols();
let psi_dim = theta.len() - rho_dim;
if hyper_dirs.len() != psi_dim {
return Err(EstimationError::InvalidInput(format!(
"joint hyper-gradient derivative count mismatch: psi_dim={}, hyper_dirs={}",
psi_dim,
hyper_dirs.len()
)));
}
for (idx, hyper_dir) in hyper_dirs.iter().enumerate() {
for component in hyper_dir.penalty_first_components() {
if component.penalty_index >= canonical_len {
return Err(EstimationError::InvalidInput(format!(
"penalty_index for dir {idx} out of bounds: {} >= {}",
component.penalty_index, canonical_len
)));
}
}
if hyper_dir.x_tau_original.nrows() != x.nrows() || hyper_dir.x_tau_original.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"X_tau[{idx}] must be {}x{}, got {}x{}",
x.nrows(),
p,
hyper_dir.x_tau_original.nrows(),
hyper_dir.x_tau_original.ncols()
)));
}
RemlState::validate_penalty_component_shapes(
hyper_dir.penalty_first_components(),
p,
&format!("S_tau[{idx}]"),
)?;
if let Some(x2) = hyper_dir.x_tau_tau_original.as_ref() {
if x2.len() != psi_dim {
return Err(EstimationError::InvalidInput(format!(
"X_tau_tau[{idx}] length mismatch: expected {}, got {}",
psi_dim,
x2.len()
)));
}
for (j, x_ij) in x2.iter().enumerate() {
let Some(x_ij) = x_ij.as_ref() else {
continue;
};
if x_ij.nrows() != x.nrows() || x_ij.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"X_tau_tau[{idx}][{j}] must be {}x{}, got {}x{}",
x.nrows(),
p,
x_ij.nrows(),
x_ij.ncols()
)));
}
}
}
if let Some(s2) = hyper_dir.penaltysecond_componentrows() {
if s2.len() != psi_dim {
return Err(EstimationError::InvalidInput(format!(
"S_tau_tau[{idx}] length mismatch: expected {}, got {}",
psi_dim,
s2.len()
)));
}
for (j, components) in s2.iter().enumerate() {
let Some(components) = components.as_ref() else {
continue;
};
RemlState::validate_penalty_component_shapes(
components,
p,
&format!("S_tau_tau[{idx}][{j}]"),
)?;
}
}
}
Ok(())
}
pub(crate) struct ExternalJointHyperEvaluator<'a> {
conditioning: ParametricColumnConditioning,
config: Arc<RemlConfig>,
penalty_shrinkage_floor: Option<f64>,
kronecker_penalty_system: Option<crate::smooth::KroneckerPenaltySystem>,
kronecker_factored: Option<crate::basis::KroneckerFactoredBasis>,
reml_state: RemlState<'a>,
}
impl<'a> ExternalJointHyperEvaluator<'a> {
pub(crate) fn new(
y: ArrayView1<'a, f64>,
w: ArrayView1<'a, f64>,
x: &DesignMatrix,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
opts: &ExternalOptimOptions,
context: &str,
) -> Result<Self, EstimationError> {
if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
return Err(EstimationError::InvalidInput(message));
}
let p = x.ncols();
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
validate_penalty_specs(&specs, p, context)?;
let (canonical, active_nullspace_dims) = crate::construction::canonicalize_penalty_specs(
&specs,
&opts.nullspace_dims,
p,
context,
)?;
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(x, &specs);
let x_fit = conditioning.apply_to_design(x);
let fit_linear_constraints =
conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
let (config, _) = resolved_external_config(opts)?;
let config = Arc::new(config);
let mut reml_state = RemlState::newwith_offset_shared(
y,
x_fit,
w,
offset,
Arc::new(canonical),
p,
Arc::clone(&config),
Some(active_nullspace_dims.clone()),
None,
fit_linear_constraints.clone(),
)?;
reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
reml_state.set_rho_prior(opts.rho_prior.clone());
reml_state.set_link_states(
config.link_kind.mixture_state().cloned(),
config.link_kind.sas_state().copied(),
);
if let Some(kron) = opts.kronecker_penalty_system.clone() {
reml_state.set_kronecker_penalty_system(kron);
}
if let Some(kf) = opts.kronecker_factored.clone() {
reml_state.set_kronecker_factored(kf);
}
Ok(Self {
conditioning,
config,
penalty_shrinkage_floor: opts.penalty_shrinkage_floor,
kronecker_penalty_system: opts.kronecker_penalty_system.clone(),
kronecker_factored: opts.kronecker_factored.clone(),
reml_state,
})
}
fn prepare_eval_state(
&mut self,
x: &DesignMatrix,
s_list: &[BlockwisePenalty],
nullspace_dims: &[usize],
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
theta: &Array1<f64>,
rho_dim: usize,
mut hyper_dirs: Vec<DirectionalHyperParam>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
context: &str,
) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
let p = x.ncols();
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
validate_penalty_specs(&specs, p, context)?;
let (canonical, active_nullspace_dims) =
crate::construction::canonicalize_penalty_specs(&specs, nullspace_dims, p, context)?;
validate_joint_hyper_direction_shapes(x, canonical.len(), theta, rho_dim, &hyper_dirs)?;
let x_fit = self.conditioning.apply_to_design(x);
let fit_linear_constraints = self
.conditioning
.transform_linear_constraints_to_internal(linear_constraints);
for dir in &mut hyper_dirs {
let mut x_tau = dir.x_tau_dense();
self.conditioning
.transform_matrix_columnswith_a_inplace(&mut x_tau);
dir.x_tau_original = crate::estimate::reml::HyperDesignDerivative::from(x_tau);
if let Some(rows) = dir.x_tau_tau_original.as_mut() {
for mat in rows.iter_mut().flatten() {
let mut dense = mat.materialize();
self.conditioning
.transform_matrix_columnswith_a_inplace(&mut dense);
*mat = crate::estimate::reml::HyperDesignDerivative::from(dense);
}
}
}
let has_design_drift = hyper_dirs
.iter()
.any(|dir| dir.x_tau_original.any_nonzero());
ensure_exact_directional_hyper_supported(
self.config.link_function(),
self.config.firth_bias_reduction,
has_design_drift,
context,
)?;
self.reml_state.reset_surface(
x_fit,
Arc::new(canonical),
p,
active_nullspace_dims,
None,
fit_linear_constraints,
self.kronecker_penalty_system.clone(),
self.kronecker_factored.clone(),
)?;
self.reml_state
.set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
self.reml_state.setwarm_start_original_beta(warm_start_beta);
Ok(hyper_dirs)
}
pub(crate) fn evaluate_with_order(
&mut self,
x: &DesignMatrix,
s_list: &[BlockwisePenalty],
nullspace_dims: &[usize],
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
theta: &Array1<f64>,
rho_dim: usize,
hyper_dirs: Vec<DirectionalHyperParam>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
context: &str,
order: crate::solver::outer_strategy::OuterEvalOrder,
) -> Result<
(
f64,
Array1<f64>,
crate::solver::outer_strategy::HessianResult,
),
EstimationError,
> {
let order = if matches!(
order,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
) {
let firth_pair_terms_unavailable = false;
let tau_tau_policy = crate::estimate::reml::exact_tau_tau_hessian_policy_with_firth(
x.nrows(),
x.ncols(),
&hyper_dirs,
firth_pair_terms_unavailable,
);
if tau_tau_policy.prefer_gradient_only() {
log::warn!(
"[OUTER] disabling exact tau Hessian before conditioning; using gradient-only outer eval \
(n={}, p={}, psi_dim={}, implicit_tau={}, implicit_multidim_duchon={}, firth_pair_gap={}, dense_tau_cache={:.1} MiB, gradient_plan={:.1} MiB, exact_hessian_plan={:.1} MiB, budget={:.1} MiB)",
x.nrows(),
x.ncols(),
hyper_dirs.len(),
tau_tau_policy.any_has_implicit,
tau_tau_policy.implicit_multidim_duchon,
tau_tau_policy.firth_pair_terms_unavailable,
tau_tau_policy.estimated_dense_tau_cache_bytes as f64 / (1024.0 * 1024.0),
tau_tau_policy.gradient_plan.total_bytes() as f64 / (1024.0 * 1024.0),
tau_tau_policy.hessian_plan.total_bytes() as f64 / (1024.0 * 1024.0),
tau_tau_policy.budget_bytes as f64 / (1024.0 * 1024.0),
);
crate::solver::outer_strategy::OuterEvalOrder::ValueAndGradient
} else {
order
}
} else {
order
};
let hyper_dirs = self.prepare_eval_state(
x,
s_list,
nullspace_dims,
linear_constraints,
theta,
rho_dim,
hyper_dirs,
warm_start_beta,
context,
)?;
self.reml_state
.compute_joint_hyper_eval_with_order(theta, rho_dim, &hyper_dirs, order)
}
pub(crate) fn evaluate_efs(
&mut self,
x: &DesignMatrix,
s_list: &[BlockwisePenalty],
nullspace_dims: &[usize],
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
theta: &Array1<f64>,
rho_dim: usize,
hyper_dirs: Vec<DirectionalHyperParam>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
context: &str,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
let hyper_dirs = self.prepare_eval_state(
x,
s_list,
nullspace_dims,
linear_constraints,
theta,
rho_dim,
hyper_dirs,
warm_start_beta,
context,
)?;
let rho = theta.slice(s![..rho_dim]).to_owned();
self.reml_state
.compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
}
fn prepare_eval_state_cost_only(
&mut self,
x: &DesignMatrix,
s_list: &[BlockwisePenalty],
nullspace_dims: &[usize],
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
context: &str,
) -> Result<(), EstimationError> {
let p = x.ncols();
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
validate_penalty_specs(&specs, p, context)?;
let (canonical, active_nullspace_dims) =
crate::construction::canonicalize_penalty_specs(&specs, nullspace_dims, p, context)?;
let x_fit = self.conditioning.apply_to_design(x);
let fit_linear_constraints = self
.conditioning
.transform_linear_constraints_to_internal(linear_constraints);
self.reml_state.reset_surface(
x_fit,
Arc::new(canonical),
p,
active_nullspace_dims,
None,
fit_linear_constraints,
self.kronecker_penalty_system.clone(),
self.kronecker_factored.clone(),
)?;
self.reml_state
.set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
self.reml_state.setwarm_start_original_beta(warm_start_beta);
Ok(())
}
pub(crate) fn evaluate_cost_only(
&mut self,
x: &DesignMatrix,
s_list: &[BlockwisePenalty],
nullspace_dims: &[usize],
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
theta: &Array1<f64>,
rho_dim: usize,
warm_start_beta: Option<ArrayView1<'_, f64>>,
context: &str,
) -> Result<f64, EstimationError> {
if rho_dim > theta.len() {
return Err(EstimationError::InvalidInput(format!(
"rho_dim {} exceeds theta dimension {}",
rho_dim,
theta.len()
)));
}
self.prepare_eval_state_cost_only(
x,
s_list,
nullspace_dims,
linear_constraints,
warm_start_beta,
context,
)?;
let rho = theta.slice(s![..rho_dim]).to_owned();
self.reml_state.compute_cost(&rho)
}
}
pub fn optimize_external_design<X>(
y: ArrayView1<'_, f64>,
w: ArrayView1<'_, f64>,
x: X,
offset: ArrayView1<'_, f64>,
s_list: Vec<BlockwisePenalty>,
opts: &ExternalOptimOptions,
) -> Result<ExternalOptimResult, EstimationError>
where
X: Into<DesignMatrix>,
{
optimize_external_designwith_heuristic_lambdas(y, w, x, offset, s_list, None, opts)
}
pub fn optimize_external_designwith_heuristic_lambdas<X>(
y: ArrayView1<'_, f64>,
w: ArrayView1<'_, f64>,
x: X,
offset: ArrayView1<'_, f64>,
s_list: Vec<BlockwisePenalty>,
heuristic_lambdas: Option<&[f64]>,
opts: &ExternalOptimOptions,
) -> Result<ExternalOptimResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let specs: Vec<PenaltySpec> = s_list
.into_iter()
.map(PenaltySpec::from_blockwise)
.collect();
optimize_external_designwith_heuristic_lambdas_andwarm_start(
y,
w,
x,
offset,
specs,
heuristic_lambdas,
None,
opts,
)
}
fn optimize_external_designwith_heuristic_lambdas_andwarm_start<X>(
y: ArrayView1<'_, f64>,
w: ArrayView1<'_, f64>,
x: X,
offset: ArrayView1<'_, f64>,
s_list: Vec<PenaltySpec>,
heuristic_lambdas: Option<&[f64]>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
opts: &ExternalOptimOptions,
) -> Result<ExternalOptimResult, EstimationError>
where
X: Into<DesignMatrix>,
{
if matches!(opts.family, crate::types::LikelihoodFamily::BinomialMixture)
&& opts.mixture_link.is_none()
{
return Err(EstimationError::InvalidInput(
"BinomialMixture requires mixture_link specification".to_string(),
));
}
let x = x.into();
if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
return Err(EstimationError::InvalidInput(message));
}
let p = x.ncols();
validate_penalty_specs(&s_list, p, "optimize_external_design")?;
let (canonical, active_nullspace_dims) = crate::construction::canonicalize_penalty_specs(
&s_list,
&opts.nullspace_dims,
p,
"optimize_external_design",
)?;
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &s_list);
let x_fit = conditioning.apply_to_design(&x);
let fit_linear_constraints =
conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
let k = canonical.len();
if active_nullspace_dims.len() != k {
return Err(EstimationError::InvalidInput(format!(
"nullspace_dims length mismatch: expected {k} entries for active penalties, got {}",
active_nullspace_dims.len()
)));
}
let (cfg, effective_sas_link) = resolved_external_config(opts)?;
let design_kind = match &x {
DesignMatrix::Dense(_) => "dense",
DesignMatrix::Sparse(_) => "sparse",
};
log::info!(
"[GAM fit] n={} p={} k={} fam={:?} link={:?} X={} reml_iter={} firth={}",
y.len(),
p,
k,
opts.family,
cfg.link_function(),
design_kind,
opts.max_iter,
cfg.firth_bias_reduction
);
let y_o = y.to_owned();
let w_o = w.to_owned();
let x_o = x;
let offset_o = offset.to_owned();
let canonical_shared = Arc::new(canonical);
let cfg_shared = Arc::new(cfg.clone());
let mut reml_state = RemlState::newwith_offset_shared(
y_o.view(),
x_fit,
w_o.view(),
offset_o.view(),
Arc::clone(&canonical_shared),
p,
Arc::clone(&cfg_shared),
Some(active_nullspace_dims.clone()),
None,
fit_linear_constraints.clone(),
)?;
reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
reml_state.set_rho_prior(opts.rho_prior.clone());
if let Some(kron) = opts.kronecker_penalty_system.clone() {
reml_state.set_kronecker_penalty_system(kron);
}
if let Some(kf) = opts.kronecker_factored.clone() {
reml_state.set_kronecker_factored(kf);
}
reml_state.setwarm_start_original_beta(warm_start_beta);
let reml_seed_config = SeedConfig {
bounds: (-12.0, 12.0),
max_seeds: if k <= 4 {
6
} else if k <= 12 {
8
} else {
10
},
seed_budget: if k <= 6 { 1 } else { 2 },
risk_profile: if matches!(cfg.link_function(), LinkFunction::Identity) {
SeedRiskProfile::Gaussian
} else {
SeedRiskProfile::GeneralizedLinear
},
screen_max_inner_iterations: SeedConfig::default().screen_max_inner_iterations,
num_auxiliary_trailing: 0,
};
let reml_tol = cfg.reml_convergence_tolerance;
let reml_max_iter = opts.max_iter;
let outer_eval_idx = AtomicUsize::new(0usize);
let mixture_optspec = if opts.optimize_mixture {
opts.mixture_link.clone()
} else {
None
};
let sas_optspec = if opts.optimize_sas {
effective_sas_link
} else {
None
};
let mixture_dim = mixture_optspec
.as_ref()
.map(|s| s.initial_rho.len())
.unwrap_or(0);
let sas_dim = if sas_optspec.is_some() { 2 } else { 0 };
let sasridgeweight = if sas_dim > 0 {
sas_log_deltaridgeweight()
} else {
0.0
};
let (
final_rho,
final_mixture_state,
final_sas_state,
final_mixture_param_covariance,
final_sas_param_covariance,
outer_result,
) = if mixture_dim > 0 && sas_dim > 0 {
return Err(EstimationError::InvalidInput(
"simultaneous mixture and SAS optimization is not supported".to_string(),
));
} else if mixture_dim == 0 && sas_dim == 0 {
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, InnerProgressFeedback, OuterEvalOrder, OuterProblem,
};
let analytic_outer_hessian_available = reml_state.analytic_outer_hessian_enabled();
let problem = OuterProblem::new(k)
.with_gradient(Derivative::Analytic)
.with_hessian(if analytic_outer_hessian_available {
DeclaredHessianForm::Either
} else {
DeclaredHessianForm::Unavailable
})
.with_barrier(self::reml::unified::BarrierConfig::from_constraints(
fit_linear_constraints.as_ref(),
))
.with_tolerance(reml_tol)
.with_max_iter(reml_max_iter)
.with_seed_config(reml_seed_config.clone())
.with_screening_cap(Arc::clone(&reml_state.screening_max_inner_iterations))
.with_outer_inner_cap(InnerProgressFeedback {
cap: Arc::clone(&reml_state.outer_inner_cap),
accepted_iter: Arc::new(AtomicUsize::new(0)),
last_iters: Arc::clone(&reml_state.last_inner_iters),
last_converged: Arc::clone(&reml_state.last_inner_converged),
ift_residual: Arc::clone(&reml_state.last_ift_prediction_residual),
accept_rho: Arc::clone(&reml_state.last_pirls_accept_rho),
})
.with_rho_bound(crate::estimate::RHO_BOUND);
let problem = if let Some(ref h) = heuristic_lambdas {
problem.with_heuristic_lambdas(h.to_vec())
} else {
problem
};
let prepass_seed: Option<Array1<f64>> = {
let bnds = reml_seed_config.bounds;
let (lo, hi) = if bnds.0 <= bnds.1 {
bnds
} else {
(bnds.1, bnds.0)
};
let risk_shift = match reml_seed_config.risk_profile {
SeedRiskProfile::Gaussian => 0.0,
SeedRiskProfile::GeneralizedLinear => 1.0,
SeedRiskProfile::Survival => 2.0,
};
let base = if let Some(h) = heuristic_lambdas.as_ref().filter(|h| h.len() == k) {
Array1::from_iter(h.iter().map(|&v| {
let r = v.max(1e-12).ln();
(r + risk_shift).clamp(lo, hi)
}))
} else {
Array1::from_elem(k, risk_shift.clamp(lo, hi))
};
let refined = crate::seeding::select_objective_seed_on_log_lambda_grid(
&base,
(lo, hi),
k,
|rho| reml_state.compute_cost(rho).ok().filter(|c| c.is_finite()),
);
if refined
.iter()
.zip(base.iter())
.any(|(&a, &b)| (a - b).abs() > 1e-12)
{
log::info!(
"[OUTER] standard REML objective-grid selected seed: {:?} -> {:?}",
base.as_slice().unwrap_or(&[]),
refined.as_slice().unwrap_or(&[])
);
Some(refined)
} else {
None
}
};
let problem = if let Some(seed) = prepass_seed {
problem.with_initial_rho(seed)
} else {
problem
};
let mut obj = problem.build_objective_with_screening_proxy(
&mut reml_state,
|state: &mut &mut self::reml::RemlState<'_>, rho: &Array1<f64>| state.compute_cost(rho),
|state: &mut &mut self::reml::RemlState<'_>, rho: &Array1<f64>| {
outer_eval_idx.fetch_add(1, Ordering::Relaxed);
state.compute_outer_eval_with_order(
rho,
if analytic_outer_hessian_available {
OuterEvalOrder::ValueGradientHessian
} else {
OuterEvalOrder::ValueAndGradient
},
)
},
|state: &mut &mut self::reml::RemlState<'_>,
rho: &Array1<f64>,
order: OuterEvalOrder| {
outer_eval_idx.fetch_add(1, Ordering::Relaxed);
state.compute_outer_eval_with_order(rho, order)
},
Some(|state: &mut &mut self::reml::RemlState<'_>| state.reset_outer_seed_state()),
Some(
|state: &mut &mut self::reml::RemlState<'_>, rho: &Array1<f64>| {
state.compute_efs_steps(rho)
},
),
|state: &mut &mut self::reml::RemlState<'_>, rho: &Array1<f64>| {
state.compute_screening_proxy(rho)
},
);
let strategy_result = problem.run(&mut obj, "standard REML")?;
let prev_cap = reml_state
.outer_inner_cap
.swap(0, std::sync::atomic::Ordering::Relaxed);
if prev_cap != 0 {
let guard_start = std::time::Instant::now();
let _ = reml_state.compute_cost(&strategy_result.rho);
log::info!(
"[OUTER guard] convergence-guard re-eval at converged ρ done (prev_cap={prev_cap}, elapsed={:.3}s)",
guard_start.elapsed().as_secs_f64()
);
} else {
log::debug!("[OUTER guard] schedule never lifted (prev_cap=0); skipping refit");
}
(
strategy_result.rho.clone(),
cfg.link_kind.mixture_state().cloned(),
cfg.link_kind.sas_state().copied(),
None,
None,
strategy_result,
)
} else {
let use_mixture = mixture_dim > 0;
let use_sas = sas_dim > 0;
let use_beta_logistic =
use_sas && matches!(cfg.link_function(), LinkFunction::BetaLogistic);
let theta_dim = k + mixture_dim + sas_dim;
let sasspec = sas_optspec;
let mixspec = mixture_optspec
.clone()
.or_else(|| {
if use_mixture {
None
} else {
Some(MixtureLinkSpec {
components: Vec::new(),
initial_rho: Array1::zeros(0),
})
}
})
.ok_or_else(|| EstimationError::InvalidInput("missing mixture spec".to_string()))?;
let mut heuristic_theta = Vec::new();
if let Some(hvals) = heuristic_lambdas {
if hvals.len() == k {
heuristic_theta.extend_from_slice(hvals);
if use_mixture {
heuristic_theta
.extend_from_slice(mixspec.initial_rho.as_slice().unwrap_or(&[]));
}
if let Some(spec) = sasspec {
heuristic_theta.push(spec.initial_epsilon);
heuristic_theta.push(spec.initial_log_delta);
}
}
}
let heuristic_theta_ref = if heuristic_theta.len() == theta_dim {
Some(heuristic_theta.as_slice())
} else {
None
};
let aux_dim_outer = if use_mixture { mixture_dim } else { sas_dim };
let mut reml_seed_config_mix = reml_seed_config.clone();
reml_seed_config_mix.num_auxiliary_trailing = aux_dim_outer;
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, HessianResult, InnerProgressFeedback, OuterEval,
OuterProblem,
};
let initial_link_kind = cfg.link_kind.clone();
let problem = OuterProblem::new(theta_dim)
.with_gradient(Derivative::Analytic)
.with_hessian(DeclaredHessianForm::Either)
.with_psi_dim(mixture_dim + sas_dim)
.with_barrier(self::reml::unified::BarrierConfig::from_constraints(
fit_linear_constraints.as_ref(),
))
.with_tolerance(reml_tol)
.with_max_iter(reml_max_iter)
.with_seed_config(reml_seed_config_mix.clone())
.with_screening_cap(Arc::clone(&reml_state.screening_max_inner_iterations))
.with_outer_inner_cap(InnerProgressFeedback {
cap: Arc::clone(&reml_state.outer_inner_cap),
accepted_iter: Arc::new(AtomicUsize::new(0)),
last_iters: Arc::clone(&reml_state.last_inner_iters),
last_converged: Arc::clone(&reml_state.last_inner_converged),
ift_residual: Arc::clone(&reml_state.last_ift_prediction_residual),
accept_rho: Arc::clone(&reml_state.last_pirls_accept_rho),
})
.with_rho_bound(crate::estimate::RHO_BOUND);
let problem = if let Some(h) = heuristic_theta_ref {
problem.with_heuristic_lambdas(h.to_vec())
} else {
problem
};
let apply_link_theta = |state: &mut &mut self::reml::RemlState<'_>,
theta: &Array1<f64>|
-> Result<Array1<f64>, EstimationError> {
let rho = theta.slice(s![..k]).to_owned();
let mut cfg_eval = cfg.clone();
if use_mixture {
let mix_rho = theta.slice(s![k..(k + mixture_dim)]).to_owned();
cfg_eval.link_kind = InverseLink::Mixture(
state_fromspec(&MixtureLinkSpec {
components: mixspec.components.clone(),
initial_rho: mix_rho,
})
.map_err(|e| {
EstimationError::InvalidInput(format!("invalid blended inverse link: {e}"))
})?,
);
}
if use_sas {
let epsilon = if use_beta_logistic {
theta[k]
} else {
let (v, _) = sas_effective_epsilon(theta[k]);
v
};
let delta_like = theta[k + 1];
cfg_eval.link_kind = if use_beta_logistic {
InverseLink::BetaLogistic(
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: epsilon,
initial_log_delta: delta_like,
})
.map_err(|e| {
EstimationError::InvalidInput(format!(
"invalid Beta-Logistic link: {e}"
))
})?,
)
} else {
InverseLink::Sas(
state_from_sasspec(SasLinkSpec {
initial_epsilon: epsilon,
initial_log_delta: delta_like,
})
.map_err(|e| {
EstimationError::InvalidInput(format!("invalid SAS link: {e}"))
})?,
)
};
}
state.set_link_states(
cfg_eval.link_kind.mixture_state().cloned(),
cfg_eval.link_kind.sas_state().copied(),
);
Ok(rho)
};
let sas_ridge_cost = |theta: &Array1<f64>| -> f64 {
let sasridge = if use_sas && !use_beta_logistic {
sasridgeweight
} else {
0.0
};
if use_sas && sasridge > 0.0 {
let log_delta = theta[k + 1];
let mut extra = 0.5 * sasridge * log_delta * log_delta;
if !use_beta_logistic {
let (barriercost, _) = sas_log_delta_edge_barriercostgrad(log_delta);
extra += barriercost;
}
extra
} else {
0.0
}
};
let mut obj = problem.build_objective(
&mut reml_state,
|state: &mut &mut self::reml::RemlState<'_>, theta: &Array1<f64>| {
let rho = apply_link_theta(state, theta)?;
let cost = state.compute_cost(&rho)? + sas_ridge_cost(theta);
Ok(cost)
},
|state: &mut &mut self::reml::RemlState<'_>, theta: &Array1<f64>| {
let eval_idx = outer_eval_idx.fetch_add(1, Ordering::Relaxed) + 1;
let rho = apply_link_theta(state, theta)?;
let tcost = Instant::now();
let eval_mode = self::reml::unified::EvalMode::ValueGradientHessian;
let result = state.evaluate_unified_with_link_ext(&rho, eval_mode)?;
let cost = result.cost + sas_ridge_cost(theta);
let mut grad = result.gradient.ok_or_else(|| {
EstimationError::InvalidInput(
"unified evaluator returned no gradient in ValueGradientHessian mode"
.to_string(),
)
})?;
debug_assert_eq!(
grad.len(),
theta_dim,
"unified evaluator gradient length {} != theta_dim {}",
grad.len(),
theta_dim
);
let grad_effective = grad.clone();
let mut hessian = materialize_link_outer_hessian(result.hessian, theta_dim)?;
if use_sas && !use_beta_logistic {
let (_, d_eps_d_raw, d2_eps_d_raw2) = sas_effective_epsilon_second(theta[k]);
for j in 0..theta_dim {
hessian[[k, j]] *= d_eps_d_raw;
hessian[[j, k]] *= d_eps_d_raw;
}
hessian[[k, k]] += grad_effective[k] * d2_eps_d_raw2;
grad[k] *= d_eps_d_raw;
}
if use_sas && !use_beta_logistic && sasridgeweight > 0.0 {
let log_delta = theta[k + 1];
grad[k + 1] += sasridgeweight * log_delta;
hessian[[k + 1, k + 1]] += sasridgeweight;
let (_, barriergrad, barrierhess) =
sas_log_delta_edge_barriercostgradhess(log_delta);
grad[k + 1] += barriergrad;
hessian[[k + 1, k + 1]] += barrierhess;
}
let cost_sec = tcost.elapsed().as_secs_f64();
let aux_dim = if use_mixture { mixture_dim } else { sas_dim };
log::debug!(
"[outer-eval {eval_idx}] theta_dim={} aux_dim={} unified_link_ext time_sec={:.3}",
theta_dim,
aux_dim,
cost_sec,
);
Ok(OuterEval {
cost,
gradient: grad,
hessian: HessianResult::Analytic(hessian),
})
},
Some(|state: &mut &mut self::reml::RemlState<'_>| {
state.reset_outer_seed_state();
state.set_link_states(
initial_link_kind.mixture_state().cloned(),
initial_link_kind.sas_state().copied(),
);
}),
Some(
|state: &mut &mut self::reml::RemlState<'_>, theta: &Array1<f64>| {
let rho = apply_link_theta(state, theta)?;
let mut efs_eval = state.compute_efs_steps_with_link_ext(&rho)?;
if use_sas && !use_beta_logistic {
let (_, d_eps_d_raw) = sas_effective_epsilon(theta[k]);
if efs_eval.steps.len() > k {
efs_eval.steps[k] *= d_eps_d_raw;
}
if let Some(ref mut pg) = efs_eval.psi_gradient {
if !pg.is_empty() {
pg[0] *= d_eps_d_raw;
}
}
}
efs_eval.cost += sas_ridge_cost(theta);
Ok(efs_eval)
},
),
);
let outer_result = problem.run(&mut obj, "mixture/SAS flexible link")?;
let prev_cap_mix = reml_state
.outer_inner_cap
.swap(0, std::sync::atomic::Ordering::Relaxed);
if prev_cap_mix != 0 {
let guard_start_mix = std::time::Instant::now();
let _ = reml_state.compute_cost(&outer_result.rho);
log::info!(
"[OUTER guard] convergence-guard re-eval at converged ρ done (mixture/SAS arm; prev_cap={prev_cap_mix}, elapsed={:.3}s)",
guard_start_mix.elapsed().as_secs_f64()
);
}
let final_rho = outer_result.rho.slice(s![..k]).to_owned();
let final_mix_state = if use_mixture {
let final_mix_rho = outer_result.rho.slice(s![k..(k + mixture_dim)]).to_owned();
Some(
state_fromspec(&MixtureLinkSpec {
components: mixspec.components.clone(),
initial_rho: final_mix_rho,
})
.map_err(|e| {
EstimationError::InvalidInput(format!("invalid blended inverse link: {e}"))
})?,
)
} else {
None
};
let final_sas_state = if use_sas {
let epsilon_eff = if use_beta_logistic {
outer_result.rho[k]
} else {
let (v, _) = sas_effective_epsilon(outer_result.rho[k]);
v
};
Some(if use_beta_logistic {
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: epsilon_eff,
initial_log_delta: outer_result.rho[k + 1],
})
.map_err(|e| {
EstimationError::InvalidInput(format!("invalid Beta-Logistic link: {e}"))
})?
} else {
state_from_sasspec(SasLinkSpec {
initial_epsilon: epsilon_eff,
initial_log_delta: outer_result.rho[k + 1],
})
.map_err(|e| EstimationError::InvalidInput(format!("invalid SAS link: {e}")))?
})
} else {
cfg.link_kind.sas_state().copied()
};
let aux_param_covariance = None;
let (mix_cov, sas_cov) = if use_mixture {
(aux_param_covariance, None)
} else if use_sas {
(None, aux_param_covariance)
} else {
(None, None)
};
(
final_rho,
final_mix_state,
final_sas_state,
mix_cov,
sas_cov,
outer_result,
)
};
let iters = std::cmp::max(1, outer_result.iterations);
let (pirls_res, _) = pirls::fit_model_for_fixed_rho(
LogSmoothingParamsView::new(final_rho.view()),
pirls::PirlsProblem {
x: reml_state.x(),
offset: offset_o.view(),
y: y_o.view(),
priorweights: w_o.view(),
covariate_se: None,
},
pirls::PenaltyConfig {
canonical_penalties: reml_state.canonical_penalties(),
balanced_penalty_root: Some(reml_state.balanced_penalty_root()),
reparam_invariant: None,
p,
coefficient_lower_bounds: None,
linear_constraints_original: fit_linear_constraints.as_ref(),
penalty_shrinkage_floor: opts.penalty_shrinkage_floor,
kronecker_factored: None,
},
&pirls::PirlsConfig {
link_kind: if let Some(state) = final_mixture_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = final_sas_state {
if matches!(cfg.link_function(), LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
cfg.link_kind.clone()
},
..cfg.as_pirls_config()
},
None,
)?;
let beta_orig_internal = pirls_res
.reparam_result
.qs
.dot(pirls_res.beta_transformed.as_ref());
let beta_orig = conditioning.backtransform_beta(&beta_orig_internal);
let n = y_o.len() as f64;
let weighted_rss = if matches!(cfg.link_function(), LinkFunction::Identity) {
let fitted = {
let mut eta = offset_o.clone();
eta += &x_o.matrixvectormultiply(&beta_orig);
eta
};
let resid = y_o.to_owned() - &fitted;
w_o.iter()
.zip(resid.iter())
.map(|(&wi, &ri)| wi * ri * ri)
.sum()
} else {
0.0
};
let (final_rho, pirls_res) = (final_rho, pirls_res);
let beta_orig_internal = pirls_res
.reparam_result
.qs
.dot(pirls_res.beta_transformed.as_ref());
let lambdas = final_rho.mapv(f64::exp);
let p_dim = pirls_res.beta_transformed.len();
let penalty_rank_total = pirls_res.reparam_result.e_transformed.nrows();
let mp = (p_dim as f64 - penalty_rank_total as f64).max(0.0);
let mut edf_by_block = vec![0.0; k];
let mut edf_total = 0.0;
let mut smoothing_correction = None;
let mut penalized_hessian = Array2::<f64>::zeros((0, 0));
let mut beta_covariance = None;
let mut beta_standard_errors = None;
let mut beta_covariance_corrected = None;
let mut beta_standard_errors_corrected = None;
let mut bias_correction_beta = None;
if opts.compute_inference {
let h = &pirls_res.stabilizedhessian_transformed;
let p_dim = h.nrows();
let factor = {
let scale = h.max_abs_diag();
let min_step = scale * 1e-10;
let mut ridge = 0.0_f64;
let mut attempts = 0_usize;
loop {
let candidate = if ridge > 0.0 {
match h.addridge(ridge) {
Ok(c) => c,
Err(_) => h.clone(),
}
} else {
h.clone()
};
if let Ok(f) = candidate.factorize() {
if ridge > 0.0 {
log::warn!("Stabilized Hessian factorized with ridge {:.3e}", ridge,);
}
break f;
}
attempts += 1;
if attempts >= MAX_FACTORIZATION_ATTEMPTS {
return Err(EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
});
}
ridge = if ridge <= 0.0 { min_step } else { ridge * 10.0 };
}
};
let mut traces = vec![0.0f64; k];
for (kk, cp) in pirls_res
.reparam_result
.canonical_transformed
.iter()
.enumerate()
{
let r = &cp.col_range;
let rank = cp.rank();
let mut rhs = Array2::<f64>::zeros((p_dim, rank));
for col in 0..rank {
for row in 0..cp.block_dim() {
rhs[[r.start + row, col]] = cp.root[[col, row]];
}
}
let sol =
factor
.solvemulti(&rhs)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
let mut frob = 0.0f64;
for col in 0..rank {
for row in 0..cp.block_dim() {
frob += sol[[r.start + row, col]] * rhs[[r.start + row, col]];
}
}
traces[kk] = lambdas[kk] * frob;
}
edf_total = (p_dim as f64 - kahan_sum(traces.iter().copied())).clamp(mp, p_dim as f64);
for (kk, cp) in pirls_res
.reparam_result
.canonical_transformed
.iter()
.enumerate()
{
let p_k = cp.rank() as f64;
let edf_k = (p_k - traces[kk]).clamp(0.0, p_k);
edf_by_block[kk] = edf_k;
}
let beta_t = pirls_res.beta_transformed.as_ref();
let mut s_beta_t = Array1::<f64>::zeros(p_dim);
for (kk, cp) in pirls_res
.reparam_result
.canonical_transformed
.iter()
.enumerate()
{
let r = &cp.col_range;
let local = cp.local_ref();
let beta_block = beta_t.slice(ndarray::s![r.clone()]);
let local_beta = local.dot(&beta_block);
let lam_k = lambdas[kk];
let mut acc = s_beta_t.slice_mut(ndarray::s![r.clone()]);
acc.scaled_add(lam_k, &local_beta);
}
match factor.solve(&s_beta_t) {
Ok(b_t) => {
let qs = &pirls_res.reparam_result.qs;
let b_orig = qs.dot(&b_t);
if b_orig.iter().all(|v| v.is_finite()) {
bias_correction_beta = Some(b_orig);
} else {
log::warn!("bias-correction vector contained non-finite entries; skipping");
}
}
Err(e) => {
log::warn!("bias-correction solve failed: {e}");
}
}
}
let standard_deviation = match pirls_res.likelihood.family {
GlmLikelihoodFamily::GaussianIdentity => {
let denom = if opts.compute_inference {
(n - mp).max(1.0)
} else {
n.max(1.0)
};
(weighted_rss / denom).sqrt()
}
GlmLikelihoodFamily::GammaLog => pirls_res.likelihood.gamma_shape().unwrap_or(1.0),
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture
| GlmLikelihoodFamily::PoissonLog => 1.0,
};
let finalgrad = reml_state
.compute_gradient(&final_rho)
.unwrap_or_else(|_| Array1::from_elem(final_rho.len(), f64::NAN));
let finalgrad_norm_rho = finalgrad.dot(&finalgrad).sqrt();
let finalgrad_norm = if finalgrad_norm_rho.is_finite() {
finalgrad_norm_rho
} else {
outer_result.final_grad_norm
};
if opts.compute_inference {
penalized_hessian = map_hessian_to_original_basis(&pirls_res)?;
const COV_MAX_P: usize = 5_000;
let p_cov = penalized_hessian.nrows();
let diag_fallback = || {
let mut diag_inv = Array2::<f64>::zeros(penalized_hessian.dim());
for i in 0..p_cov {
let d = penalized_hessian[[i, i]];
if d > 0.0 {
diag_inv[[i, i]] = 1.0 / d;
}
}
diag_inv
};
beta_covariance = if p_cov > COV_MAX_P {
log::warn!(
"skipping full posterior covariance inversion (p={p_cov} > {COV_MAX_P}): \
using diagonal-only standard errors"
);
Some(diag_fallback())
} else {
match matrix_inversewith_regularization(&penalized_hessian, "posterior covariance") {
Some(cov) => Some(cov),
None => {
log::warn!(
"full posterior covariance inversion failed (p={p_cov}): \
falling back to diagonal-only standard errors"
);
Some(diag_fallback())
}
}
};
smoothing_correction = reml_state.compute_smoothing_correction_auto(
&final_rho,
&pirls_res,
beta_covariance.as_ref(),
finalgrad_norm,
);
beta_standard_errors = beta_covariance.as_ref().map(se_from_covariance);
beta_covariance_corrected = match (&beta_covariance, &smoothing_correction) {
(Some(base_cov), Some(corr)) if base_cov.dim() == corr.dim() => {
let mut corrected = base_cov.clone();
corrected += corr;
enforce_symmetry(&mut corrected);
Some(corrected)
}
(Some(_), Some(corr)) => {
log::warn!(
"Skipping corrected covariance: dimension mismatch (base {:?}, corr {:?})",
beta_covariance.as_ref().map(Array2::dim),
Some(corr.dim())
);
None
}
_ => None,
};
beta_standard_errors_corrected = beta_covariance_corrected.as_ref().map(se_from_covariance);
}
let inference = opts.compute_inference.then(|| FitInference {
edf_by_block,
edf_total,
smoothing_correction,
penalized_hessian,
working_weights: pirls_res.solveweights.clone(),
working_response: pirls_res.solveworking_response.clone(),
reparam_qs: Some(pirls_res.reparam_result.qs.clone()),
beta_covariance,
beta_standard_errors,
beta_covariance_corrected,
beta_standard_errors_corrected,
bias_correction_beta,
});
let pirls_status = pirls_res.status;
let likelihood_spec = pirls_res.likelihood;
let log_likelihood = crate::pirls::calculate_loglikelihood_omitting_constants(
y_o.view(),
&pirls_res.finalmu,
likelihood_spec,
w_o.view(),
);
let result = ExternalOptimResult {
beta: beta_orig_internal,
lambdas: lambdas.to_owned(),
likelihood_family: likelihood_spec.response_family(),
likelihood_scale: likelihood_spec.scale,
log_likelihood_normalization: LogLikelihoodNormalization::OmittingResponseConstants,
log_likelihood,
standard_deviation,
iterations: iters,
finalgrad_norm,
pirls_status,
deviance: pirls_res.deviance,
stable_penalty_term: pirls_res.stable_penalty_term,
max_abs_eta: pirls_res.max_abs_eta,
constraint_kkt: pirls_res.constraint_kkt.clone(),
artifacts: FitArtifacts {
pirls: Some(pirls_res),
..Default::default()
},
inference,
reml_score: outer_result.final_value,
fitted_link: if let Some(state) = final_mixture_state {
FittedLinkState::Mixture {
state,
covariance: final_mixture_param_covariance,
}
} else if let Some(state) = opts.latent_cloglog {
FittedLinkState::LatentCLogLog { state }
} else if let Some(state) = final_sas_state {
match opts.family {
crate::types::LikelihoodFamily::BinomialSas => FittedLinkState::Sas {
state,
covariance: final_sas_param_covariance,
},
crate::types::LikelihoodFamily::BinomialBetaLogistic => {
FittedLinkState::BetaLogistic {
state,
covariance: final_sas_param_covariance,
}
}
_ => FittedLinkState::Standard(None),
}
} else {
FittedLinkState::Standard(None)
},
};
Ok(conditioning.backtransform_external_result(result))
}
#[derive(Clone)]
pub struct FitOptions {
pub latent_cloglog: Option<LatentCLogLogState>,
pub mixture_link: Option<MixtureLinkSpec>,
pub optimize_mixture: bool,
pub sas_link: Option<SasLinkSpec>,
pub optimize_sas: bool,
pub compute_inference: bool,
pub max_iter: usize,
pub tol: f64,
pub nullspace_dims: Vec<usize>,
pub linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
pub firth_bias_reduction: bool,
pub adaptive_regularization: Option<AdaptiveRegularizationOptions>,
pub penalty_shrinkage_floor: Option<f64>,
pub rho_prior: crate::types::RhoPrior,
pub kronecker_penalty_system: Option<crate::smooth::KroneckerPenaltySystem>,
pub kronecker_factored: Option<crate::basis::KroneckerFactoredBasis>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AdaptiveRegularizationOptions {
pub enabled: bool,
pub max_mm_iter: usize,
pub beta_rel_tol: f64,
pub max_epsilon_outer_iter: usize,
pub epsilon_log_step: f64,
pub min_epsilon: f64,
pub weight_floor: f64,
pub weight_ceiling: f64,
}
impl Default for AdaptiveRegularizationOptions {
fn default() -> Self {
Self {
enabled: false,
max_mm_iter: 10,
beta_rel_tol: 1e-3,
max_epsilon_outer_iter: 4,
epsilon_log_step: std::f64::consts::LN_2,
min_epsilon: 1e-8,
weight_floor: 1e-8,
weight_ceiling: 1e8,
}
}
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct FitArtifacts {
#[serde(default, skip_serializing, skip_deserializing)]
pub pirls: Option<crate::pirls::PirlsResult>,
#[serde(default)]
pub survival_link_wiggle_knots: Option<Array1<f64>>,
#[serde(default)]
pub survival_link_wiggle_degree: Option<usize>,
}
impl std::fmt::Debug for FitArtifacts {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FitArtifacts")
.field("pirls", &self.pirls.as_ref().map(|_| "..."))
.field(
"survival_link_wiggle_knots",
&self
.survival_link_wiggle_knots
.as_ref()
.map(|knots| knots.len()),
)
.field(
"survival_link_wiggle_degree",
&self.survival_link_wiggle_degree,
)
.finish()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FitInference {
pub edf_by_block: Vec<f64>,
pub edf_total: f64,
pub smoothing_correction: Option<Array2<f64>>,
pub penalized_hessian: Array2<f64>,
pub working_weights: Array1<f64>,
pub working_response: Array1<f64>,
pub reparam_qs: Option<Array2<f64>>,
pub beta_covariance: Option<Array2<f64>>,
pub beta_standard_errors: Option<Array1<f64>>,
pub beta_covariance_corrected: Option<Array2<f64>>,
pub beta_standard_errors_corrected: Option<Array1<f64>>,
#[serde(default)]
pub bias_correction_beta: Option<Array1<f64>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum FittedLinkState {
Standard(Option<LinkFunction>),
LatentCLogLog {
state: LatentCLogLogState,
},
Sas {
state: SasLinkState,
covariance: Option<Array2<f64>>,
},
BetaLogistic {
state: SasLinkState,
covariance: Option<Array2<f64>>,
},
Mixture {
state: MixtureLinkState,
covariance: Option<Array2<f64>>,
},
}
pub fn saved_mixture_state_from_fit(fit: &UnifiedFitResult) -> Option<MixtureLinkState> {
match &fit.fitted_link {
FittedLinkState::Mixture { state, .. } => Some(state.clone()),
_ => None,
}
}
pub fn saved_latent_cloglog_state_from_fit(fit: &UnifiedFitResult) -> Option<LatentCLogLogState> {
match &fit.fitted_link {
FittedLinkState::LatentCLogLog { state } => Some(*state),
_ => None,
}
}
pub fn saved_sas_state_from_fit(fit: &UnifiedFitResult) -> Option<SasLinkState> {
match &fit.fitted_link {
FittedLinkState::Sas { state, .. } | FittedLinkState::BetaLogistic { state, .. } => {
Some(*state)
}
_ => None,
}
}
fn validate_fitted_link_estimation(fitted_link: &FittedLinkState) -> Result<(), EstimationError> {
match fitted_link {
FittedLinkState::Standard(_) => Ok(()),
FittedLinkState::LatentCLogLog { state } => {
ensure_finite_scalar_estimation("fit_result.latent_cloglog.latent_sd", state.latent_sd)
}
FittedLinkState::Mixture { state, covariance } => {
validate_all_finite_estimation(
"fit_result.mixture_link_rho",
state.rho.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.mixture_linkweights",
state.pi.iter().copied(),
)?;
if let Some(v) = covariance.as_ref() {
validate_all_finite_estimation(
"fit_result.mixture_link_param_covariance",
v.iter().copied(),
)?;
}
Ok(())
}
FittedLinkState::Sas { state, covariance }
| FittedLinkState::BetaLogistic { state, covariance } => {
ensure_finite_scalar_estimation("fit_result.sas_epsilon", state.epsilon)?;
ensure_finite_scalar_estimation("fit_result.sas_log_delta", state.log_delta)?;
ensure_finite_scalar_estimation("fit_result.sas_delta", state.delta)?;
if let Some(v) = covariance.as_ref() {
validate_all_finite_estimation(
"fit_result.sas_param_covariance",
v.iter().copied(),
)?;
}
Ok(())
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockRole {
Mean,
Location,
Scale,
Time,
Threshold,
LinkWiggle,
}
impl BlockRole {
#[inline]
pub fn name(&self) -> &'static str {
match self {
Self::Mean => "mean",
Self::Location => "location",
Self::Scale => "scale",
Self::Time => "time",
Self::Threshold => "threshold",
Self::LinkWiggle => "link-wiggle",
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FittedBlock {
pub beta: Array1<f64>,
pub role: BlockRole,
pub edf: f64,
pub lambdas: Array1<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FitGeometry {
pub penalized_hessian: Array2<f64>,
pub working_weights: Array1<f64>,
pub working_response: Array1<f64>,
}
pub struct UnifiedFitResultParts {
pub blocks: Vec<FittedBlock>,
pub log_lambdas: Array1<f64>,
pub lambdas: Array1<f64>,
pub likelihood_family: Option<LikelihoodFamily>,
pub likelihood_scale: LikelihoodScaleMetadata,
pub log_likelihood_normalization: LogLikelihoodNormalization,
pub log_likelihood: f64,
pub deviance: f64,
pub reml_score: f64,
pub stable_penalty_term: f64,
pub penalized_objective: f64,
pub outer_iterations: usize,
pub outer_converged: bool,
pub outer_gradient_norm: f64,
pub standard_deviation: f64,
pub covariance_conditional: Option<Array2<f64>>,
pub covariance_corrected: Option<Array2<f64>>,
pub inference: Option<FitInference>,
pub fitted_link: FittedLinkState,
pub geometry: Option<FitGeometry>,
pub block_states: Vec<crate::families::custom_family::ParameterBlockState>,
#[doc(hidden)]
pub pirls_status: crate::pirls::PirlsStatus,
#[doc(hidden)]
pub max_abs_eta: f64,
#[doc(hidden)]
pub constraint_kkt: Option<crate::pirls::ConstraintKktDiagnostics>,
#[doc(hidden)]
pub artifacts: FitArtifacts,
#[doc(hidden)]
pub inner_cycles: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct UnifiedFitResult {
pub blocks: Vec<FittedBlock>,
pub log_lambdas: Array1<f64>,
pub lambdas: Array1<f64>,
pub likelihood_family: Option<LikelihoodFamily>,
pub likelihood_scale: LikelihoodScaleMetadata,
pub log_likelihood_normalization: LogLikelihoodNormalization,
pub log_likelihood: f64,
pub deviance: f64,
pub reml_score: f64,
pub stable_penalty_term: f64,
pub penalized_objective: f64,
pub outer_iterations: usize,
pub outer_converged: bool,
pub outer_gradient_norm: f64,
pub standard_deviation: f64,
pub covariance_conditional: Option<Array2<f64>>,
pub covariance_corrected: Option<Array2<f64>>,
pub inference: Option<FitInference>,
pub fitted_link: FittedLinkState,
pub geometry: Option<FitGeometry>,
#[serde(skip)]
pub block_states: Vec<crate::families::custom_family::ParameterBlockState>,
#[serde(default)]
pub beta: Array1<f64>,
pub pirls_status: crate::pirls::PirlsStatus,
#[serde(default)]
pub max_abs_eta: f64,
#[serde(default)]
pub constraint_kkt: Option<crate::pirls::ConstraintKktDiagnostics>,
#[serde(default)]
pub artifacts: FitArtifacts,
#[serde(default)]
pub inner_cycles: usize,
}
impl Default for FittedLinkState {
fn default() -> Self {
FittedLinkState::Standard(None)
}
}
pub(crate) fn ensure_finite_scalar_estimation(
name: &str,
value: f64,
) -> Result<(), EstimationError> {
if value.is_finite() {
Ok(())
} else {
Err(EstimationError::InvalidInput(format!(
"{name} must be finite, got {value}"
)))
}
}
fn validate_likelihood_scale_estimation(
scale: LikelihoodScaleMetadata,
) -> Result<(), EstimationError> {
match scale {
LikelihoodScaleMetadata::ProfiledGaussian | LikelihoodScaleMetadata::Unspecified => Ok(()),
LikelihoodScaleMetadata::FixedDispersion { phi } => {
ensure_finite_scalar_estimation("fit_result.likelihood_scale.phi", phi)?;
if phi > 0.0 {
Ok(())
} else {
Err(EstimationError::InvalidInput(format!(
"fit_result.likelihood_scale.phi must be > 0, got {phi}"
)))
}
}
LikelihoodScaleMetadata::FixedGammaShape { shape }
| LikelihoodScaleMetadata::EstimatedGammaShape { shape } => {
ensure_finite_scalar_estimation("fit_result.likelihood_scale.shape", shape)?;
if shape > 0.0 {
Ok(())
} else {
Err(EstimationError::InvalidInput(format!(
"fit_result.likelihood_scale.shape must be > 0, got {shape}"
)))
}
}
}
}
pub(crate) fn validate_all_finite_estimation<I>(
label: &str,
values: I,
) -> Result<(), EstimationError>
where
I: IntoIterator<Item = f64>,
{
for (idx, value) in values.into_iter().enumerate() {
if !value.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"{label}[{idx}] must be finite, got {value}"
)));
}
}
Ok(())
}
pub fn ensure_finite_scalar(name: &str, value: f64) -> Result<(), String> {
ensure_finite_scalar_estimation(name, value).map_err(|e| e.to_string())
}
pub fn validate_all_finite<I: IntoIterator<Item = f64>>(
label: &str,
values: I,
) -> Result<(), String> {
validate_all_finite_estimation(label, values).map_err(|e| e.to_string())
}
impl FitGeometry {
pub fn validate_numeric_finiteness(&self) -> Result<(), EstimationError> {
validate_all_finite_estimation(
"fit_result.geometry.penalized_hessian",
self.penalized_hessian.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.geometry.working_weights",
self.working_weights.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.geometry.working_response",
self.working_response.iter().copied(),
)?;
Ok(())
}
}
impl FitInference {
pub fn validate_numeric_finiteness(&self) -> Result<(), EstimationError> {
ensure_finite_scalar_estimation("fit_result.edf_total", self.edf_total)?;
validate_all_finite_estimation(
"fit_result.edf_by_block",
self.edf_by_block.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.working_weights",
self.working_weights.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.working_response",
self.working_response.iter().copied(),
)?;
validate_all_finite_estimation(
"fit_result.penalized_hessian",
self.penalized_hessian.iter().copied(),
)?;
if let Some(v) = self.beta_covariance.as_ref() {
validate_all_finite_estimation("fit_result.beta_covariance", v.iter().copied())?;
}
if let Some(v) = self.beta_covariance_corrected.as_ref() {
validate_all_finite_estimation(
"fit_result.beta_covariance_corrected",
v.iter().copied(),
)?;
}
if let Some(v) = self.beta_standard_errors.as_ref() {
validate_all_finite_estimation("fit_result.beta_standard_errors", v.iter().copied())?;
}
if let Some(v) = self.bias_correction_beta.as_ref() {
validate_all_finite_estimation("fit_result.bias_correction_beta", v.iter().copied())?;
}
if let Some(v) = self.beta_standard_errors_corrected.as_ref() {
validate_all_finite_estimation(
"fit_result.beta_standard_errors_corrected",
v.iter().copied(),
)?;
}
if let Some(v) = self.smoothing_correction.as_ref() {
validate_all_finite_estimation("fit_result.smoothing_correction", v.iter().copied())?;
}
if let Some(v) = self.reparam_qs.as_ref() {
validate_all_finite_estimation("fit_result.reparam_qs", v.iter().copied())?;
}
Ok(())
}
}
pub fn validate_dense_hessian_export(
label: &str,
hessian: &Array2<f64>,
expected_dim: usize,
) -> Result<(), EstimationError> {
if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"{label} shape mismatch: got {}x{}, expected {}x{}",
hessian.nrows(),
hessian.ncols(),
expected_dim,
expected_dim
)));
}
if expected_dim == 0 {
return Ok(());
}
validate_all_finite_estimation(label, hessian.iter().copied())?;
if !hessian.iter().any(|value| value.abs() > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"{label} must be an explicit dense Hessian; zero placeholders are not allowed at fit export"
)));
}
let symmetry_tol = 1e-10;
for i in 0..expected_dim {
for j in 0..i {
let a = hessian[[i, j]];
let b = hessian[[j, i]];
let scale = 1.0_f64.max(a.abs()).max(b.abs());
if (a - b).abs() > symmetry_tol * scale {
return Err(EstimationError::InvalidInput(format!(
"{label} must be symmetric at fit export; entries ({i},{j})={a} and ({j},{i})={b} differ"
)));
}
}
}
Ok(())
}
pub fn validate_explicit_dense_hessian_for_whitening(
label: &str,
hessian: &Array2<f64>,
expected_dim: usize,
) -> Result<(), EstimationError> {
validate_dense_hessian_export(label, hessian, expected_dim)?;
if expected_dim == 0 {
return Ok(());
}
hessian
.to_owned()
.cholesky(Side::Lower)
.map(|_| ())
.map_err(|err| {
EstimationError::InvalidInput(format!(
"{label} must be positive definite for HMC/NUTS whitening; Cholesky failed: {err:?}"
))
})
}
fn array1_values_equal(lhs: &Array1<f64>, rhs: &Array1<f64>) -> bool {
lhs.len() == rhs.len() && lhs.iter().zip(rhs.iter()).all(|(a, b)| a == b)
}
fn array2_values_equal(lhs: &Array2<f64>, rhs: &Array2<f64>) -> bool {
lhs.dim() == rhs.dim() && lhs.iter().zip(rhs.iter()).all(|(a, b)| a == b)
}
fn log_lambdas_match_lambdas(log_lambdas: &Array1<f64>, lambdas: &Array1<f64>) -> bool {
if log_lambdas.len() != lambdas.len() {
return false;
}
log_lambdas
.iter()
.zip(lambdas.iter())
.all(|(&log_lam, &lam)| {
let canonical = lam.max(1e-300).ln();
let tol = 1e-12 * (1.0 + canonical.abs());
(log_lam - canonical).abs() <= tol
})
}
fn flatten_block_betas(blocks: &[FittedBlock]) -> Array1<f64> {
let total: usize = blocks.iter().map(|b| b.beta.len()).sum();
let mut flat = Array1::zeros(total);
let mut off = 0;
for block in blocks {
let p = block.beta.len();
flat.slice_mut(ndarray::s![off..off + p])
.assign(&block.beta);
off += p;
}
flat
}
fn flatten_block_lambdas(blocks: &[FittedBlock]) -> Array1<f64> {
let total: usize = blocks.iter().map(|b| b.lambdas.len()).sum();
let mut flat = Array1::zeros(total);
let mut off = 0;
for block in blocks {
let p = block.lambdas.len();
flat.slice_mut(ndarray::s![off..off + p])
.assign(&block.lambdas);
off += p;
}
flat
}
impl UnifiedFitResult {
pub fn try_from_parts(parts: UnifiedFitResultParts) -> Result<Self, EstimationError> {
let UnifiedFitResultParts {
blocks,
log_lambdas,
lambdas,
likelihood_family,
likelihood_scale,
log_likelihood_normalization,
log_likelihood,
deviance,
reml_score,
stable_penalty_term,
penalized_objective,
outer_iterations,
outer_converged,
outer_gradient_norm,
standard_deviation,
covariance_conditional,
covariance_corrected,
inference,
fitted_link,
geometry,
block_states,
pirls_status,
max_abs_eta,
constraint_kkt,
artifacts,
inner_cycles,
} = parts;
if blocks.is_empty() {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult requires at least one coefficient block".to_string(),
));
}
if log_lambdas.len() != lambdas.len() {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult lambda mismatch: log_lambdas={}, lambdas={}",
log_lambdas.len(),
lambdas.len()
)));
}
for (idx, block) in blocks.iter().enumerate() {
validate_all_finite_estimation(
&format!("fit_result.blocks[{idx}].beta"),
block.beta.iter().copied(),
)?;
ensure_finite_scalar_estimation(&format!("fit_result.blocks[{idx}].edf"), block.edf)?;
validate_all_finite_estimation(
&format!("fit_result.blocks[{idx}].lambdas"),
block.lambdas.iter().copied(),
)?;
}
let beta = flatten_block_betas(&blocks);
let block_lambdas = flatten_block_lambdas(&blocks);
if !array1_values_equal(&block_lambdas, &lambdas) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult top-level lambdas must match block lambdas concatenated in block order"
.to_string(),
));
}
validate_all_finite_estimation("fit_result.log_lambdas", log_lambdas.iter().copied())?;
validate_all_finite_estimation("fit_result.lambdas", lambdas.iter().copied())?;
if !log_lambdas_match_lambdas(&log_lambdas, &lambdas) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult log_lambdas must equal ln(lambdas) elementwise".to_string(),
));
}
validate_likelihood_scale_estimation(likelihood_scale)?;
ensure_finite_scalar_estimation("fit_result.log_likelihood", log_likelihood)?;
ensure_finite_scalar_estimation("fit_result.deviance", deviance)?;
ensure_finite_scalar_estimation("fit_result.reml_score", reml_score)?;
ensure_finite_scalar_estimation("fit_result.stable_penalty_term", stable_penalty_term)?;
ensure_finite_scalar_estimation("fit_result.penalized_objective", penalized_objective)?;
ensure_finite_scalar_estimation("fit_result.outer_gradient_norm", outer_gradient_norm)?;
ensure_finite_scalar_estimation("fit_result.standard_deviation", standard_deviation)?;
if let Some(v) = covariance_conditional.as_ref() {
validate_all_finite_estimation("fit_result.beta_covariance", v.iter().copied())?;
}
if let Some(v) = covariance_corrected.as_ref() {
validate_all_finite_estimation(
"fit_result.beta_covariance_corrected",
v.iter().copied(),
)?;
}
if let Some(inf) = inference.as_ref() {
inf.validate_numeric_finiteness()?;
}
if let Some(geom) = geometry.as_ref() {
geom.validate_numeric_finiteness()?;
}
for (idx, state) in block_states.iter().enumerate() {
validate_all_finite_estimation(
&format!("fit_result.block_states[{idx}].beta"),
state.beta.iter().copied(),
)?;
validate_all_finite_estimation(
&format!("fit_result.block_states[{idx}].eta"),
state.eta.iter().copied(),
)?;
}
validate_fitted_link_estimation(&fitted_link)?;
let p = beta.len();
if let Some(cov) = covariance_conditional.as_ref() {
if cov.nrows() != p || cov.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult conditional covariance shape mismatch: got {}x{}, expected {}x{}",
cov.nrows(),
cov.ncols(),
p,
p
)));
}
}
if let Some(cov) = covariance_corrected.as_ref() {
if cov.nrows() != p || cov.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult corrected covariance shape mismatch: got {}x{}, expected {}x{}",
cov.nrows(),
cov.ncols(),
p,
p
)));
}
}
if let Some(inf) = inference.as_ref() {
if !inf.edf_by_block.is_empty() && inf.edf_by_block.len() != lambdas.len() {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult EDF smoothing-parameter count mismatch: edf_by_block={}, lambdas={}",
inf.edf_by_block.len(),
lambdas.len()
)));
}
if inf.working_weights.len() != inf.working_response.len() {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult working vector length mismatch: working_weights={}, working_response={}",
inf.working_weights.len(),
inf.working_response.len()
)));
}
if inf.penalized_hessian.nrows() != p || inf.penalized_hessian.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult penalized Hessian shape mismatch: got {}x{}, expected {}x{}",
inf.penalized_hessian.nrows(),
inf.penalized_hessian.ncols(),
p,
p
)));
}
validate_dense_hessian_export(
"UnifiedFitResult inference penalized Hessian",
&inf.penalized_hessian,
p,
)?;
if let Some(cov) = inf.beta_covariance.as_ref() {
if cov.nrows() != p || cov.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult inference conditional covariance shape mismatch: got {}x{}, expected {}x{}",
cov.nrows(),
cov.ncols(),
p,
p
)));
}
match covariance_conditional.as_ref() {
Some(top) if array2_values_equal(cov, top) => {}
Some(_) => {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult inference conditional covariance must match top-level covariance_conditional"
.to_string(),
));
}
None => {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult inference conditional covariance requires top-level covariance_conditional"
.to_string(),
));
}
}
}
if let Some(se) = inf.beta_standard_errors.as_ref()
&& se.len() != p
{
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult beta standard error length mismatch: got {}, expected {}",
se.len(),
p
)));
}
if let Some(cov) = inf.beta_covariance_corrected.as_ref() {
if cov.nrows() != p || cov.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult inference corrected covariance shape mismatch: got {}x{}, expected {}x{}",
cov.nrows(),
cov.ncols(),
p,
p
)));
}
match covariance_corrected.as_ref() {
Some(top) if array2_values_equal(cov, top) => {}
Some(_) => {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult inference corrected covariance must match top-level covariance_corrected"
.to_string(),
));
}
None => {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult inference corrected covariance requires top-level covariance_corrected"
.to_string(),
));
}
}
}
if let Some(se) = inf.beta_standard_errors_corrected.as_ref()
&& se.len() != p
{
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult corrected beta standard error length mismatch: got {}, expected {}",
se.len(),
p
)));
}
if let Some(corr) = inf.smoothing_correction.as_ref()
&& (corr.nrows() != p || corr.ncols() != p)
{
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult smoothing correction shape mismatch: got {}x{}, expected {}x{}",
corr.nrows(),
corr.ncols(),
p,
p
)));
}
if let Some(qs) = inf.reparam_qs.as_ref()
&& (qs.nrows() != p || qs.ncols() != p)
{
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult reparam_qs shape mismatch: got {}x{}, expected {}x{}",
qs.nrows(),
qs.ncols(),
p,
p
)));
}
}
if let Some(geom) = geometry.as_ref() {
if geom.penalized_hessian.nrows() != p || geom.penalized_hessian.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult geometry penalized Hessian shape mismatch: got {}x{}, expected {}x{}",
geom.penalized_hessian.nrows(),
geom.penalized_hessian.ncols(),
p,
p
)));
}
validate_dense_hessian_export(
"UnifiedFitResult geometry penalized Hessian",
&geom.penalized_hessian,
p,
)?;
if geom.working_weights.len() != geom.working_response.len() {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult geometry working vector length mismatch: working_weights={}, working_response={}",
geom.working_weights.len(),
geom.working_response.len()
)));
}
if let Some(inf) = inference.as_ref() {
if !array2_values_equal(&geom.penalized_hessian, &inf.penalized_hessian) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult geometry penalized Hessian must match inference.penalized_hessian"
.to_string(),
));
}
if !array1_values_equal(&geom.working_weights, &inf.working_weights) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult geometry working_weights must match inference.working_weights"
.to_string(),
));
}
if !array1_values_equal(&geom.working_response, &inf.working_response) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult geometry working_response must match inference.working_response"
.to_string(),
));
}
}
}
if !block_states.is_empty() && block_states.len() != blocks.len() {
return Err(EstimationError::InvalidInput(format!(
"UnifiedFitResult block state count mismatch: blocks={}, block_states={}",
blocks.len(),
block_states.len()
)));
}
Ok(Self {
blocks,
log_lambdas,
lambdas,
likelihood_family,
likelihood_scale,
log_likelihood_normalization,
log_likelihood,
deviance,
reml_score,
stable_penalty_term,
penalized_objective,
outer_iterations,
outer_converged,
outer_gradient_norm,
standard_deviation,
covariance_conditional,
covariance_corrected,
inference,
fitted_link,
geometry,
block_states,
beta,
pirls_status,
max_abs_eta,
constraint_kkt,
artifacts,
inner_cycles,
})
}
#[cfg(test)]
pub(crate) fn new_for_test_unchecked(parts: UnifiedFitResultParts) -> Self {
let beta = flatten_block_betas(&parts.blocks);
Self {
blocks: parts.blocks,
log_lambdas: parts.log_lambdas,
lambdas: parts.lambdas,
likelihood_family: parts.likelihood_family,
likelihood_scale: parts.likelihood_scale,
log_likelihood_normalization: parts.log_likelihood_normalization,
log_likelihood: parts.log_likelihood,
deviance: parts.deviance,
reml_score: parts.reml_score,
stable_penalty_term: parts.stable_penalty_term,
penalized_objective: parts.penalized_objective,
outer_iterations: parts.outer_iterations,
outer_converged: parts.outer_converged,
outer_gradient_norm: parts.outer_gradient_norm,
standard_deviation: parts.standard_deviation,
covariance_conditional: parts.covariance_conditional,
covariance_corrected: parts.covariance_corrected,
inference: parts.inference,
fitted_link: parts.fitted_link,
geometry: parts.geometry,
block_states: parts.block_states,
beta,
pirls_status: parts.pirls_status,
max_abs_eta: parts.max_abs_eta,
constraint_kkt: parts.constraint_kkt,
artifacts: parts.artifacts,
inner_cycles: parts.inner_cycles,
}
}
pub fn validate_numeric_finiteness(&self) -> Result<(), EstimationError> {
let expected_beta = flatten_block_betas(&self.blocks);
if !array1_values_equal(&self.beta, &expected_beta) {
return Err(EstimationError::InvalidInput(
"UnifiedFitResult decoded beta must match coefficient blocks concatenated in block order"
.to_string(),
));
}
Self::try_from_parts(UnifiedFitResultParts {
blocks: self.blocks.clone(),
log_lambdas: self.log_lambdas.clone(),
lambdas: self.lambdas.clone(),
likelihood_family: self.likelihood_family,
likelihood_scale: self.likelihood_scale,
log_likelihood_normalization: self.log_likelihood_normalization,
log_likelihood: self.log_likelihood,
deviance: self.deviance,
reml_score: self.reml_score,
stable_penalty_term: self.stable_penalty_term,
penalized_objective: self.penalized_objective,
outer_iterations: self.outer_iterations,
outer_converged: self.outer_converged,
outer_gradient_norm: self.outer_gradient_norm,
standard_deviation: self.standard_deviation,
covariance_conditional: self.covariance_conditional.clone(),
covariance_corrected: self.covariance_corrected.clone(),
inference: self.inference.clone(),
fitted_link: self.fitted_link.clone(),
geometry: self.geometry.clone(),
block_states: self.block_states.clone(),
pirls_status: self.pirls_status,
max_abs_eta: self.max_abs_eta,
constraint_kkt: self.constraint_kkt.clone(),
artifacts: self.artifacts.clone(),
inner_cycles: self.inner_cycles,
})
.map(|_| ())
}
}
impl UnifiedFitResult {
pub fn beta_covariance(&self) -> Option<&Array2<f64>> {
self.covariance_conditional.as_ref()
}
pub fn beta_covariance_corrected(&self) -> Option<&Array2<f64>> {
self.covariance_corrected.as_ref().or_else(|| {
self.inference
.as_ref()
.and_then(|inf| inf.beta_covariance_corrected.as_ref())
})
}
pub fn beta_standard_errors(&self) -> Option<&Array1<f64>> {
self.inference
.as_ref()
.and_then(|inf| inf.beta_standard_errors.as_ref())
}
pub fn beta_standard_errors_corrected(&self) -> Option<&Array1<f64>> {
self.inference
.as_ref()
.and_then(|inf| inf.beta_standard_errors_corrected.as_ref())
}
pub fn bias_correction_beta(&self) -> Option<&Array1<f64>> {
self.inference
.as_ref()
.and_then(|inf| inf.bias_correction_beta.as_ref())
}
pub fn penalized_hessian(&self) -> Option<&Array2<f64>> {
self.inference
.as_ref()
.map(|inf| &inf.penalized_hessian)
.or_else(|| self.geometry.as_ref().map(|geom| &geom.penalized_hessian))
}
pub fn working_weights(&self) -> Option<&Array1<f64>> {
self.inference.as_ref().map(|inf| &inf.working_weights)
}
pub fn working_response(&self) -> Option<&Array1<f64>> {
self.inference.as_ref().map(|inf| &inf.working_response)
}
pub fn edf_total(&self) -> Option<f64> {
self.inference.as_ref().map(|inf| inf.edf_total)
}
pub fn edf_by_block(&self) -> &[f64] {
self.inference
.as_ref()
.map(|inf| inf.edf_by_block.as_slice())
.unwrap_or(&[])
}
pub fn block_by_role(&self, role: BlockRole) -> Option<&FittedBlock> {
self.blocks.iter().find(|b| b.role == role)
}
pub fn beta_flat(&self) -> Array1<f64> {
self.beta.clone()
}
pub fn beta_time(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Time)
.map(|b| b.beta.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn beta_threshold(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Threshold)
.map(|b| b.beta.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn beta_log_sigma(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Scale)
.map(|b| b.beta.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn beta_link_wiggle(&self) -> Option<Array1<f64>> {
self.block_by_role(BlockRole::LinkWiggle)
.map(|b| b.beta.clone())
}
pub fn lambdas_time(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Time)
.map(|b| b.lambdas.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn lambdas_threshold(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Threshold)
.map(|b| b.lambdas.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn lambdas_log_sigma(&self) -> Array1<f64> {
self.block_by_role(BlockRole::Scale)
.map(|b| b.lambdas.clone())
.unwrap_or_else(|| Array1::zeros(0))
}
pub fn lambdas_linkwiggle(&self) -> Option<Array1<f64>> {
self.block_by_role(BlockRole::LinkWiggle)
.map(|b| b.lambdas.clone())
}
pub fn n_blocks(&self) -> usize {
self.blocks.len()
}
pub fn block_roles(&self) -> Vec<BlockRole> {
self.blocks.iter().map(|b| b.role.clone()).collect()
}
pub fn fitted_link_state(
&self,
family: crate::types::LikelihoodFamily,
) -> Result<FittedLinkState, EstimationError> {
match family {
crate::types::LikelihoodFamily::GaussianIdentity => {
Ok(FittedLinkState::Standard(Some(LinkFunction::Identity)))
}
crate::types::LikelihoodFamily::BinomialLogit => {
Ok(FittedLinkState::Standard(Some(LinkFunction::Logit)))
}
crate::types::LikelihoodFamily::BinomialProbit => {
Ok(FittedLinkState::Standard(Some(LinkFunction::Probit)))
}
crate::types::LikelihoodFamily::BinomialCLogLog => {
Ok(FittedLinkState::Standard(Some(LinkFunction::CLogLog)))
}
crate::types::LikelihoodFamily::BinomialLatentCLogLog => match &self.fitted_link {
FittedLinkState::LatentCLogLog { state } => {
Ok(FittedLinkState::LatentCLogLog { state: *state })
}
_ => Err(EstimationError::InvalidInput(
"BinomialLatentCLogLog requires fixed latent cloglog state".to_string(),
)),
},
crate::types::LikelihoodFamily::BinomialSas => match &self.fitted_link {
FittedLinkState::Sas { state, covariance } => Ok(FittedLinkState::Sas {
state: state.clone(),
covariance: covariance.clone(),
}),
_ => Err(EstimationError::InvalidInput(
"BinomialSas requires fitted SAS link parameters".to_string(),
)),
},
crate::types::LikelihoodFamily::BinomialBetaLogistic => match &self.fitted_link {
FittedLinkState::BetaLogistic { state, covariance } => {
Ok(FittedLinkState::BetaLogistic {
state: state.clone(),
covariance: covariance.clone(),
})
}
_ => Err(EstimationError::InvalidInput(
"BinomialBetaLogistic requires fitted beta-logistic link parameters"
.to_string(),
)),
},
crate::types::LikelihoodFamily::BinomialMixture => match &self.fitted_link {
FittedLinkState::Mixture { state, covariance } => Ok(FittedLinkState::Mixture {
state: state.clone(),
covariance: covariance.clone(),
}),
_ => Err(EstimationError::InvalidInput(
"BinomialMixture requires fitted mixture link parameters".to_string(),
)),
},
crate::types::LikelihoodFamily::PoissonLog
| crate::types::LikelihoodFamily::GammaLog => {
Ok(FittedLinkState::Standard(Some(LinkFunction::Log)))
}
crate::types::LikelihoodFamily::RoystonParmar => Ok(FittedLinkState::Standard(None)),
}
}
}
#[derive(Clone, Debug)]
pub struct ParametricTermSummary {
pub name: String,
pub estimate: f64,
pub std_error: Option<f64>,
pub zvalue: Option<f64>,
pub pvalue: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct SmoothTermSummary {
pub name: String,
pub edf: f64,
pub ref_df: f64,
pub chi_sq: Option<f64>,
pub pvalue: Option<f64>,
pub continuous_order: Option<ContinuousSmoothnessOrder>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ContinuousSmoothnessOrderStatus {
Ok,
NonMaternRegime,
FirstOrderLimit,
IntrinsicLimit,
UndefinedZeroLambda,
}
#[derive(Clone, Debug)]
pub struct ContinuousSmoothnessOrder {
pub lambda0: f64,
pub lambda1: f64,
pub lambda2: f64,
pub r_ratio: Option<f64>,
pub nu: Option<f64>,
pub kappa2: Option<f64>,
pub status: ContinuousSmoothnessOrderStatus,
}
#[derive(Clone, Debug)]
pub struct ModelSummary {
pub family: String,
pub deviance_explained: Option<f64>,
pub reml_score: Option<f64>,
pub parametric_terms: Vec<ParametricTermSummary>,
pub smooth_terms: Vec<SmoothTermSummary>,
}
fn unscale_to_physical_lambdas(
lambda_tilde: [f64; 3],
normalization_scale: [f64; 3],
) -> Option<[f64; 3]> {
let mut out = [f64::NAN; 3];
for k in 0..3 {
let c = normalization_scale[k];
if !(c.is_finite() && c > 0.0) {
return None;
}
out[k] = lambda_tilde[k] / c;
}
Some(out)
}
pub fn compute_continuous_smoothness_order(
lambda_tilde: [f64; 3],
normalization_scale: [f64; 3],
eps: f64,
) -> ContinuousSmoothnessOrder {
let Some(lambda) = unscale_to_physical_lambdas(lambda_tilde, normalization_scale) else {
return ContinuousSmoothnessOrder {
lambda0: f64::NAN,
lambda1: f64::NAN,
lambda2: f64::NAN,
r_ratio: None,
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
};
let [lambda0, lambda1, lambda2] = lambda;
if !lambda0.is_finite() || !lambda1.is_finite() || !lambda2.is_finite() {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
}
let lambda_scale = lambda0.abs().max(lambda1.abs()).max(lambda2.abs()).max(1.0);
let lambda_floor = eps * lambda_scale;
if lambda0 <= lambda_floor {
if lambda1 > lambda_floor && lambda2 > lambda_floor {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: Some(1.0),
kappa2: Some(0.0),
status: ContinuousSmoothnessOrderStatus::IntrinsicLimit,
};
}
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
}
if lambda2 <= lambda_floor {
if lambda1 > lambda_floor && lambda1.is_finite() {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: Some(1.0),
kappa2: Some(lambda0 / lambda1),
status: ContinuousSmoothnessOrderStatus::FirstOrderLimit,
};
}
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
}
let r_ratio = (lambda1 * lambda1) / (lambda0 * lambda2);
if !r_ratio.is_finite() {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: None,
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
}
let discriminant = lambda1 * lambda1 - 4.0 * lambda0 * lambda2;
let disc_tol = eps * lambda_scale * lambda_scale;
let status = if discriminant < -disc_tol {
ContinuousSmoothnessOrderStatus::NonMaternRegime
} else {
ContinuousSmoothnessOrderStatus::Ok
};
if r_ratio <= 2.0 + eps {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: Some(r_ratio),
nu: None,
kappa2: None,
status,
};
}
let nu = r_ratio / (r_ratio - 2.0);
let kappa2 = lambda1 / ((r_ratio - 2.0) * lambda2);
if !nu.is_finite() || !kappa2.is_finite() {
return ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: Some(r_ratio),
nu: None,
kappa2: None,
status: ContinuousSmoothnessOrderStatus::UndefinedZeroLambda,
};
}
ContinuousSmoothnessOrder {
lambda0,
lambda1,
lambda2,
r_ratio: Some(r_ratio),
nu: Some(nu),
kappa2: Some(kappa2),
status,
}
}
#[cfg(test)]
pub(crate) fn try_compute_continuous_smoothness_order(
lambda_tilde: &[f64],
normalization_scale: &[f64],
eps: f64,
) -> Option<ContinuousSmoothnessOrder> {
if lambda_tilde.len() != 3 || normalization_scale.len() != 3 {
return None;
}
Some(compute_continuous_smoothness_order(
[lambda_tilde[0], lambda_tilde[1], lambda_tilde[2]],
[
normalization_scale[0],
normalization_scale[1],
normalization_scale[2],
],
eps,
))
}
fn significance_stars(p: Option<f64>) -> &'static str {
match p {
Some(v) if v.is_finite() && v < 0.001 => "***",
Some(v) if v.is_finite() && v < 0.01 => "**",
Some(v) if v.is_finite() && v < 0.05 => "*",
Some(v) if v.is_finite() && v < 0.1 => ".",
_ => "",
}
}
fn format_pvalue(p: Option<f64>) -> String {
let Some(v) = p else {
return "NA".to_string();
};
if !v.is_finite() {
return "NA".to_string();
}
if v < 2e-16 {
"< 2e-16".to_string()
} else if v < 1e-4 {
format!("{v:.2e}")
} else {
format!("{v:.4}")
}
}
impl fmt::Display for ModelSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let paramnamew = self
.parametric_terms
.iter()
.map(|t| t.name.len())
.max()
.unwrap_or(10)
.max("Term".len());
let smoothnamew = self
.smooth_terms
.iter()
.map(|t| t.name.len())
.max()
.unwrap_or(10)
.max("Term".len());
writeln!(f, "Family: {}", self.family)?;
let dev_txt = self
.deviance_explained
.map(|d| format!("{:.1}%", (100.0 * d).clamp(-9999.0, 9999.0)))
.unwrap_or_else(|| "NA".to_string());
let reml_txt = self
.reml_score
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "NA".to_string());
writeln!(f, "Deviance Explained: {dev_txt} | REML Score: {reml_txt}")?;
writeln!(f)?;
writeln!(f, "Parametric Terms:")?;
writeln!(f, "{:-<1$}", "", paramnamew + 59)?;
writeln!(
f,
"{:<namew$} {:>10} {:>12} {:>10} {:>19}",
"Term",
"Estimate",
"Standard Error",
"Z Statistic",
"Two-Sided P-Value",
namew = paramnamew
)?;
writeln!(f, "{:-<1$}", "", paramnamew + 59)?;
for term in &self.parametric_terms {
let estimate = format!("{:.4}", term.estimate);
let se = term
.std_error
.filter(|v| v.is_finite())
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "NA".to_string());
let z = term
.zvalue
.filter(|v| v.is_finite())
.map(|v| format!("{v:.2}"))
.unwrap_or_else(|| "NA".to_string());
let p = format_pvalue(term.pvalue);
let stars = significance_stars(term.pvalue);
writeln!(
f,
"{:<namew$} {:>10} {:>12} {:>10} {:>19} {}",
term.name,
estimate,
se,
z,
p,
stars,
namew = paramnamew
)?;
}
writeln!(f)?;
writeln!(f, "Smooth Terms:")?;
writeln!(f, "{:-<1$}", "", smoothnamew + 86)?;
writeln!(
f,
"{:<namew$} {:>26} {:>30} {:>12} {:>10}",
"Term",
"Effective Degrees of Freedom",
"Reference Degrees of Freedom",
"Chi-Square",
"P-Value",
namew = smoothnamew
)?;
writeln!(f, "{:-<1$}", "", smoothnamew + 86)?;
for term in &self.smooth_terms {
let chisq = term
.chi_sq
.filter(|v| v.is_finite())
.map(|v| format!("{v:.3}"))
.unwrap_or_else(|| "NA".to_string());
let p = format_pvalue(term.pvalue);
let stars = significance_stars(term.pvalue);
writeln!(
f,
"{:<namew$} {:>26.2} {:>30.2} {:>12} {:>10} {}",
term.name,
term.edf,
term.ref_df,
chisq,
p,
stars,
namew = smoothnamew
)?;
}
writeln!(f)?;
let order_terms = self
.smooth_terms
.iter()
.filter_map(|t| t.continuous_order.as_ref().map(|o| (&t.name, o)))
.collect::<Vec<_>>();
if !order_terms.is_empty() {
writeln!(f, "Continuous Smoothness Order:")?;
writeln!(
f,
"{:<namew$} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>20}",
"Term",
"lambda0",
"lambda1",
"lambda2",
"R",
"nu",
"kappa^2",
"status",
namew = smoothnamew
)?;
for (name, o) in order_terms {
let r_txt = o
.r_ratio
.filter(|v| v.is_finite())
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "NA".to_string());
let nu_txt =
o.nu.filter(|v| v.is_finite())
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "NA".to_string());
let kappa_txt = o
.kappa2
.filter(|v| v.is_finite())
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "NA".to_string());
let status_txt = match o.status {
ContinuousSmoothnessOrderStatus::Ok => "Ok",
ContinuousSmoothnessOrderStatus::NonMaternRegime => "NonMaternRegime",
ContinuousSmoothnessOrderStatus::FirstOrderLimit => "FirstOrderLimit",
ContinuousSmoothnessOrderStatus::IntrinsicLimit => "IntrinsicLimit",
ContinuousSmoothnessOrderStatus::UndefinedZeroLambda => "UndefinedZeroLambda",
};
writeln!(
f,
"{:<namew$} {:>10.3e} {:>10.3e} {:>10.3e} {:>10} {:>10} {:>10} {:>20}",
name,
o.lambda0,
o.lambda1,
o.lambda2,
r_txt,
nu_txt,
kappa_txt,
status_txt,
namew = smoothnamew
)?;
}
writeln!(f)?;
}
write!(
f,
"Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1"
)?;
Ok(())
}
}
pub use crate::inference::predict::{
CoefficientUncertaintyResult, InferenceCovarianceMode, MeanIntervalMethod, PredictInput,
PredictPosteriorMeanResult, PredictResult, PredictUncertaintyOptions, PredictUncertaintyResult,
PredictableModel, coefficient_uncertainty, coefficient_uncertaintywith_mode,
enrich_posterior_mean_bounds, predict_gam, predict_gam_posterior_mean,
predict_gam_posterior_meanwith_backend, predict_gam_posterior_meanwith_fit,
predict_gamwith_uncertainty,
};
pub fn fit_gamwith_heuristic_lambdas<X>(
x: X,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
heuristic_lambdas: Option<&[f64]>,
family: crate::types::LikelihoodFamily,
opts: &FitOptions,
) -> Result<UnifiedFitResult, EstimationError>
where
X: Into<DesignMatrix>,
{
fit_gamwith_heuristic_lambdas_andwarm_start(
x,
y,
weights,
offset,
s_list,
heuristic_lambdas,
None,
family,
opts,
)
}
pub(crate) fn fit_gamwith_heuristic_lambdas_andwarm_start<X>(
x: X,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
heuristic_lambdas: Option<&[f64]>,
warm_start_beta: Option<ArrayView1<'_, f64>>,
family: crate::types::LikelihoodFamily,
opts: &FitOptions,
) -> Result<UnifiedFitResult, EstimationError>
where
X: Into<DesignMatrix>,
{
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
let x = x.into();
if matches!(family, crate::types::LikelihoodFamily::BinomialMixture)
&& opts.mixture_link.is_none()
{
return Err(EstimationError::InvalidInput(
"BinomialMixture requires mixture_link specification".to_string(),
));
}
let effective_sas_link = effective_sas_link_for_family(family, opts.sas_link);
if opts.mixture_link.is_some() && opts.sas_link.is_some() {
return Err(EstimationError::InvalidInput(
"mixture_link and sas_link cannot both be set".to_string(),
));
}
let resolved_family = if opts.mixture_link.is_some() {
match family {
crate::types::LikelihoodFamily::BinomialLogit
| crate::types::LikelihoodFamily::BinomialProbit
| crate::types::LikelihoodFamily::BinomialCLogLog
| crate::types::LikelihoodFamily::BinomialMixture => {
crate::types::LikelihoodFamily::BinomialMixture
}
_ => {
return Err(EstimationError::InvalidInput(
"mixture_link is only supported for binomial families".to_string(),
));
}
}
} else if effective_sas_link.is_some() {
match family {
crate::types::LikelihoodFamily::BinomialLogit
| crate::types::LikelihoodFamily::BinomialProbit
| crate::types::LikelihoodFamily::BinomialCLogLog
| crate::types::LikelihoodFamily::BinomialSas
| crate::types::LikelihoodFamily::BinomialBetaLogistic => {
if matches!(family, crate::types::LikelihoodFamily::BinomialBetaLogistic) {
crate::types::LikelihoodFamily::BinomialBetaLogistic
} else {
crate::types::LikelihoodFamily::BinomialSas
}
}
_ => {
return Err(EstimationError::InvalidInput(
"sas_link is only supported for binomial families".to_string(),
));
}
}
} else {
family
};
if matches!(
resolved_family,
crate::types::LikelihoodFamily::RoystonParmar
) {
return Err(EstimationError::InvalidInput(
"fit_gam external design path does not support RoystonParmar; use survival training APIs".to_string(),
));
}
validate_penalty_specs(&specs, x.ncols(), "fit_gam")?;
let ext_opts = ExternalOptimOptions {
family: resolved_family,
latent_cloglog: opts.latent_cloglog,
mixture_link: opts.mixture_link.clone(),
optimize_mixture: opts.optimize_mixture,
sas_link: effective_sas_link,
optimize_sas: opts.optimize_sas,
compute_inference: opts.compute_inference,
max_iter: opts.max_iter,
tol: opts.tol,
nullspace_dims: opts.nullspace_dims.clone(),
linear_constraints: opts.linear_constraints.clone(),
firth_bias_reduction: Some(opts.firth_bias_reduction),
penalty_shrinkage_floor: opts.penalty_shrinkage_floor,
rho_prior: Default::default(),
kronecker_penalty_system: opts.kronecker_penalty_system.clone(),
kronecker_factored: opts.kronecker_factored.clone(),
};
let result = optimize_external_designwith_heuristic_lambdas_andwarm_start(
y,
weights,
&x,
offset,
specs.clone(),
heuristic_lambdas,
warm_start_beta,
&ext_opts,
)?;
let log_lambdas = result.lambdas.mapv(|v| v.max(1e-300).ln());
let edf = result
.inference
.as_ref()
.map(|inf| inf.edf_total)
.unwrap_or(0.0);
let geometry = result.inference.as_ref().map(|inf| FitGeometry {
penalized_hessian: inf.penalized_hessian.clone(),
working_weights: inf.working_weights.clone(),
working_response: inf.working_response.clone(),
});
let covariance_conditional = result
.inference
.as_ref()
.and_then(|inf| inf.beta_covariance.clone());
let covariance_corrected = result
.inference
.as_ref()
.and_then(|inf| inf.beta_covariance_corrected.clone());
let penalized_objective = result.reml_score;
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![FittedBlock {
beta: result.beta.clone(),
role: BlockRole::Mean,
edf,
lambdas: result.lambdas.clone(),
}],
log_lambdas,
lambdas: result.lambdas,
likelihood_family: Some(result.likelihood_family),
likelihood_scale: result.likelihood_scale,
log_likelihood_normalization: result.log_likelihood_normalization,
log_likelihood: result.log_likelihood,
deviance: result.deviance,
reml_score: result.reml_score,
stable_penalty_term: result.stable_penalty_term,
penalized_objective,
outer_iterations: result.iterations,
outer_converged: true,
outer_gradient_norm: result.finalgrad_norm,
standard_deviation: result.standard_deviation,
covariance_conditional,
covariance_corrected,
inference: result.inference,
fitted_link: result.fitted_link,
geometry,
block_states: Vec::new(),
pirls_status: result.pirls_status,
max_abs_eta: result.max_abs_eta,
constraint_kkt: result.constraint_kkt,
artifacts: result.artifacts,
inner_cycles: 0,
})
}
pub fn fit_gam<X>(
x: X,
y: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
family: crate::types::LikelihoodFamily,
opts: &FitOptions,
) -> Result<UnifiedFitResult, EstimationError>
where
X: Into<DesignMatrix>,
{
fit_gamwith_heuristic_lambdas(x, y, weights, offset, s_list, None, family, opts)
}
#[inline]
fn sas_log_deltaridgeweight() -> f64 {
1e-4
}
#[inline]
fn sas_log_delta_edge_barrierweight() -> f64 {
1e-2
}
#[inline]
fn sas_log_delta_bound() -> f64 {
12.0
}
#[inline]
fn sas_log_delta_edge_barriercostgrad(raw_log_delta: f64) -> (f64, f64) {
let w = sas_log_delta_edge_barrierweight();
if w <= 0.0 || !raw_log_delta.is_finite() {
return (0.0, 0.0);
}
let b = sas_log_delta_bound().max(f64::EPSILON);
let t = (raw_log_delta / b).tanh();
let one_minus_t2 = (1.0 - t * t).max(1e-12);
let cost = -w * one_minus_t2.ln();
let grad = (2.0 * w / b) * t;
(cost, grad)
}
#[inline]
fn sas_epsilon_bound() -> f64 {
8.0
}
#[inline]
fn sas_effective_epsilon(raw_epsilon: f64) -> (f64, f64) {
let bound = sas_epsilon_bound().max(f64::EPSILON);
let t = (raw_epsilon / bound).tanh();
let epsilon = bound * t;
let d_epsilon_d_raw = 1.0 - t * t;
(epsilon, d_epsilon_d_raw)
}
#[inline]
fn sas_effective_epsilon_second(raw_epsilon: f64) -> (f64, f64, f64) {
let bound = sas_epsilon_bound().max(f64::EPSILON);
let t = (raw_epsilon / bound).tanh();
let first = 1.0 - t * t;
let second = -2.0 * t * first / bound;
(bound * t, first, second)
}
#[inline]
fn sas_log_delta_edge_barriercostgradhess(raw_log_delta: f64) -> (f64, f64, f64) {
let w = sas_log_delta_edge_barrierweight();
if w <= 0.0 || !raw_log_delta.is_finite() {
return (0.0, 0.0, 0.0);
}
let b = sas_log_delta_bound().max(f64::EPSILON);
let t = (raw_log_delta / b).tanh();
let one_minus_t2 = (1.0 - t * t).max(1e-12);
let cost = -w * one_minus_t2.ln();
let grad = (2.0 * w / b) * t;
let hess = (2.0 * w / (b * b)) * one_minus_t2;
(cost, grad, hess)
}
fn materialize_link_outer_hessian(
hessian: crate::solver::outer_strategy::HessianResult,
theta_dim: usize,
) -> Result<Array2<f64>, EstimationError> {
match hessian.materialize_dense() {
Ok(Some(h)) => {
if h.nrows() != theta_dim || h.ncols() != theta_dim {
return Err(EstimationError::InvalidInput(format!(
"unified evaluator Hessian shape {}x{} != theta_dim {}",
h.nrows(),
h.ncols(),
theta_dim
)));
}
Ok(h)
}
Ok(None) => Err(EstimationError::InvalidInput(
"unified evaluator returned no analytic Hessian in ValueGradientHessian mode"
.to_string(),
)),
Err(err) => Err(EstimationError::InvalidInput(format!(
"failed to materialize analytic link Hessian: {err}"
))),
}
}
pub fn evaluate_externalgradient<X>(
y: ArrayView1<'_, f64>,
w: ArrayView1<'_, f64>,
x: X,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
opts: &ExternalOptimOptions,
rho: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError>
where
X: Into<DesignMatrix>,
{
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
let x = x.into();
if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
return Err(EstimationError::InvalidInput(message));
}
let p = x.ncols();
validate_penalty_specs(&specs, p, "evaluate_externalgradient")?;
let (canonical, active_nullspace_dims) = crate::construction::canonicalize_penalty_specs(
&specs,
&opts.nullspace_dims,
p,
"evaluate_externalgradient",
)?;
if rho.len() != active_nullspace_dims.len() {
return Err(EstimationError::InvalidInput(format!(
"rho dimension mismatch: rho_dim={}, active_penalties={}",
rho.len(),
active_nullspace_dims.len()
)));
}
let (cfg, _) = resolved_external_config(opts)?;
let y_o = y.to_owned();
let w_o = w.to_owned();
let offset_o = offset.to_owned();
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
let x_fit = conditioning.apply_to_design(&x);
let fit_linear_constraints =
conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
let mut reml_state = RemlState::newwith_offset(
y_o.view(),
x_fit,
w_o.view(),
offset_o.view(),
canonical,
p,
&cfg,
Some(active_nullspace_dims),
None,
fit_linear_constraints,
)?;
reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
reml_state.set_rho_prior(opts.rho_prior.clone());
reml_state.set_link_states(
cfg.link_kind.mixture_state().cloned(),
cfg.link_kind.sas_state().copied(),
);
reml_state.compute_gradient(rho)
}
pub fn evaluate_externalcost_andridge<X>(
y: ArrayView1<'_, f64>,
w: ArrayView1<'_, f64>,
x: X,
offset: ArrayView1<'_, f64>,
s_list: &[BlockwisePenalty],
opts: &ExternalOptimOptions,
rho: &Array1<f64>,
) -> Result<(f64, f64), EstimationError>
where
X: Into<DesignMatrix>,
{
let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
let x = x.into();
if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
return Err(EstimationError::InvalidInput(message));
}
let p = x.ncols();
validate_penalty_specs(&specs, p, "evaluate_externalcost_andridge")?;
let (canonical, active_nullspace_dims) = crate::construction::canonicalize_penalty_specs(
&specs,
&opts.nullspace_dims,
p,
"evaluate_externalcost_andridge",
)?;
if rho.len() != active_nullspace_dims.len() {
return Err(EstimationError::InvalidInput(format!(
"rho dimension mismatch: rho_dim={}, active_penalties={}",
rho.len(),
active_nullspace_dims.len()
)));
}
let (cfg, _) = resolved_external_config(opts)?;
let y_o = y.to_owned();
let w_o = w.to_owned();
let offset_o = offset.to_owned();
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
let x_fit = conditioning.apply_to_design(&x);
let fit_linear_constraints =
conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
let mut reml_state = RemlState::newwith_offset(
y_o.view(),
x_fit,
w_o.view(),
offset_o.view(),
canonical,
p,
&cfg,
Some(active_nullspace_dims),
None,
fit_linear_constraints,
)?;
reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
reml_state.set_rho_prior(opts.rho_prior.clone());
reml_state.set_link_states(
cfg.link_kind.mixture_state().cloned(),
cfg.link_kind.sas_state().copied(),
);
let cost = reml_state.compute_cost(rho)?;
let ridge = reml_state.last_ridge_used().unwrap_or(0.0);
Ok((cost, ridge))
}
#[cfg(test)]
mod estimate_policy_tests {
use super::reml::hyper::link_binomial_aux;
use super::*;
use crate::linalg::utils::{StableSolver, max_abs_diag};
use crate::mixture_link::{sas_inverse_link_jet, sas_inverse_link_jetwith_param_partials};
use crate::types::LikelihoodFamily;
use ndarray::{Array1, Array2, array};
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
#[test]
fn sas_raw_epsilon_hessian_chain_rule_matches_chained_gradient_slope() {
let raw0 = 1.3_f64;
let (eps0, d1, d2) = sas_effective_epsilon_second(raw0);
let g0 = array![0.4, -0.7, 0.2];
let h_eff = array![[2.0, 0.3, -0.1], [0.3, 1.5, 0.25], [-0.1, 0.25, 0.8]];
let analytic = h_eff[[0, 0]] * d1 * d1 + g0[0] * d2;
let chained_grad = |raw: f64| {
let (eps, deps_draw) = sas_effective_epsilon(raw);
let delta = array![eps - eps0, 0.0, 0.0];
let g_eff = &g0 + &h_eff.dot(&delta);
g_eff[0] * deps_draw
};
let h = 1e-6;
let fd = (chained_grad(raw0 + h) - chained_grad(raw0 - h)) / (2.0 * h);
assert!(
(analytic - fd).abs() < 2e-8,
"SAS raw epsilon Hessian chain rule mismatch: analytic={analytic:.12e} fd={fd:.12e}"
);
}
#[test]
fn sas_log_delta_barrier_hessian_matches_gradient_slope() {
let raw = 2.25_f64;
let (_, _, analytic_hess) = sas_log_delta_edge_barriercostgradhess(raw);
let h = 1e-6;
let (_, gp) = sas_log_delta_edge_barriercostgrad(raw + h);
let (_, gm) = sas_log_delta_edge_barriercostgrad(raw - h);
let fd = (gp - gm) / (2.0 * h);
assert!(
(analytic_hess - fd).abs() < 2e-9,
"SAS log-delta barrier Hessian mismatch: analytic={analytic_hess:.12e} fd={fd:.12e}"
);
}
fn decode_invariant_test_fit() -> UnifiedFitResult {
UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
blocks: vec![FittedBlock {
beta: array![0.25, -0.5],
role: BlockRole::Mean,
edf: 1.5,
lambdas: array![0.2, 0.8],
}],
log_lambdas: array![0.2_f64.max(1e-300).ln(), 0.8_f64.max(1e-300).ln()],
lambdas: array![0.2, 0.8],
likelihood_family: Some(LikelihoodFamily::GaussianIdentity),
likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
log_likelihood_normalization: LogLikelihoodNormalization::Full,
log_likelihood: -1.2,
deviance: 2.4,
reml_score: 0.7,
stable_penalty_term: 0.3,
penalized_objective: 2.2,
outer_iterations: 3,
outer_converged: true,
outer_gradient_norm: 0.05,
standard_deviation: 1.1,
covariance_conditional: Some(array![[1.0, 0.1], [0.1, 2.0]]),
covariance_corrected: Some(array![[1.2, 0.1], [0.1, 2.2]]),
inference: Some(FitInference {
edf_by_block: vec![0.6, 0.9],
edf_total: 1.5,
smoothing_correction: Some(array![[0.2, 0.0], [0.0, 0.2]]),
penalized_hessian: array![[2.0, 0.1], [0.1, 3.0]],
working_weights: array![1.0, 0.5, 0.75],
working_response: array![0.1, 0.2, 0.3],
reparam_qs: Some(array![[1.0, 0.0], [0.0, 1.0]]),
beta_covariance: Some(array![[1.0, 0.1], [0.1, 2.0]]),
beta_standard_errors: Some(array![1.0, 2.0_f64.sqrt()]),
beta_covariance_corrected: Some(array![[1.2, 0.1], [0.1, 2.2]]),
beta_standard_errors_corrected: Some(array![1.2_f64.sqrt(), 2.2_f64.sqrt()]),
bias_correction_beta: None,
}),
fitted_link: FittedLinkState::Standard(None),
geometry: Some(FitGeometry {
penalized_hessian: array![[2.0, 0.1], [0.1, 3.0]],
working_weights: array![1.0, 0.5, 0.75],
working_response: array![0.1, 0.2, 0.3],
}),
block_states: Vec::new(),
pirls_status: crate::pirls::PirlsStatus::Converged,
max_abs_eta: 1.25,
constraint_kkt: None,
artifacts: FitArtifacts::default(),
inner_cycles: 0,
})
.expect("construct decode invariant test fit")
}
#[test]
fn resolve_external_family_rejects_unsupported_firth_request() {
let err = resolve_external_family(LikelihoodFamily::PoissonLog, Some(true))
.expect_err("Poisson fitting should reject unsupported Firth requests explicitly");
assert!(
err.to_string()
.contains("firth_bias_reduction is currently implemented only for"),
"unexpected error: {err}"
);
}
#[test]
fn unified_fit_decode_validation_rejects_beta_drift_from_blocks() {
let fit = decode_invariant_test_fit();
let mut payload = serde_json::to_value(&fit).expect("serialize fit");
payload["beta"] = serde_json::to_value(&Array1::from(vec![9.0_f64, 8.0_f64]))
.expect("serialize drifted beta");
let decoded: UnifiedFitResult =
serde_json::from_value(payload).expect("deserialize corrupted fit");
let err = decoded
.validate_numeric_finiteness()
.expect_err("beta drift should fail validation");
assert!(
err.to_string()
.contains("decoded beta must match coefficient blocks"),
"unexpected error: {err}"
);
}
#[test]
fn unified_fit_validation_rejects_edf_smoothing_parameter_drift() {
let mut fit = decode_invariant_test_fit();
fit.inference
.as_mut()
.expect("test fit has inference")
.edf_by_block = vec![1.5];
let err = fit
.validate_numeric_finiteness()
.expect_err("EDF entries should align with smoothing parameters");
assert!(
err.to_string()
.contains("EDF smoothing-parameter count mismatch"),
"unexpected error: {err}"
);
}
#[test]
fn unified_fit_validation_accepts_persisted_log_lambda_roundoff() {
let mut fit = decode_invariant_test_fit();
fit.log_lambdas[0] += 5e-14;
fit.validate_numeric_finiteness()
.expect("sub-ulp persisted log-lambda roundoff should remain valid");
}
#[test]
fn unified_fit_validation_rejects_material_log_lambda_drift() {
let mut fit = decode_invariant_test_fit();
fit.log_lambdas[0] += 1e-4;
let err = fit
.validate_numeric_finiteness()
.expect_err("material log-lambda drift should fail validation");
assert!(
err.to_string().contains("log_lambdas must equal"),
"unexpected error: {err}"
);
}
#[test]
fn unified_fit_decode_validation_rejects_geometry_drift_from_inference() {
let fit = decode_invariant_test_fit();
let mut payload = serde_json::to_value(&fit).expect("serialize fit");
let drifted_hessian: Array2<f64> = array![[4.0, 0.0], [0.0, 5.0]];
payload["geometry"]["penalized_hessian"] =
serde_json::to_value(&drifted_hessian).expect("serialize drifted penalized Hessian");
let decoded: UnifiedFitResult =
serde_json::from_value(payload).expect("deserialize corrupted fit");
let err = decoded
.validate_numeric_finiteness()
.expect_err("geometry drift should fail validation");
assert!(
err.to_string()
.contains("geometry penalized Hessian must match inference.penalized_hessian"),
"unexpected error: {err}"
);
}
fn build_tiny_design(n: usize) -> Array2<f64> {
let mut x = Array2::<f64>::zeros((n, 3));
for i in 0..n {
let t = (i as f64 + 0.5) / n as f64;
let x1 = -1.5 + 3.0 * t;
x[[i, 0]] = 1.0;
x[[i, 1]] = x1;
x[[i, 2]] = (2.1 * x1).sin();
}
x
}
fn one_penalty_non_intercept(p: usize) -> Vec<Array2<f64>> {
let mut s = Array2::<f64>::zeros((p, p));
for j in 1..p {
s[[j, j]] = 1.0;
}
vec![s]
}
fn dense_penalty_test_inputs(
s_list: &[Array2<f64>],
p: usize,
context: &str,
) -> (
Vec<PenaltySpec>,
Vec<crate::construction::CanonicalPenalty>,
Vec<usize>,
) {
let penalty_specs = s_list
.iter()
.cloned()
.map(PenaltySpec::Dense)
.collect::<Vec<_>>();
let (canonical_penalties, active_nullspace_dims) =
crate::construction::canonicalize_penalty_specs(
&penalty_specs,
&vec![1; penalty_specs.len()],
p,
context,
)
.expect("canonicalize dense penalties");
(penalty_specs, canonical_penalties, active_nullspace_dims)
}
#[test]
fn sas_beta_raw_epsilon_sensitivity_matchesfd_at_seed19() {
let seed = 19_u64;
let n = 20usize;
let x = build_tiny_design(n);
let w = Array1::<f64>::ones(n);
let offset = Array1::<f64>::zeros(n);
let s_list = one_penalty_non_intercept(x.ncols());
let true_beta = array![-0.2, 0.9, -0.4];
let eta_true = x.dot(&true_beta);
let eps_true = 0.25;
let ld_true = -0.20;
let p = eta_true.mapv(|e| sas_inverse_link_jet(e, eps_true, ld_true).mu);
let mut rng = StdRng::seed_from_u64(seed);
let y = p.mapv(|pi| if rng.random::<f64>() < pi { 1.0 } else { 0.0 });
let opts = ExternalOptimOptions {
family: LikelihoodFamily::BinomialSas,
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: Some(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
}),
optimize_sas: true,
compute_inference: true,
max_iter: 80,
tol: 1e-7,
nullspace_dims: vec![1],
linear_constraints: None,
firth_bias_reduction: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let theta = array![0.10, 0.12, -0.18];
let (cfg, effective_sas_link) = resolved_external_config(&opts).expect("cfg");
assert!(effective_sas_link.is_some());
let (penalty_specs, canonical_penalties, active_nullspace_dims) = dense_penalty_test_inputs(
&s_list,
x.ncols(),
"sas_beta_raw_epsilon_sensitivity_matchesfd_at_seed19",
);
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(
&DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone())),
&penalty_specs,
);
let x_fit = conditioning.apply_to_design(&DesignMatrix::Dense(
crate::matrix::DenseDesignMatrix::from(x.clone()),
));
let mut reml_state = RemlState::newwith_offset(
y.view(),
x_fit,
w.view(),
offset.view(),
canonical_penalties.clone(),
x.ncols(),
&cfg,
Some(active_nullspace_dims.clone()),
None,
None,
)
.expect("reml_state");
let rho = theta.slice(s![..1]).to_owned();
let (epsilon_eff, d_eps_d_raw) = sas_effective_epsilon(theta[1]);
let sas_state = state_from_sasspec(SasLinkSpec {
initial_epsilon: epsilon_eff,
initial_log_delta: theta[2],
})
.expect("sas state");
reml_state.set_link_states(None, Some(sas_state));
let pirls_result = reml_state
.obtain_eval_bundle(&rho)
.map(|b| b.pirls_result.clone())
.expect("pirls_result");
let eta = &pirls_result.final_eta;
let x_t = &pirls_result.x_transformed;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let du_vec: Vec<f64> = (0..eta.len())
.into_par_iter()
.map(|i| {
let jets = sas_inverse_link_jetwith_param_partials(
eta[i],
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
let d1 = jets.jet.d1;
let dmu = jets.djet_depsilon.mu;
let dd1 = jets.djet_depsilon.d1;
aux.a2 * dmu * d1 + aux.a1 * dd1
})
.collect();
let du_by_eps = Array1::from_vec(du_vec);
let score_at = |raw_eps: f64| -> Array1<f64> {
let (eps_eff, _) = sas_effective_epsilon(raw_eps);
let sas_state = state_from_sasspec(SasLinkSpec {
initial_epsilon: eps_eff,
initial_log_delta: theta[2],
})
.expect("score sas state");
let out_vec: Vec<f64> = (0..eta.len())
.into_par_iter()
.map(|i| {
let jets = sas_inverse_link_jetwith_param_partials(
eta[i],
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let d1 = jets.jet.d1;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
aux.a1 * d1
})
.collect();
Array1::from_vec(out_vec)
};
let score_p = score_at(theta[1] + 1e-4 * (1.0 + theta[1].abs()));
let score_m = score_at(theta[1] - 1e-4 * (1.0 + theta[1].abs()));
let fd_du_raw = (&score_p - &score_m).mapv(|v| v / (2.0 * 1e-4 * (1.0 + theta[1].abs())));
let du_raw = du_by_eps.mapv(|v| v * d_eps_d_raw);
crate::testing::assert_matrix_derivativefd(
&fd_du_raw.insert_axis(Axis(1)),
&du_raw.insert_axis(Axis(1)),
2e-3,
"sas du / d raw epsilon at fixed eta",
);
let rhs = x_t.transpose_vector_multiply(&du_by_eps);
let neg_du_deta_vec: Vec<f64> = (0..eta.len())
.into_par_iter()
.map(|i| {
let jets = sas_inverse_link_jetwith_param_partials(
eta[i].clamp(-30.0, 30.0),
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let d1 = jets.jet.d1;
let d2 = jets.jet.d2;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
-(aux.a2 * d1 * d1 + aux.a1 * d2)
})
.collect();
let neg_du_deta = Array1::from_vec(neg_du_deta_vec);
let score_beta_jacobian = {
let x_dense = x_t.to_dense();
let diag_v = Array2::from_diag(&neg_du_deta);
let mut j = x_dense.t().dot(&diag_v).dot(&x_dense);
for ((r, c), v) in pirls_result.reparam_result.s_transformed.indexed_iter() {
j[[r, c]] += v;
}
if pirls_result.ridge_used > 0.0 {
for d in 0..j.nrows() {
j[[d, d]] += pirls_result.ridge_used;
}
}
j
};
let stable_solver = StableSolver::new("sas dbeta exact test");
let mut dbeta_exact = stable_solver
.solvevectorwithridge_retries(
&score_beta_jacobian,
&rhs,
max_abs_diag(&score_beta_jacobian) * 1e-12,
)
.expect("observed-jacobian solve for dbeta");
dbeta_exact *= d_eps_d_raw;
let fd_h = 1e-4 * (1.0 + theta[1].abs());
let beta_at = |raw_eps: f64| -> Array1<f64> {
let mut state = RemlState::newwith_offset(
y.view(),
conditioning.apply_to_design(&DesignMatrix::Dense(
crate::matrix::DenseDesignMatrix::from(x.clone()),
)),
w.view(),
offset.view(),
canonical_penalties.clone(),
x.ncols(),
&cfg,
Some(active_nullspace_dims.clone()),
None,
None,
)
.expect("fd state");
let (eps_eff, _) = sas_effective_epsilon(raw_eps);
let sas_state = state_from_sasspec(SasLinkSpec {
initial_epsilon: eps_eff,
initial_log_delta: theta[2],
})
.expect("fd sas state");
state.set_link_states(None, Some(sas_state));
let pirls = state
.obtain_eval_bundle(&rho)
.map(|b| b.pirls_result.clone())
.expect("fd pirls");
pirls.beta_transformed.as_ref().clone()
};
let beta_p = beta_at(theta[1] + fd_h);
let beta_m = beta_at(theta[1] - fd_h);
let fd_beta = (&beta_p - &beta_m).mapv(|v| v / (2.0 * fd_h));
crate::testing::assert_matrix_derivativefd(
&fd_beta.insert_axis(Axis(1)),
&dbeta_exact.insert_axis(Axis(1)),
2e-3,
"sas observed-jacobian dbeta / d raw epsilon",
);
}
#[test]
fn sas_true_score_beta_jacobian_matchesfd_at_seed19() {
let seed = 19_u64;
let n = 20usize;
let x = build_tiny_design(n);
let w = Array1::<f64>::ones(n);
let offset = Array1::<f64>::zeros(n);
let s_list = one_penalty_non_intercept(x.ncols());
let true_beta = array![-0.2, 0.9, -0.4];
let eta_true = x.dot(&true_beta);
let eps_true = 0.25;
let ld_true = -0.20;
let p = eta_true.mapv(|e| sas_inverse_link_jet(e, eps_true, ld_true).mu);
let mut rng = StdRng::seed_from_u64(seed);
let y = p.mapv(|pi| if rng.random::<f64>() < pi { 1.0 } else { 0.0 });
let opts = ExternalOptimOptions {
family: LikelihoodFamily::BinomialSas,
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: Some(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
}),
optimize_sas: true,
compute_inference: true,
max_iter: 80,
tol: 1e-7,
nullspace_dims: vec![1],
linear_constraints: None,
firth_bias_reduction: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let theta = array![0.10, 0.12, -0.18];
let (cfg, effective_sas_link) = resolved_external_config(&opts).expect("cfg");
assert!(effective_sas_link.is_some());
let (penalty_specs, canonical_penalties, active_nullspace_dims) = dense_penalty_test_inputs(
&s_list,
x.ncols(),
"sas_true_score_beta_jacobian_matchesfd_at_seed19",
);
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(
&DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone())),
&penalty_specs,
);
let x_fit = conditioning.apply_to_design(&DesignMatrix::Dense(
crate::matrix::DenseDesignMatrix::from(x.clone()),
));
let mut reml_state = RemlState::newwith_offset(
y.view(),
x_fit,
w.view(),
offset.view(),
canonical_penalties,
x.ncols(),
&cfg,
Some(active_nullspace_dims),
None,
None,
)
.expect("reml_state");
let rho = theta.slice(s![..1]).to_owned();
let (epsilon_eff, _) = sas_effective_epsilon(theta[1]);
let sas_state = state_from_sasspec(SasLinkSpec {
initial_epsilon: epsilon_eff,
initial_log_delta: theta[2],
})
.expect("sas state");
reml_state.set_link_states(None, Some(sas_state));
let pirls_result = reml_state
.obtain_eval_bundle(&rho)
.map(|b| b.pirls_result.clone())
.expect("pirls_result");
let beta0 = pirls_result.beta_transformed.as_ref().clone();
let s_transformed = pirls_result.reparam_result.s_transformed.clone();
let ridge = pirls_result.ridge_used;
let x_dense = match &pirls_result.x_transformed {
DesignMatrix::Dense(x_dense) => x_dense.to_dense(),
DesignMatrix::Sparse(_) => {
panic!("expected dense transformed design in seed-19 SAS test")
}
};
let gradient_at = |beta: &Array1<f64>| -> Array1<f64> {
let mut eta = offset.clone();
eta += &x_dense.dot(beta);
let mut u = Array1::<f64>::zeros(eta.len());
for i in 0..eta.len() {
let jets = sas_inverse_link_jetwith_param_partials(
eta[i].clamp(-30.0, 30.0),
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let d1 = jets.jet.d1;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
u[i] = aux.a1 * d1;
}
let mut g = -x_dense.t().dot(&u);
g += &s_transformed.dot(beta);
if ridge > 0.0 {
g += &beta.mapv(|v| ridge * v);
}
g
};
let mut analytic_j = Array2::<f64>::zeros((beta0.len(), beta0.len()));
let mut eta0 = offset.clone();
eta0 += &x_dense.dot(&beta0);
let mut neg_du_deta = Array1::<f64>::zeros(eta0.len());
for i in 0..eta0.len() {
let jets = sas_inverse_link_jetwith_param_partials(
eta0[i].clamp(-30.0, 30.0),
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let d1 = jets.jet.d1;
let d2 = jets.jet.d2;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
neg_du_deta[i] = -(aux.a2 * d1 * d1 + aux.a1 * d2);
}
let weighted_x = &x_dense * &neg_du_deta.insert_axis(Axis(1));
analytic_j.assign(&x_dense.t().dot(&weighted_x));
analytic_j += &s_transformed;
if ridge > 0.0 {
for j in 0..analytic_j.nrows() {
analytic_j[[j, j]] += ridge;
}
}
let mut fd_j = Array2::<f64>::zeros((beta0.len(), beta0.len()));
for j in 0..beta0.len() {
let h = 1e-5 * (1.0 + beta0[j].abs());
let mut beta_p = beta0.clone();
let mut beta_m = beta0.clone();
beta_p[j] += h;
beta_m[j] -= h;
let g_p = gradient_at(&beta_p);
let g_m = gradient_at(&beta_m);
let fd_col = (&g_p - &g_m).mapv(|v| v / (2.0 * h));
fd_j.column_mut(j).assign(&fd_col);
}
crate::testing::assert_matrix_derivativefd(
&fd_j,
&analytic_j,
2e-3,
"sas true beta-score jacobian at seed-19",
);
}
#[test]
fn sas_pirlshessian_matches_true_score_jacobian_at_seed19() {
let seed = 19_u64;
let n = 20usize;
let x = build_tiny_design(n);
let w = Array1::<f64>::ones(n);
let offset = Array1::<f64>::zeros(n);
let s_list = one_penalty_non_intercept(x.ncols());
let true_beta = array![-0.2, 0.9, -0.4];
let eta_true = x.dot(&true_beta);
let eps_true = 0.25;
let ld_true = -0.20;
let p = eta_true.mapv(|e| sas_inverse_link_jet(e, eps_true, ld_true).mu);
let mut rng = StdRng::seed_from_u64(seed);
let y = p.mapv(|pi| if rng.random::<f64>() < pi { 1.0 } else { 0.0 });
let opts = ExternalOptimOptions {
family: LikelihoodFamily::BinomialSas,
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: Some(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
}),
optimize_sas: true,
compute_inference: true,
max_iter: 80,
tol: 1e-7,
nullspace_dims: vec![1],
linear_constraints: None,
firth_bias_reduction: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
};
let theta = array![0.10, 0.12, -0.18];
let (cfg, effective_sas_link) = resolved_external_config(&opts).expect("cfg");
assert!(effective_sas_link.is_some());
let (penalty_specs, canonical_penalties, active_nullspace_dims) = dense_penalty_test_inputs(
&s_list,
x.ncols(),
"sas_pirlshessian_matches_true_score_jacobian_at_seed19",
);
let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(
&DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone())),
&penalty_specs,
);
let x_fit = conditioning.apply_to_design(&DesignMatrix::Dense(
crate::matrix::DenseDesignMatrix::from(x.clone()),
));
let mut reml_state = RemlState::newwith_offset(
y.view(),
x_fit,
w.view(),
offset.view(),
canonical_penalties,
x.ncols(),
&cfg,
Some(active_nullspace_dims),
None,
None,
)
.expect("reml_state");
let rho = theta.slice(s![..1]).to_owned();
let (epsilon_eff, _) = sas_effective_epsilon(theta[1]);
let sas_state = state_from_sasspec(SasLinkSpec {
initial_epsilon: epsilon_eff,
initial_log_delta: theta[2],
})
.expect("sas state");
reml_state.set_link_states(None, Some(sas_state));
let pirls_result = reml_state
.obtain_eval_bundle(&rho)
.map(|b| b.pirls_result.clone())
.expect("pirls_result");
let beta0 = pirls_result.beta_transformed.as_ref().clone();
let s_transformed = pirls_result.reparam_result.s_transformed.clone();
let ridge = pirls_result.ridge_used;
let x_dense = match &pirls_result.x_transformed {
DesignMatrix::Dense(x_dense) => x_dense.to_dense(),
DesignMatrix::Sparse(_) => {
panic!("expected dense transformed design in seed-19 SAS test")
}
};
let mut eta0 = offset.clone();
eta0 += &x_dense.dot(&beta0);
let mut neg_du_deta = Array1::<f64>::zeros(eta0.len());
for i in 0..eta0.len() {
let jets = sas_inverse_link_jetwith_param_partials(
eta0[i].clamp(-30.0, 30.0),
sas_state.epsilon,
sas_state.log_delta,
);
let mu = jets.jet.mu;
let d1 = jets.jet.d1;
let d2 = jets.jet.d2;
let aux = link_binomial_aux(y[i], w[i].max(0.0), mu);
neg_du_deta[i] = -(aux.a2 * d1 * d1 + aux.a1 * d2);
}
let weighted_x = &x_dense * &neg_du_deta.insert_axis(Axis(1));
let mut true_jacobian = x_dense.t().dot(&weighted_x);
true_jacobian += &s_transformed;
if ridge > 0.0 {
for j in 0..true_jacobian.nrows() {
true_jacobian[[j, j]] += ridge;
}
}
let pht_dense = pirls_result.penalized_hessian_transformed.to_dense();
let max_abs_diff = true_jacobian
.iter()
.zip(pht_dense.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs_diff <= 2e-3,
"expected PIRLS Hessian to match the true SAS score Jacobian, got max_abs_diff={max_abs_diff:.3e}"
);
}
#[test]
fn link_binomial_aux_stay_finite_for_saturated_sas_probabilities() {
let saturated_cases = [
(
0.0,
sas_inverse_link_jetwith_param_partials(-30.0, 0.0, 12.0)
.jet
.mu,
),
(
1.0,
sas_inverse_link_jetwith_param_partials(30.0, 0.0, 12.0)
.jet
.mu,
),
];
for (yi, mu) in saturated_cases {
let aux = link_binomial_aux(yi, 1.0, mu);
assert!(aux.a1.is_finite(), "a1 must be finite for yi={yi} mu={mu}");
assert!(aux.a2.is_finite(), "a2 must be finite for yi={yi} mu={mu}");
assert!(
aux.variance.is_finite() && aux.variance > 0.0,
"variance must be finite and positive for yi={yi} mu={mu}"
);
}
}
}
#[cfg(test)]
mod continuous_order_tests {
use super::*;
#[test]
fn continuous_order_formula_matches_closed_form() {
let out = compute_continuous_smoothness_order([2.0, 10.0, 3.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::Ok);
let r = out.r_ratio.expect("R");
let nu = out.nu.expect("nu");
let kappa2 = out.kappa2.expect("kappa2");
assert!((r - (100.0 / 6.0)).abs() < 1e-12);
assert!((nu - (r / (r - 2.0))).abs() < 1e-12);
assert!((kappa2 - (10.0 / ((r - 2.0) * 3.0))).abs() < 1e-12);
}
#[test]
fn continuous_order_unscales_lambdas_exactly_by_ck() {
let out = compute_continuous_smoothness_order([6.0, 15.0, 9.0], [3.0, 5.0, 9.0], 1e-12);
assert!((out.lambda0 - 2.0).abs() < 1e-12);
assert!((out.lambda1 - 3.0).abs() < 1e-12);
assert!((out.lambda2 - 1.0).abs() < 1e-12);
}
#[test]
fn continuous_order_invalid_ck_is_guarded() {
let out = compute_continuous_smoothness_order([1.0, 1.0, 1.0], [1.0, 0.0, 1.0], 1e-12);
assert_eq!(
out.status,
ContinuousSmoothnessOrderStatus::UndefinedZeroLambda
);
assert!(out.r_ratio.is_none());
}
#[test]
fn continuous_order_is_invariant_to_penalty_normalization_reversal() {
let base = compute_continuous_smoothness_order([2.0, 10.0, 3.0], [1.0, 1.0, 1.0], 1e-12);
let scaled = compute_continuous_smoothness_order(
[2.0 * 4.0, 10.0 * 0.5, 3.0 * 8.0],
[4.0, 0.5, 8.0],
1e-12,
);
assert_eq!(base.status, ContinuousSmoothnessOrderStatus::Ok);
assert_eq!(scaled.status, ContinuousSmoothnessOrderStatus::Ok);
assert!((base.r_ratio.unwrap() - scaled.r_ratio.unwrap()).abs() < 1e-12);
assert!((base.nu.unwrap() - scaled.nu.unwrap()).abs() < 1e-12);
assert!((base.kappa2.unwrap() - scaled.kappa2.unwrap()).abs() < 1e-12);
}
#[test]
fn continuous_order_flags_non_matern_regimewhen_r_le_4() {
let out = compute_continuous_smoothness_order([1.0, 1.0, 1.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::NonMaternRegime);
assert!(out.nu.is_none());
assert!(out.kappa2.is_none());
}
#[test]
fn continuous_order_reports_effective_nu_kappa_in_non_matern_bandwhen_r_gt_2() {
let out = compute_continuous_smoothness_order([1.0, 3.0, 3.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::NonMaternRegime);
let r = out.r_ratio.expect("R");
assert!(r > 2.0 && r < 4.0);
assert!(out.nu.is_some());
assert!(out.kappa2.is_some());
}
#[test]
fn continuous_order_boundary_r_equals_four_is_matern_square_case() {
let out = compute_continuous_smoothness_order([1.0, 2.0, 1.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::Ok);
let nu = out.nu.expect("nu");
assert!((nu - 2.0).abs() < 1e-12);
}
#[test]
fn continuous_order_guardszero_or_nearzero_lambda() {
let out = compute_continuous_smoothness_order([0.0, 1.0, 1.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::IntrinsicLimit);
assert!(out.r_ratio.is_none());
}
#[test]
fn continuous_order_first_order_limitwhen_lambda2_collapses() {
let out = compute_continuous_smoothness_order([2.0, 4.0, 1e-20], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::FirstOrderLimit);
assert_eq!(out.nu, Some(1.0));
let k2 = out.kappa2.expect("kappa2");
assert!((k2 - 0.5).abs() < 1e-12);
}
#[test]
fn continuous_order_intrinsic_limitwhen_lambda0_collapses() {
let out = compute_continuous_smoothness_order([1e-20, 4.0, 2.0], [1.0, 1.0, 1.0], 1e-12);
assert_eq!(out.status, ContinuousSmoothnessOrderStatus::IntrinsicLimit);
assert_eq!(out.nu, Some(1.0));
assert_eq!(out.kappa2, Some(0.0));
}
#[test]
fn continuous_order_is_only_defined_for_three_penalties_per_term() {
let ok =
try_compute_continuous_smoothness_order(&[2.0, 10.0, 3.0], &[1.0, 1.0, 1.0], 1e-12);
let two = try_compute_continuous_smoothness_order(&[2.0, 10.0], &[1.0, 1.0], 1e-12);
let four = try_compute_continuous_smoothness_order(
&[2.0, 10.0, 3.0, 7.0],
&[1.0, 1.0, 1.0, 1.0],
1e-12,
);
assert!(ok.is_some());
assert!(two.is_none());
assert!(four.is_none());
}
}
#[path = "reml/mod.rs"]
pub(crate) mod reml;