use crate::construction::{KroneckerReparamResult, ReparamResult};
use crate::estimate::EstimationError;
use crate::estimate::reml::FirthDenseOperator;
use crate::faer_ndarray::{
FaerArrayView, FaerCholesky, FaerEigh, FaerLinalgError, array1_to_col_matmut, array2_to_matmut,
fast_ab, fast_atb, fast_atv, fast_av_into,
};
use crate::linalg::sparse_exact::{
factorize_sparse_spd, solve_sparse_spd, sparse_symmetric_upper_matvec_public,
};
use crate::linalg::utils::{StableSolver, boundary_hit_step_fraction};
use crate::matrix::{DesignMatrix, LinearOperator, ReparamOperator, SymmetricMatrix};
use crate::mixture_link::{InverseLinkJet as MixtureInverseLinkJet, logit_inverse_link_jet5};
use crate::probability::standard_normal_quantile;
use crate::solver::active_set;
use crate::types::{Coefficients, LinearPredictor, LogSmoothingParamsView};
use crate::types::{
GlmLikelihoodFamily, GlmLikelihoodSpec, InverseLink, LinkFunction, MixtureLinkState,
RidgePassport, RidgePolicy, SasLinkState,
};
use dyn_stack::{MemBuffer, MemStack};
use faer::linalg::matmul::matmul;
use faer::sparse::linalg::matmul::{
SparseMatMulInfo, sparse_sparse_matmul_numeric, sparse_sparse_matmul_numeric_scratch,
sparse_sparse_matmul_symbolic,
};
use faer::sparse::{SparseColMat, Triplet};
use faer::sparse::{
SparseColMatMut, SparseColMatRef, SparseRowMat, SymbolicSparseColMat, SymbolicSparseColMatRef,
};
use faer::{Accum, Par, Side, Unbind, get_global_parallelism};
use log;
use ndarray::{
Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2, ShapeBuilder, Zip, s,
};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use statrs::function::gamma::{digamma, ln_gamma};
use faer::linalg::cholesky::llt::factor::LltParams;
use faer::{Auto, Spec};
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
pub use crate::solver::active_set::{ConstraintKktDiagnostics, LinearInequalityConstraints};
#[inline]
fn array1_is_finite(values: &Array1<f64>) -> bool {
values.iter().all(|v| v.is_finite())
}
#[inline]
fn array2_is_finite(values: &Array2<f64>) -> bool {
values.iter().all(|v| v.is_finite())
}
const GAMMA_SHAPE_MIN: f64 = 1e-8;
const GAMMA_SHAPE_MAX: f64 = 1e12;
const GAMMA_SHAPE_TARGET_TOL: f64 = 1e-12;
const PIRLS_ETA_ABS_CAP: f64 = 40.0;
#[inline]
fn gamma_shape_score(shape: f64, target: f64) -> f64 {
shape.ln() - digamma(shape) - target
}
fn estimate_gamma_shape_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
const EPS: f64 = 1e-12;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_target, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let yi = y[i].max(EPS);
let mui = eta[i].clamp(-700.0, 700.0).exp().max(EPS);
let ratio = yi / mui;
(wi * (ratio - ratio.ln() - 1.0), wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(t1, w1), (t2, w2)| (t1 + t2, w1 + w2),
);
if total_weight <= 0.0 {
return 1.0;
}
let target = (weighted_target / total_weight).max(0.0);
if target <= GAMMA_SHAPE_TARGET_TOL {
return GAMMA_SHAPE_MAX;
}
let discriminant = (target - 3.0) * (target - 3.0) + 24.0 * target;
let approx = ((3.0 - target) + discriminant.sqrt()) / (12.0 * target);
let mut lo = GAMMA_SHAPE_MIN;
let mut hi = approx.max(1.0);
while hi < GAMMA_SHAPE_MAX && gamma_shape_score(hi, target) > 0.0 {
hi = (hi * 2.0).min(GAMMA_SHAPE_MAX);
}
if gamma_shape_score(hi, target) > 0.0 {
return GAMMA_SHAPE_MAX;
}
for _ in 0..80 {
let mid = 0.5 * (lo + hi);
if gamma_shape_score(mid, target) > 0.0 {
lo = mid;
} else {
hi = mid;
}
if (hi - lo) <= GAMMA_SHAPE_TARGET_TOL * hi.max(1.0) {
break;
}
}
0.5 * (lo + hi)
}
#[inline]
fn gamma_loglikelihood_with_shape(
y: ArrayView1<'_, f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
shape: f64,
) -> f64 {
const EPS: f64 = 1e-12;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let shape_c = shape.clamp(GAMMA_SHAPE_MIN, GAMMA_SHAPE_MAX);
let shape_ln = shape_c.ln();
let ln_gamma_shape = ln_gamma(shape_c);
(0..y.len())
.into_par_iter()
.map(|i| {
let yi_c = y[i].max(EPS);
let mui_c = mu[i].max(EPS);
priorweights[i]
* (shape_c * shape_ln - ln_gamma_shape - shape_c * mui_c.ln()
+ (shape_c - 1.0) * yi_c.ln()
- shape_c * yi_c / mui_c)
})
.sum()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PirlsLinearSolvePath {
DenseTransformed,
SparseNative,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PirlsCoordinateFrame {
TransformedQs,
OriginalSparseNative,
}
#[derive(Clone, Debug)]
pub struct SparsePirlsDecision {
pub path: PirlsLinearSolvePath,
pub reason: &'static str,
pub p: usize,
pub nnz_x: usize,
pub nnz_xtwx_symbolic: Option<usize>,
pub nnz_s_lambda: usize,
pub nnz_h_est: Option<usize>,
pub density_h_est: Option<f64>,
}
fn fmt_opt_usize(v: Option<usize>) -> String {
v.map(|v| v.to_string()).unwrap_or_else(|| "na".to_string())
}
fn fmt_opt_f64(v: Option<f64>) -> String {
v.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "na".to_string())
}
impl SparsePirlsDecision {
fn path_str(&self) -> &'static str {
match self.path {
PirlsLinearSolvePath::DenseTransformed => "dense_transformed",
PirlsLinearSolvePath::SparseNative => "sparse_native",
}
}
fn format_fields(&self, path: &str) -> String {
format!(
"path={path} reason={} p={} nnz_x={} nnz_xtwx_symbolic={} nnz_s_lambda={} nnz_h_est={} density_h_est={}",
self.reason,
self.p,
self.nnz_x,
fmt_opt_usize(self.nnz_xtwx_symbolic),
self.nnz_s_lambda,
fmt_opt_usize(self.nnz_h_est),
fmt_opt_f64(self.density_h_est),
)
}
fn log_once(&self) {
let path = self.path_str();
let key = self.format_fields(path);
let repetition_count = pirls_decision_repetition_count(key.clone());
if repetition_count == 1 {
log::debug!("[pirls-path] {key}");
return;
}
if should_log_pirls_decision_summary(repetition_count) {
log::debug!(
"[pirls-path] repeated path={} reason={} count={} (suppressing identical decisions)",
path,
self.reason,
repetition_count,
);
}
}
}
fn pirls_decision_repetition_count(log_key: String) -> usize {
static PIRLS_DECISION_LOG_COUNTS: OnceLock<Mutex<HashMap<String, usize>>> = OnceLock::new();
let counts = PIRLS_DECISION_LOG_COUNTS.get_or_init(|| Mutex::new(HashMap::new()));
let mut counts = counts.lock().expect("pirls decision log counter poisoned");
let count = counts.entry(log_key).or_insert(0);
*count += 1;
*count
}
fn should_log_pirls_decision_summary(repetition_count: usize) -> bool {
repetition_count > 1 && repetition_count.is_power_of_two()
}
const SPARSE_NATIVE_MAX_H_DENSITY: f64 = 0.30;
#[derive(Clone, Debug)]
struct SparsePenaltyPattern {
upper_triplets: Vec<(usize, usize, f64)>,
nnz_upper: usize,
}
impl SparsePenaltyPattern {
fn from_dense_upper(matrix: &Array2<f64>, tol: f64) -> Self {
let p = matrix.nrows().min(matrix.ncols());
let mut upper_triplets = Vec::new();
for col in 0..p {
for row in 0..=col {
let value = matrix[[row, col]];
if value.abs() > tol {
upper_triplets.push((row, col, value));
}
}
}
let nnz_upper = upper_triplets.len();
Self {
upper_triplets,
nnz_upper,
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct SparsePenalizedSystemStats {
pub(crate) nnz_xtwx_symbolic: usize,
pub(crate) nnz_s_lambda_upper: usize,
pub(crate) nnz_h_upper: usize,
pub(crate) density_upper: f64,
}
struct SparsePenalizedSystemCache {
xtwx_cache: SparseXtWxCache,
penalty_pattern: SparsePenaltyPattern,
h_upper_symbolic: SymbolicSparseColMat<usize>,
h_uppervalues: Vec<f64>,
h_upper_col_ptr: Vec<usize>,
h_upperrow_idx: Vec<usize>,
p: usize,
}
impl SparsePenalizedSystemCache {
fn new(
x: &SparseColMat<usize, f64>,
penalty_pattern: SparsePenaltyPattern,
) -> Result<Self, EstimationError> {
let xtwx_cache = SparseXtWxCache::new(x)?;
let p = x.ncols();
let h_upper_symbolic = build_penalized_symbolic(
p,
xtwx_cache.xtwx_symbolic.col_ptr(),
xtwx_cache.xtwx_symbolic.row_idx(),
&penalty_pattern.upper_triplets,
)?;
let h_uppervalues = vec![0.0; h_upper_symbolic.row_idx().len()];
Ok(Self {
xtwx_cache,
penalty_pattern,
h_upper_col_ptr: h_upper_symbolic.col_ptr().to_vec(),
h_upperrow_idx: h_upper_symbolic.row_idx().to_vec(),
h_upper_symbolic,
h_uppervalues,
p,
})
}
fn matches(
&self,
x: &SparseColMat<usize, f64>,
penalty_pattern: &SparsePenaltyPattern,
) -> bool {
self.xtwx_cache.matches(x)
&& self.penalty_pattern.nnz_upper == penalty_pattern.nnz_upper
&& self.penalty_pattern.upper_triplets == penalty_pattern.upper_triplets
}
fn stats(&self) -> SparsePenalizedSystemStats {
let upper_total = self.p.saturating_mul(self.p + 1) / 2;
SparsePenalizedSystemStats {
nnz_xtwx_symbolic: self.xtwx_cache.xtwx_symbolic.row_idx().len(),
nnz_s_lambda_upper: self.penalty_pattern.nnz_upper,
nnz_h_upper: self.h_upper_symbolic.row_idx().len(),
density_upper: if upper_total == 0 {
0.0
} else {
self.h_upper_symbolic.row_idx().len() as f64 / upper_total as f64
},
}
}
fn assemble_upper(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
ridge: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
if weights.len() != self.xtwx_cache.nrows {
return Err(EstimationError::InvalidInput(format!(
"weights length {} does not match design rows {}",
weights.len(),
self.xtwx_cache.nrows
)));
}
self.xtwx_cache.compute_numeric(x, weights)?;
self.h_uppervalues.fill(0.0);
let mut cursor = self.h_upper_col_ptr[..self.p].to_vec();
let xtwx_col_ptr = self.xtwx_cache.xtwx_symbolic.col_ptr();
let xtwxrow_idx = self.xtwx_cache.xtwx_symbolic.row_idx();
for col in 0..self.p {
let start = xtwx_col_ptr[col];
let end = xtwx_col_ptr[col + 1];
for idx in start..end {
let row = xtwxrow_idx[idx];
if row <= col {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < row
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != row
{
return Err(EstimationError::InvalidInput(
"penalized symbolic pattern missing XtWX entry".to_string(),
));
}
self.h_uppervalues[*cursor_idx] += self.xtwx_cache.xtwxvalues[idx];
}
}
}
cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
for &(row, col, value) in &self.penalty_pattern.upper_triplets {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < row
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != row
{
return Err(EstimationError::InvalidInput(
"penalized symbolic pattern missing penalty entry".to_string(),
));
}
self.h_uppervalues[*cursor_idx] += value;
}
if ridge > 0.0 {
cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
for col in 0..self.p {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < col
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != col
{
return Err(EstimationError::InvalidInput(
"penalized symbolic pattern missing diagonal entry".to_string(),
));
}
self.h_uppervalues[*cursor_idx] += ridge;
}
}
Ok(SparseColMat::new(
self.h_upper_symbolic.clone(),
self.h_uppervalues.clone(),
))
}
}
fn build_penalized_symbolic(
p: usize,
xtwx_col_ptr: &[usize],
xtwxrow_idx: &[usize],
penalty_triplets: &[(usize, usize, f64)],
) -> Result<SymbolicSparseColMat<usize>, EstimationError> {
let mut cols: Vec<BTreeSet<usize>> = (0..p).map(|_| BTreeSet::new()).collect();
for col in 0..p {
cols[col].insert(col);
let start = xtwx_col_ptr[col];
let end = xtwx_col_ptr[col + 1];
for &row in &xtwxrow_idx[start..end] {
if row <= col {
cols[col].insert(row);
}
}
}
for &(row, col, _) in penalty_triplets {
if row > col || col >= p {
return Err(EstimationError::InvalidInput(
"penalty sparse pattern must be upper-triangular within bounds".to_string(),
));
}
cols[col].insert(row);
}
let mut col_ptr = Vec::with_capacity(p + 1);
let mut row_idx = Vec::new();
col_ptr.push(0);
for rows in cols {
row_idx.extend(rows.into_iter());
col_ptr.push(row_idx.len());
}
Ok(unsafe { SymbolicSparseColMat::new_unchecked(p, p, col_ptr, None, row_idx) })
}
pub trait WorkingModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError>;
fn update_with_curvature(
&mut self,
beta: &Coefficients,
_: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.update(beta)
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, curvature)
}
fn supports_observed_information_curvature(&self) -> bool {
false
}
}
#[derive(Clone, Copy)]
pub(crate) struct IntegratedWorkingInput<'a> {
pub quadctx: &'a crate::quadrature::QuadratureContext,
pub se: ArrayView1<'a, f64>,
pub mixture_link_state: Option<&'a MixtureLinkState>,
pub sas_link_state: Option<&'a SasLinkState>,
}
pub struct WorkingDerivativeBuffersMut<'a> {
c: &'a mut Array1<f64>,
d: &'a mut Array1<f64>,
dmu_deta: &'a mut Array1<f64>,
d2mu_deta2: &'a mut Array1<f64>,
d3mu_deta3: &'a mut Array1<f64>,
}
#[derive(Clone, Copy)]
struct WorkingBernoulliGeometry {
mu: f64,
weight: f64,
z: f64,
c: f64,
d: f64,
}
pub(crate) trait WorkingLikelihood {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError>;
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError>;
}
impl WorkingLikelihood for GlmLikelihoodSpec {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
match (self.family, integrated) {
(
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture,
Some(integ),
) => {
update_glmvectors_integrated_by_family(
integ.quadctx,
y,
eta,
integ.se,
self.family,
priorweights,
mu,
weights,
z,
derivatives,
integ.mixture_link_state,
integ.sas_link_state,
)?;
Ok(())
}
(
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic,
None,
) => {
update_glmvectors(
y,
eta,
&InverseLink::Standard(self.link_function()),
priorweights,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(GlmLikelihoodFamily::BinomialMixture, None) => Err(EstimationError::InvalidInput(
"BinomialMixture IRLS update requires explicit mixture link state".to_string(),
)),
(GlmLikelihoodFamily::GaussianIdentity, _) => {
update_glmvectors(
y,
eta,
&InverseLink::Standard(LinkFunction::Identity),
priorweights,
mu,
weights,
z,
None,
)?;
Ok(())
}
(GlmLikelihoodFamily::PoissonLog, _) => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
(GlmLikelihoodFamily::GammaLog, _) => {
write_gamma_log_working_state(
y,
eta,
priorweights,
self.gamma_shape().unwrap_or(1.0),
mu,
weights,
z,
derivatives,
);
Ok(())
}
}
}
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError> {
Ok(calculate_deviance(y, mu, *self, priorweights))
}
}
#[derive(Debug, Clone)]
pub enum FirthDiagnostics {
Inactive,
Active {
jeffreys_logdet: f64,
hat_diag: Array1<f64>,
},
}
impl Default for FirthDiagnostics {
fn default() -> Self {
Self::Inactive
}
}
impl FirthDiagnostics {
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
match self {
Self::Inactive => None,
Self::Active {
jeffreys_logdet, ..
} => Some(*jeffreys_logdet),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HessianCurvatureKind {
Fisher,
Observed,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ExportedLaplaceCurvature {
ObservedExact,
ExpectedInformationSurrogate,
InvalidObservedCurvature {
min_eigenvalue: f64,
pd_tolerance: f64,
gradient_norm: f64,
},
}
#[derive(Debug, Clone)]
pub struct WorkingState {
pub eta: LinearPredictor,
pub gradient: Array1<f64>,
pub hessian: crate::linalg::matrix::SymmetricMatrix,
pub log_likelihood: f64,
pub deviance: f64,
pub penalty_term: f64,
pub firth: FirthDiagnostics,
pub ridge_used: f64,
pub hessian_curvature: HessianCurvatureKind,
pub gradient_natural_scale: f64,
}
impl WorkingState {
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
self.firth.jeffreys_logdet()
}
#[inline]
pub fn relative_gradient_norm(&self, g_norm: f64) -> f64 {
g_norm / (1.0 + self.gradient_natural_scale)
}
#[inline]
fn kkt_dimension_scale(&self) -> f64 {
let n = self.eta.len().max(1) as f64;
let p = (self.gradient.len() as f64).max(1.0);
n.sqrt() * p.sqrt()
}
#[inline]
pub fn certifies_kkt(&self, g_norm: f64, tol: f64) -> bool {
g_norm < tol * self.kkt_dimension_scale() || self.relative_gradient_norm(g_norm) < tol
}
#[inline]
pub fn near_stationary_kkt(&self, g_norm: f64, tol: f64) -> bool {
let near_tol = tol.max(1e-6) * 10.0;
g_norm <= near_tol * self.kkt_dimension_scale()
|| self.relative_gradient_norm(g_norm) <= near_tol
}
}
#[inline]
pub(crate) fn array1_l2_norm(v: &Array1<f64>) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub struct PirlsWorkspace {
pub sqrtw: Array1<f64>,
pub wz: Array1<f64>,
pub eta_buf: Array1<f64>,
pub scaled_matrix: Array2<f64>, pub final_aug_matrix: Array2<f64>, pub rhs_full: Array1<f64>, pub working_residual: Array1<f64>,
pub weighted_residual: Array1<f64>,
pub delta_eta: Array1<f64>,
pub vec_buf_p: Array1<f64>,
sparse_penalized_system_cache: Option<SparsePenalizedSystemCache>,
pub factorization_scratch: MemBuffer,
pub perm: Vec<usize>,
pub perm_inv: Vec<usize>,
pub factorization_matrix: Array2<f64>,
pub weighted_xvalues: Vec<f64>,
pub weighted_x_chunk: Array2<f64>,
pub hessian_buf: Array2<f64>,
pub matvec_buf: Array1<f64>,
}
impl PirlsWorkspace {
pub fn new(n: usize, p: usize, _: usize, _: usize) -> Self {
PirlsWorkspace {
sqrtw: Array1::zeros(n),
wz: Array1::zeros(n),
eta_buf: Array1::zeros(n),
scaled_matrix: Array2::zeros((0, 0).f()),
final_aug_matrix: Array2::zeros((0, 0).f()),
rhs_full: Array1::zeros(0),
working_residual: Array1::zeros(n),
weighted_residual: Array1::zeros(n),
delta_eta: Array1::zeros(n),
vec_buf_p: Array1::zeros(p),
sparse_penalized_system_cache: None,
factorization_scratch: {
let par = faer::Par::Seq;
let req = faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<f64>(
1,
par,
Spec::new(<LltParams as Auto<f64>>::auto()),
);
MemBuffer::new(req)
},
perm: vec![0; p],
perm_inv: vec![0; p],
factorization_matrix: Array2::zeros((0, 0)),
weighted_xvalues: Vec::new(),
weighted_x_chunk: Array2::zeros((0, 0).f()),
hessian_buf: Array2::zeros((0, 0).f()),
matvec_buf: Array1::zeros(n),
}
}
#[inline]
fn dense_xtwx_chunkrows(p: usize) -> usize {
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 131_072; const TARGET_BYTES: usize = 64 * 1024 * 1024; let bytes_perrow = p.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_perrow).clamp(MIN_ROWS, MAX_ROWS)
}
fn add_dense_xtwx_streaming_from_sqrt<S>(
sqrtw: &Array1<f64>,
weighted_x_chunk: &mut Array2<f64>,
x: &ArrayBase<S, Ix2>,
out: &mut Array2<f64>,
par: Par,
) where
S: Data<Elem = f64> + Sync,
{
let n = x.nrows();
let p = x.ncols();
if n == 0 || p == 0 {
return;
}
debug_assert_eq!(
sqrtw.len(),
n,
"sqrtw length must match row count for streamed XtWX"
);
let chunkrows = Self::dense_xtwx_chunkrows(p).min(n);
let num_chunks = (n + chunkrows - 1) / chunkrows;
let use_parallel = num_chunks >= 4 && (n as u64) * (p as u64) >= 200_000;
if use_parallel {
let combined = (0..num_chunks)
.into_par_iter()
.fold(
|| {
(
Array2::<f64>::zeros((chunkrows, p).f()),
Array2::<f64>::zeros((p, p).f()),
)
},
|(mut chunk_buf, mut acc), ci| {
let start = ci * chunkrows;
let rows = (n - start).min(chunkrows);
{
let mut chunk = chunk_buf.slice_mut(s![0..rows, ..]);
let x_slice = x.slice(s![start..start + rows, ..]);
let w_slice = sqrtw.slice(s![start..start + rows]);
Zip::from(chunk.rows_mut())
.and(x_slice.rows())
.and(&w_slice)
.par_for_each(|mut dst, src, &w| {
Zip::from(&mut dst).and(&src).for_each(|d, &s| *d = s * w);
});
}
let chunkrowsview = chunk_buf.slice(s![0..rows, ..]);
let chunkview = FaerArrayView::new(&chunkrowsview);
let mut accview = array2_to_matmut(&mut acc);
matmul(
accview.as_mut(),
Accum::Add,
chunkview.as_ref().transpose(),
chunkview.as_ref(),
1.0,
Par::Seq,
);
(chunk_buf, acc)
},
)
.reduce(
|| {
(
Array2::<f64>::zeros((0, 0)),
Array2::<f64>::zeros((p, p).f()),
)
},
|(_, mut a), (_, b)| {
a += &b;
(Array2::zeros((0, 0)), a)
},
);
*out += &combined.1;
} else {
if weighted_x_chunk.ncols() != p || weighted_x_chunk.nrows() != chunkrows {
*weighted_x_chunk = Array2::zeros((chunkrows, p).f());
}
let mut outview = array2_to_matmut(out);
for start in (0..n).step_by(chunkrows) {
let rows = (n - start).min(chunkrows);
{
let mut chunk = weighted_x_chunk.slice_mut(s![0..rows, ..]);
let x_slice = x.slice(s![start..start + rows, ..]);
let w_slice = sqrtw.slice(s![start..start + rows]);
Zip::from(chunk.rows_mut())
.and(x_slice.rows())
.and(&w_slice)
.par_for_each(|mut dst, src, &w| {
Zip::from(&mut dst).and(&src).for_each(|d, &s| *d = s * w);
});
}
let chunkrowsview = weighted_x_chunk.slice(s![0..rows, ..]);
let chunkview = FaerArrayView::new(&chunkrowsview);
matmul(
outview.as_mut(),
Accum::Add,
chunkview.as_ref().transpose(),
chunkview.as_ref(),
1.0,
par,
);
}
}
}
#[inline]
fn fill_sqrtweights<S>(&mut self, weights: &ArrayBase<S, Ix1>)
where
S: Data<Elem = f64>,
{
if self.sqrtw.len() != weights.len() {
self.sqrtw = Array1::zeros(weights.len());
}
Zip::from(&mut self.sqrtw)
.and(weights)
.par_for_each(|sqrtw, &w| *sqrtw = w.max(0.0).sqrt());
}
fn ensure_sparse_penalty_cache(
&mut self,
x: &SparseColMat<usize, f64>,
s_lambda: &Array2<f64>,
) -> Result<(), EstimationError> {
let penalty_pattern = SparsePenaltyPattern::from_dense_upper(s_lambda, 1e-12);
let rebuild = match self.sparse_penalized_system_cache.as_ref() {
Some(cache) => !cache.matches(x, &penalty_pattern),
None => true,
};
if rebuild {
self.sparse_penalized_system_cache =
Some(SparsePenalizedSystemCache::new(x, penalty_pattern)?);
}
Ok(())
}
pub(crate) fn sparse_penalized_system_stats(
&mut self,
x: &SparseColMat<usize, f64>,
s_lambda: &Array2<f64>,
) -> Result<SparsePenalizedSystemStats, EstimationError> {
self.ensure_sparse_penalty_cache(x, s_lambda)?;
Ok(self.sparse_penalized_system_cache.as_ref().unwrap().stats())
}
fn assemble_sparse_penalized_hessian(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
s_lambda: &Array2<f64>,
ridge: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
self.ensure_sparse_penalty_cache(x, s_lambda)?;
self.sparse_penalized_system_cache
.as_mut()
.unwrap()
.assemble_upper(x, weights, ridge)
}
}
#[derive(Clone, Debug)]
pub struct WorkingModelPirlsOptions {
pub max_iterations: usize,
pub convergence_tolerance: f64,
pub max_step_halving: usize,
pub min_step_size: f64,
pub firth_bias_reduction: bool,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
pub initial_lm_lambda: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct WorkingModelIterationInfo {
pub iteration: usize,
pub deviance: f64,
pub gradient_norm: f64,
pub step_size: f64,
pub step_halving: usize,
}
#[derive(Clone)]
pub struct WorkingModelPirlsResult {
pub beta: Coefficients,
pub state: WorkingState,
pub status: PirlsStatus,
pub iterations: usize,
pub lastgradient_norm: f64,
pub last_deviance_change: f64,
pub last_step_size: f64,
pub last_step_halving: usize,
pub max_abs_eta: f64,
pub constraint_kkt: Option<ConstraintKktDiagnostics>,
pub final_lm_lambda: f64,
pub final_accept_rho: Option<f64>,
pub min_penalized_deviance: f64,
pub exported_laplace_curvature: ExportedLaplaceCurvature,
}
const FIXED_STABILIZATION_RIDGE: f64 = 1e-8;
enum WorkingCoordinateDesign {
OriginalSparseNative,
TransformedExplicit {
x_transformed: DesignMatrix,
x_csr: Option<SparseRowMat<usize, f64>>,
},
TransformedImplicit {
transform: WorkingReparamTransform,
},
}
#[derive(Clone)]
enum WorkingReparamTransform {
Dense(Arc<Array2<f64>>),
Kronecker(Arc<KroneckerQsTransform>),
}
impl WorkingReparamTransform {
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(qs) => qs.dot(vector),
Self::Kronecker(transform) => transform.apply(vector),
}
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(qs) => fast_atv(qs, vector),
Self::Kronecker(transform) => transform.apply_transpose(vector),
}
}
fn materialize_dense(&self) -> Array2<f64> {
match self {
Self::Dense(qs) => qs.as_ref().clone(),
Self::Kronecker(transform) => transform.materialize(),
}
}
fn conjugate_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
match self {
Self::Dense(qs) => {
let tmp = fast_atb(qs, matrix);
symmetrize_dense_matrix(&fast_ab(&tmp, qs))
}
Self::Kronecker(transform) => transform.conjugate_matrix(matrix),
}
}
}
#[derive(Clone)]
enum PirlsPenalty {
Dense {
s_transformed: Array2<f64>,
e_transformed: Array2<f64>,
},
Diagonal {
diag: Array1<f64>,
positive_indices: Vec<usize>,
},
}
impl PirlsPenalty {
fn dim(&self) -> usize {
match self {
Self::Dense { s_transformed, .. } => s_transformed.ncols(),
Self::Diagonal { diag, .. } => diag.len(),
}
}
fn rank(&self) -> usize {
match self {
Self::Dense { e_transformed, .. } => e_transformed.nrows(),
Self::Diagonal {
positive_indices, ..
} => positive_indices.len(),
}
}
fn add_to_hessian(&self, hessian: &mut Array2<f64>) {
match self {
Self::Dense { s_transformed, .. } => {
*hessian += s_transformed;
}
Self::Diagonal { diag, .. } => {
for i in 0..diag.len() {
hessian[[i, i]] += diag[i];
}
}
}
}
fn apply(&self, beta: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense { s_transformed, .. } => s_transformed.dot(beta),
Self::Diagonal { diag, .. } => diag * beta,
}
}
}
#[derive(Clone)]
struct KroneckerQsTransform {
marginal_qs: Vec<Array2<f64>>,
dims: Vec<usize>,
p: usize,
}
impl KroneckerQsTransform {
fn new(result: &KroneckerReparamResult) -> Self {
let dims = result.marginal_dims.clone();
let p = dims.iter().product();
Self {
marginal_qs: result.marginal_qs.clone(),
dims,
p,
}
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply_internal(vector, false)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply_internal(vector, true)
}
fn apply_internal(&self, vector: &Array1<f64>, transpose: bool) -> Array1<f64> {
debug_assert_eq!(vector.len(), self.p);
let mut current = vector.to_vec();
for (axis, q) in self.marginal_qs.iter().enumerate() {
current = apply_kron_mode(¤t, &self.dims, axis, q, transpose);
}
Array1::from_vec(current)
}
fn materialize(&self) -> Array2<f64> {
let mut qs = Array2::<f64>::zeros((self.p, self.p));
for j in 0..self.p {
let mut e = Array1::<f64>::zeros(self.p);
e[j] = 1.0;
let col = self.apply(&e);
qs.column_mut(j).assign(&col);
}
qs
}
fn conjugate_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
let p = self.p;
let mut right = Array2::<f64>::zeros((p, p));
for j in 0..p {
let col = matrix.dot(&self.column(j));
right.column_mut(j).assign(&col);
}
let mut out = Array2::<f64>::zeros((p, p));
for j in 0..p {
let transformed_col = self.apply_transpose(&right.column(j).to_owned());
out.column_mut(j).assign(&transformed_col);
}
symmetrize_dense_matrix(&out)
}
fn column(&self, j: usize) -> Array1<f64> {
let mut e = Array1::<f64>::zeros(self.p);
e[j] = 1.0;
self.apply(&e)
}
}
#[inline]
fn symmetrize_dense_matrix(matrix: &Array2<f64>) -> Array2<f64> {
(matrix + &matrix.t().to_owned()) * 0.5
}
fn apply_kron_mode(
data: &[f64],
dims: &[usize],
axis: usize,
q: &Array2<f64>,
transpose: bool,
) -> Vec<f64> {
let before: usize = dims[..axis].iter().product();
let dim = dims[axis];
let after: usize = dims[axis + 1..].iter().product();
let mut out = vec![0.0_f64; data.len()];
for b in 0..before {
for s in 0..after {
for i in 0..dim {
let mut acc = 0.0;
for a in 0..dim {
let coeff = if transpose { q[[a, i]] } else { q[[i, a]] };
acc += coeff * data[(b * dim + a) * after + s];
}
out[(b * dim + i) * after + s] = acc;
}
}
}
out
}
struct GamWorkingModel<'a> {
x_original: DesignMatrix,
coordinate_design: WorkingCoordinateDesign,
offset: Array1<f64>,
y: ArrayView1<'a, f64>,
priorweights: ArrayView1<'a, f64>,
penalty: PirlsPenalty,
workspace: PirlsWorkspace,
likelihood: GlmLikelihoodSpec,
link_kind: InverseLink,
firth_bias_reduction: bool,
lastmu: Array1<f64>,
lastweights: Array1<f64>,
lastz: Array1<f64>,
last_c: Array1<f64>,
last_d: Array1<f64>,
lasthessian_weights: Array1<f64>,
lasthessian_c: Array1<f64>,
lasthessian_d: Array1<f64>,
lasthessian_curvature: HessianCurvatureKind,
last_dmu_deta: Array1<f64>,
last_d2mu_deta2: Array1<f64>,
last_d3mu_deta3: Array1<f64>,
last_penalty_term: f64,
x_original_csr: Option<SparseRowMat<usize, f64>>,
covariate_se: Option<Array1<f64>>,
quadctx: crate::quadrature::QuadratureContext,
}
struct GamModelFinalState {
likelihood: GlmLikelihoodSpec,
coordinate_frame: PirlsCoordinateFrame,
finalmu: Array1<f64>,
finalweights: Array1<f64>,
scoreweights: Array1<f64>,
finalz: Array1<f64>,
final_c: Array1<f64>,
final_d: Array1<f64>,
final_dmu_deta: Array1<f64>,
final_d2mu_deta2: Array1<f64>,
final_d3mu_deta3: Array1<f64>,
penalty_term: f64,
}
impl<'a> GamWorkingModel<'a> {
fn new(
x_transformed: Option<DesignMatrix>,
x_original: DesignMatrix,
coordinate_frame: PirlsCoordinateFrame,
offset: ArrayView1<f64>,
y: ArrayView1<'a, f64>,
priorweights: ArrayView1<'a, f64>,
penalty: PirlsPenalty,
workspace: PirlsWorkspace,
likelihood: GlmLikelihoodSpec,
link_kind: InverseLink,
firth_bias_reduction: bool,
transform: Option<WorkingReparamTransform>,
quadctx: crate::quadrature::QuadratureContext,
) -> Self {
let coordinate_design = match coordinate_frame {
PirlsCoordinateFrame::OriginalSparseNative => {
WorkingCoordinateDesign::OriginalSparseNative
}
PirlsCoordinateFrame::TransformedQs => {
if let Some(x_transformed) = x_transformed {
WorkingCoordinateDesign::TransformedExplicit {
x_csr: x_transformed.to_csr_cache(),
x_transformed,
}
} else {
WorkingCoordinateDesign::TransformedImplicit {
transform: transform.expect(
"TransformedQs PIRLS coordinate frame requires either x_transformed or qs",
),
}
}
}
};
let x_original_csr = x_original.to_csr_cache();
let n = match &coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => x_original.nrows(),
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
x_transformed.nrows()
}
WorkingCoordinateDesign::TransformedImplicit { .. } => x_original.nrows(),
};
GamWorkingModel {
x_original,
coordinate_design,
offset: offset.to_owned(),
y,
priorweights,
penalty,
workspace,
likelihood,
link_kind,
firth_bias_reduction,
lastmu: Array1::zeros(n),
lastweights: Array1::zeros(n),
lastz: Array1::zeros(n),
last_c: Array1::zeros(n),
last_d: Array1::zeros(n),
lasthessian_weights: Array1::zeros(n),
lasthessian_c: Array1::zeros(n),
lasthessian_d: Array1::zeros(n),
lasthessian_curvature: HessianCurvatureKind::Fisher,
last_dmu_deta: Array1::zeros(n),
last_d2mu_deta2: Array1::zeros(n),
last_d3mu_deta3: Array1::zeros(n),
last_penalty_term: 0.0,
x_original_csr,
covariate_se: None,
quadctx,
}
}
fn with_covariate_se(mut self, se: Array1<f64>) -> Self {
self.covariate_se = Some(se);
self
}
fn into_final_state(self) -> GamModelFinalState {
let GamWorkingModel {
coordinate_design,
lastmu,
lastweights,
lastz,
last_c: _,
last_d: _,
lasthessian_weights,
lasthessian_c,
lasthessian_d,
last_dmu_deta,
last_d2mu_deta2,
last_d3mu_deta3,
last_penalty_term,
..
} = self;
let coordinate_frame = match coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => {
PirlsCoordinateFrame::OriginalSparseNative
}
WorkingCoordinateDesign::TransformedExplicit { .. } => {
PirlsCoordinateFrame::TransformedQs
}
WorkingCoordinateDesign::TransformedImplicit { .. } => {
PirlsCoordinateFrame::TransformedQs
}
};
GamModelFinalState {
likelihood: self.likelihood,
coordinate_frame,
finalmu: lastmu,
finalweights: lasthessian_weights,
scoreweights: lastweights,
finalz: lastz,
final_c: lasthessian_c,
final_d: lasthessian_d,
final_dmu_deta: last_dmu_deta,
final_d2mu_deta2: last_d2mu_deta2,
final_d3mu_deta3: last_d3mu_deta3,
penalty_term: last_penalty_term,
}
}
fn transformed_matvec_into(&self, beta: &Coefficients, out: &mut Array1<f64>) {
match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
if let Some(dense) = x_transformed.as_dense() {
fast_av_into(dense, beta.as_ref(), out);
return;
}
out.assign(&x_transformed.matrixvectormultiply(beta));
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let beta_orig = transform.apply(beta.as_ref());
if let Some(dense) = self.x_original.as_dense() {
fast_av_into(dense, &beta_orig, out);
} else {
out.assign(&self.x_original.apply(&beta_orig));
}
}
WorkingCoordinateDesign::OriginalSparseNative => {
out.assign(&self.x_original.matrixvectormultiply(beta));
}
}
}
fn transformed_transpose_matvec(&self, vec: &Array1<f64>) -> Array1<f64> {
match &self.coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => {
self.x_original.transpose_vector_multiply(vec)
}
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
x_transformed.transpose_vector_multiply(vec)
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let xtv = self.x_original.transpose_vector_multiply(vec);
transform.apply_transpose(&xtv)
}
}
}
fn compute_xtwx_blas(
workspace: &mut PirlsWorkspace,
design: &DesignMatrix,
weights: &Array1<f64>,
) -> Result<Array2<f64>, EstimationError> {
match design {
DesignMatrix::Dense(x) if x.is_materialized_dense() => {
let p = x.ncols();
let x_dense = x.to_dense_arc();
workspace.fill_sqrtweights(weights);
if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
workspace.hessian_buf = Array2::zeros((p, p).f());
} else {
workspace.hessian_buf.fill(0.0);
}
PirlsWorkspace::add_dense_xtwx_streaming_from_sqrt(
&workspace.sqrtw,
&mut workspace.weighted_x_chunk,
x_dense.as_ref(),
&mut workspace.hessian_buf,
get_global_parallelism(),
);
Ok(std::mem::take(&mut workspace.hessian_buf))
}
_ => crate::matrix::xt_diag_x_symmetric(design, weights)
.map(|h| h.to_dense())
.map_err(EstimationError::InvalidInput),
}
}
fn penalized_hessian(&mut self, weights: &Array1<f64>) -> Result<Array2<f64>, EstimationError> {
match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
let mut h = Self::compute_xtwx_blas(&mut self.workspace, x_transformed, weights)?;
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let xtwx = Self::compute_xtwx_blas(&mut self.workspace, &self.x_original, weights)?;
let mut h = transform.conjugate_matrix(&xtwx);
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::OriginalSparseNative => {
let mut h =
Self::compute_xtwx_blas(&mut self.workspace, &self.x_original, weights)?;
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
}
}
fn supports_observed_hessian_curvature(&self) -> bool {
supports_observed_hessian_curvature_for_likelihood(self.likelihood, &self.link_kind)
}
fn update_hessian_curvature_arrays(
&mut self,
requested: HessianCurvatureKind,
) -> Result<HessianCurvatureKind, EstimationError> {
if requested == HessianCurvatureKind::Fisher || !self.supports_observed_hessian_curvature()
{
self.lasthessian_weights.assign(&self.lastweights);
self.lasthessian_c.assign(&self.last_c);
self.lasthessian_d.assign(&self.last_d);
return Ok(HessianCurvatureKind::Fisher);
}
compute_observed_hessian_curvature_arrays_into(
self.likelihood,
&self.link_kind,
&self.workspace.eta_buf,
self.y,
&self.lastmu,
&self.last_dmu_deta,
&self.last_d2mu_deta2,
&self.last_d3mu_deta3,
&self.lastweights,
self.priorweights,
&mut self.lasthessian_weights,
&mut self.lasthessian_c,
&mut self.lasthessian_d,
)?;
Ok(HessianCurvatureKind::Observed)
}
fn sparse_penalized_hessian(
&mut self,
weights: &Array1<f64>,
ridge: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
let x_sparse = self.x_original.as_sparse().ok_or_else(|| {
EstimationError::InvalidInput(
"sparse-native PIRLS requires a sparse original design".to_string(),
)
})?;
let PirlsPenalty::Dense { s_transformed, .. } = &self.penalty else {
return Err(EstimationError::InvalidInput(
"sparse-native PIRLS requires a dense transformed penalty matrix".to_string(),
));
};
self.workspace
.assemble_sparse_penalized_hessian(x_sparse, weights, s_transformed, ridge)
}
}
impl<'a> WorkingModel for GamWorkingModel<'a> {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
requested_curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
let n = self.offset.len();
if self.workspace.eta_buf.len() != n {
self.workspace.eta_buf = Array1::zeros(n);
}
if self.workspace.matvec_buf.len() != n {
self.workspace.matvec_buf = Array1::zeros(n);
}
let mut matvec_tmp = std::mem::take(&mut self.workspace.matvec_buf);
self.transformed_matvec_into(beta, &mut matvec_tmp);
self.workspace.eta_buf.assign(&self.offset);
self.workspace.eta_buf += &matvec_tmp;
self.workspace.matvec_buf = matvec_tmp;
if self.likelihood.scale.gamma_shape_is_estimated() {
let shape =
estimate_gamma_shape_from_eta(self.y, &self.workspace.eta_buf, self.priorweights);
self.likelihood = self.likelihood.with_gamma_shape(shape);
}
let integrated = self.covariate_se.as_ref().map(|se| IntegratedWorkingInput {
quadctx: &self.quadctx,
se: se.view(),
mixture_link_state: self.link_kind.mixture_state(),
sas_link_state: self.link_kind.sas_state(),
});
match &self.link_kind {
InverseLink::Mixture(_) => {
if let Some(integ) = integrated {
update_glmvectors_integrated_for_link(
integ.quadctx,
self.y,
&self.workspace.eta_buf,
integ.se,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
} else {
update_glmvectors(
self.y,
&self.workspace.eta_buf,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
InverseLink::LatentCLogLog(_) | InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => {
if let Some(integ) = integrated {
update_glmvectors_integrated_for_link(
integ.quadctx,
self.y,
&self.workspace.eta_buf,
integ.se,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
} else {
update_glmvectors(
self.y,
&self.workspace.eta_buf,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
InverseLink::Standard(_) => {
self.likelihood.irls_update(
self.y,
&self.workspace.eta_buf,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
integrated,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
let mut firth = FirthDiagnostics::Inactive;
if self.firth_bias_reduction {
let (hat_diag, jeffreys_logdet) = match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit {
x_transformed,
x_csr,
} => {
if x_transformed.as_sparse().is_some() {
let csr = x_csr.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"missing CSR cache for sparse transformed design".to_string(),
)
})?;
compute_jeffreys_pirls_diagnostics_sparse(
csr,
self.workspace.eta_buf.view(),
self.priorweights,
)?
} else {
let x_dense_cow = x_transformed.to_dense_cow();
compute_jeffreys_pirls_diagnostics(
x_dense_cow.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let x_t_dense =
fast_ab(&self.x_original.to_dense(), &transform.materialize_dense());
compute_jeffreys_pirls_diagnostics(
x_t_dense.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
WorkingCoordinateDesign::OriginalSparseNative => {
if self.x_original.as_sparse().is_some() {
let csr = self.x_original_csr.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"missing CSR cache for sparse original design".to_string(),
)
})?;
compute_jeffreys_pirls_diagnostics_sparse(
csr,
self.workspace.eta_buf.view(),
self.priorweights,
)?
} else {
let x_dense = self
.x_original
.try_to_dense_arc(
"Firth diagnostics require dense access to the original design",
)
.map_err(EstimationError::InvalidInput)?;
compute_jeffreys_pirls_diagnostics(
x_dense.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
}
};
firth = FirthDiagnostics::Active {
jeffreys_logdet,
hat_diag: hat_diag.clone(),
};
ndarray::Zip::from(&mut self.lastz)
.and(&hat_diag)
.and(&self.lastweights)
.and(&self.lastmu)
.par_for_each(|zi, &hii, &wi, &mui| {
if wi > 0.0 {
*zi += hii * (0.5 - mui) / wi;
}
});
}
let z = &self.lastz;
ndarray::Zip::from(&mut self.workspace.weighted_residual)
.and(&mut self.workspace.working_residual)
.and(&self.workspace.eta_buf)
.and(z)
.and(&self.lastweights)
.par_for_each(|wr, r, &eta, &zi, &wi| {
let residual = eta - zi;
*r = residual;
*wr = residual * wi;
});
let mut gradient = self.transformed_transpose_matvec(&self.workspace.weighted_residual);
let score_norm = array1_l2_norm(&gradient);
let s_beta = self.penalty.apply(beta.as_ref());
let s_beta_norm = array1_l2_norm(&s_beta);
gradient += &s_beta;
let hessian_curvature = self.update_hessian_curvature_arrays(requested_curvature)?;
self.lasthessian_curvature = hessian_curvature;
if self.workspace.matvec_buf.len() != n {
self.workspace.matvec_buf = Array1::zeros(n);
}
solver_hessian_weights_into(
&self.lasthessian_weights,
&self.lastweights,
&mut self.workspace.matvec_buf,
);
let solver_weights = std::mem::take(&mut self.workspace.matvec_buf);
let (penalized_hessian, sparsehessian, ridge_used) = if matches!(
self.coordinate_design,
WorkingCoordinateDesign::OriginalSparseNative
) {
let (h_sparse, _factor, ridge_used) =
ensure_sparse_positive_definitewithridge(|ridge| {
self.sparse_penalized_hessian(&solver_weights, ridge)
})?;
(Array2::zeros((0, 0)), Some(h_sparse), ridge_used)
} else {
let mut penalized_hessian = self.penalized_hessian(&solver_weights)?;
#[cfg(debug_assertions)]
debug_assert_symmetric_tol(&penalized_hessian, "PIRLS penalized Hessian", 1e-8);
let ridge_used = ensure_positive_definitewithridge(
&mut penalized_hessian,
"PIRLS penalized Hessian",
)?;
(penalized_hessian, None, ridge_used)
};
self.workspace.matvec_buf = solver_weights;
let deviance = self
.likelihood
.loglik_deviance(self.y, &self.lastmu, self.priorweights)?;
let log_likelihood = calculate_loglikelihood_omitting_constants(
self.y,
&self.lastmu,
self.likelihood,
self.priorweights,
);
let mut penalty_term = beta.as_ref().dot(&s_beta);
let mut ridge_grad_norm = 0.0;
if ridge_used > 0.0 {
let ridge_penalty = ridge_used * beta.as_ref().dot(beta.as_ref());
penalty_term += ridge_penalty;
gradient.zip_mut_with(beta.as_ref(), |g, &b| *g += ridge_used * b);
ridge_grad_norm = ridge_used * array1_l2_norm(beta.as_ref());
}
self.last_penalty_term = penalty_term;
let gradient_natural_scale = score_norm + s_beta_norm + ridge_grad_norm;
Ok(WorkingState {
eta: LinearPredictor::new(std::mem::replace(
&mut self.workspace.eta_buf,
Array1::zeros(0),
)),
gradient,
hessian: match sparsehessian {
Some(h_sparse) => crate::linalg::matrix::SymmetricMatrix::Sparse(h_sparse),
None => crate::linalg::matrix::SymmetricMatrix::Dense(penalized_hessian),
},
log_likelihood,
deviance,
penalty_term,
firth,
ridge_used,
hessian_curvature,
gradient_natural_scale,
})
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
if !self.firth_bias_reduction {
return self.update_with_curvature(beta, curvature);
}
let firth_enabled = self.firth_bias_reduction;
self.firth_bias_reduction = false;
let result = self.update_with_curvature(beta, curvature);
self.firth_bias_reduction = firth_enabled;
result
}
fn supports_observed_information_curvature(&self) -> bool {
self.supports_observed_hessian_curvature()
}
}
const DENSE_OUTER_MAX_P: usize = 1024;
const DENSE_OUTER_PARALLEL_FLOP_THRESHOLD: u64 = 100_000;
enum XtWxBackend {
Dense(DenseOuterState),
Sparse(SparseSpGemmState),
}
struct DenseOuterState {
xtwx_dense: Array2<f64>,
thread_buffers: Vec<Array2<f64>>,
}
struct SparseSpGemmState {
wxvalues: Vec<f64>,
wx_tvalues: Vec<f64>,
sqrt_weights: Vec<f64>,
info: SparseMatMulInfo,
scratch: MemBuffer,
par: Par,
}
pub(crate) struct SparseXtWxCache {
xtwx_symbolic: SymbolicSparseColMat<usize>,
xtwxvalues: Vec<f64>,
nrows: usize,
ncols: usize,
nnz: usize,
x_col_ptr: Vec<usize>,
xrow_idx: Vec<usize>,
x_t_csc: SparseColMat<usize, f64>,
backend: XtWxBackend,
}
impl SparseXtWxCache {
fn new(x: &SparseColMat<usize, f64>) -> Result<Self, EstimationError> {
let x_t_csc =
x.as_ref().transpose().to_col_major().map_err(|_| {
EstimationError::InvalidInput("failed to transpose to CSC".to_string())
})?;
let (xtwx_symbolic, info) = sparse_sparse_matmul_symbolic(x_t_csc.symbolic(), x.symbolic())
.map_err(|_| {
EstimationError::InvalidInput("failed to build symbolic XtWX cache".to_string())
})?;
let xtwxvalues = vec![0.0; xtwx_symbolic.row_idx().len()];
let backend = if x.ncols() <= DENSE_OUTER_MAX_P {
XtWxBackend::Dense(DenseOuterState {
xtwx_dense: Array2::<f64>::zeros((x.ncols(), x.ncols())),
thread_buffers: Vec::new(),
})
} else {
let par = get_global_parallelism();
let scratch = MemBuffer::new(sparse_sparse_matmul_numeric_scratch::<usize, f64>(
xtwx_symbolic.as_ref(),
par,
));
XtWxBackend::Sparse(SparseSpGemmState {
wxvalues: vec![0.0; x.val().len()],
wx_tvalues: vec![0.0; x_t_csc.val().len()],
sqrt_weights: vec![0.0; x.nrows()],
info,
scratch,
par,
})
};
Ok(Self {
xtwx_symbolic,
xtwxvalues,
nrows: x.nrows(),
ncols: x.ncols(),
nnz: x.val().len(),
x_col_ptr: x.symbolic().col_ptr().to_vec(),
xrow_idx: x.symbolic().row_idx().to_vec(),
x_t_csc,
backend,
})
}
fn matches(&self, x: &SparseColMat<usize, f64>) -> bool {
if self.nrows != x.nrows() || self.ncols != x.ncols() || self.nnz != x.val().len() {
return false;
}
let sym = x.symbolic();
self.x_col_ptr.as_slice() == sym.col_ptr() && self.xrow_idx.as_slice() == sym.row_idx()
}
fn compute_numeric(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
) -> Result<(), EstimationError> {
if weights.len() != self.nrows {
return Err(EstimationError::InvalidInput(format!(
"weights length {} does not match design rows {}",
weights.len(),
self.nrows
)));
}
match &mut self.backend {
XtWxBackend::Dense(state) => {
state.compute(self.x_t_csc.as_ref(), weights, self.nrows, self.ncols);
let col_ptr = self.xtwx_symbolic.col_ptr();
let row_idx = self.xtwx_symbolic.row_idx();
let dense = &state.xtwx_dense;
for col in 0..self.ncols {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
if row <= col {
self.xtwxvalues[idx] = dense[[row, col]];
}
}
}
}
XtWxBackend::Sparse(state) => state.compute(
x,
self.x_t_csc.as_ref(),
weights,
self.ncols,
self.xtwx_symbolic.as_ref(),
&mut self.xtwxvalues,
),
}
Ok(())
}
}
impl DenseOuterState {
fn compute(
&mut self,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
n: usize,
p: usize,
) {
debug_assert_eq!(self.xtwx_dense.dim(), (p, p));
self.xtwx_dense.fill(0.0);
if n == 0 || p == 0 {
return;
}
let xtwx_start = std::time::Instant::now();
let nnz_total = x_t.symbolic().row_idx().len() as u64;
let work = nnz_total
.saturating_mul(nnz_total)
.checked_div(n as u64)
.unwrap_or(u64::MAX);
let n_threads = rayon::current_num_threads();
let parallelize = n_threads > 1 && work >= DENSE_OUTER_PARALLEL_FLOP_THRESHOLD;
if !parallelize {
accumulate_outer_upper(&mut self.xtwx_dense, x_t, weights, 0..n);
log::info!(
"[STAGE] PIRLS dense XᵀWX assembly (serial) n={} p={} flops~{} elapsed={:.3}s",
n,
p,
(n as u64).saturating_mul((p as u64).saturating_mul(p as u64)),
xtwx_start.elapsed().as_secs_f64(),
);
return;
}
if self.thread_buffers.len() != n_threads {
self.thread_buffers
.resize_with(n_threads, || Array2::<f64>::zeros((p, p)));
}
let chunk = n.div_ceil(n_threads);
self.thread_buffers
.par_iter_mut()
.enumerate()
.for_each(|(t, buf)| {
buf.fill(0.0);
let start = t * chunk;
let end = (start + chunk).min(n);
if start < end {
accumulate_outer_upper(buf, x_t, weights, start..end);
}
});
for buf in &self.thread_buffers {
self.xtwx_dense += buf;
}
log::info!(
"[STAGE] PIRLS dense XᵀWX assembly (parallel, threads={}) n={} p={} flops~{} elapsed={:.3}s",
rayon::current_num_threads(),
n,
p,
(n as u64).saturating_mul((p as u64).saturating_mul(p as u64)),
xtwx_start.elapsed().as_secs_f64(),
);
}
}
impl SparseSpGemmState {
fn compute(
&mut self,
x: &SparseColMat<usize, f64>,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
p: usize,
xtwx_symbolic: SymbolicSparseColMatRef<'_, usize>,
xtwxvalues: &mut [f64],
) {
let n = x_t.ncols();
debug_assert_eq!(weights.len(), n);
debug_assert_eq!(self.sqrt_weights.len(), n);
let sqrt_w = self.sqrt_weights.as_mut_slice();
for (dst, &w) in sqrt_w.iter_mut().zip(weights.iter()) {
*dst = w.max(0.0).sqrt();
}
let sqrt_w: &[f64] = sqrt_w;
let x_ref = x.as_ref();
for col in 0..p {
let rows = x_ref.row_idx_of_col_raw(col);
let xvals = x_ref.val_of_col(col);
let range = x_ref.col_range(col);
let dst = &mut self.wxvalues[range];
for ((d, &s), row) in dst.iter_mut().zip(xvals.iter()).zip(rows.iter()) {
*d = s * sqrt_w[row.unbound()];
}
}
for col in 0..n {
let w = sqrt_w[col];
let xvals = x_t.val_of_col(col);
let range = x_t.col_range(col);
let dst = &mut self.wx_tvalues[range];
for (d, &s) in dst.iter_mut().zip(xvals.iter()) {
*d = s * w;
}
}
let wx_ref = SparseColMatRef::new(x.symbolic(), &self.wxvalues[..]);
let wx_t_ref = SparseColMatRef::new(x_t.symbolic(), &self.wx_tvalues[..]);
let mut stack = MemStack::new(&mut self.scratch);
let xtwxmut = SparseColMatMut::new(xtwx_symbolic, xtwxvalues);
sparse_sparse_matmul_numeric(
xtwxmut,
Accum::Replace,
wx_t_ref,
wx_ref,
1.0,
&self.info,
self.par,
&mut stack,
);
}
}
#[inline]
fn accumulate_outer_upper(
acc: &mut Array2<f64>,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
rows: std::ops::Range<usize>,
) {
debug_assert_eq!(acc.nrows(), acc.ncols());
let p = acc.ncols();
let acc_data = acc
.as_slice_mut()
.expect("dense XᵀWX accumulator is row-major and contiguous");
for i in rows {
let w_i = weights[i].max(0.0);
if w_i == 0.0 {
continue;
}
let cols = x_t.row_idx_of_col_raw(i);
let vals = x_t.val_of_col(i);
let nnz_i = cols.len();
for jj in 0..nnz_i {
let j = cols[jj].unbound();
let wvj = w_i * vals[jj];
let row = &mut acc_data[j * p..j * p + p];
for kk in jj..nnz_i {
let k = cols[kk].unbound();
row[k] += wvj * vals[kk];
}
}
}
}
fn compute_jeffreys_pirls_diagnostics_sparse(
x_design_csr: &SparseRowMat<usize, f64>,
eta: ArrayView1<f64>,
observation_weights: ArrayView1<f64>,
) -> Result<(Array1<f64>, f64), EstimationError> {
let n = x_design_csr.nrows();
let p = x_design_csr.ncols();
let mut x_dense = Array2::<f64>::zeros((n, p));
let xview = x_design_csr.as_ref();
for i in 0..n {
let vals = xview.val_of_row(i);
let cols = xview.col_idx_of_row_raw(i);
if cols.len() != vals.len() {
return Err(EstimationError::InvalidInput(
"sparse row structure mismatch: column/value lengths differ".to_string(),
));
}
for (idx, &col) in cols.iter().enumerate() {
x_dense[[i, col.unbound()]] = vals[idx];
}
}
compute_jeffreys_pirls_diagnostics(x_dense.view(), eta, observation_weights)
}
fn compute_jeffreys_pirls_diagnostics(
x_design: ArrayView2<f64>,
eta: ArrayView1<f64>,
observation_weights: ArrayView1<f64>,
) -> Result<(Array1<f64>, f64), EstimationError> {
let op = FirthDenseOperator::build_with_observation_weights(
&x_design.to_owned(),
&eta.to_owned(),
observation_weights,
)?;
Ok((op.pirls_hat_diag(), op.jeffreys_logdet()))
}
fn ensure_positive_definitewithridge(
hess: &mut Array2<f64>,
label: &str,
) -> Result<f64, EstimationError> {
let ridge = if FIXED_STABILIZATION_RIDGE > 0.0 {
FIXED_STABILIZATION_RIDGE
} else {
0.0
};
if hess.cholesky(Side::Lower).is_ok() {
return Ok(0.0);
}
if ridge > 0.0 {
for i in 0..hess.nrows() {
hess[[i, i]] += ridge;
}
if hess.cholesky(Side::Lower).is_ok() {
log::debug!("{} stabilized with fixed ridge {:.1e}.", label, ridge);
return Ok(ridge);
}
}
if let Ok((evals, _)) = hess.eigh(Side::Lower) {
let min_eig = evals.iter().fold(f64::INFINITY, |a, &b| a.min(b));
return Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: min_eig,
});
}
Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NEG_INFINITY,
})
}
fn solve_newton_direction_dense(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
) -> Result<(), EstimationError> {
let dense_solve_start = std::time::Instant::now();
let p = hessian.nrows();
if direction_out.len() != gradient.len() {
*direction_out = Array1::zeros(gradient.len());
}
let factor = StableSolver::new("pirls newton direction")
.factorize(hessian)
.map_err(EstimationError::LinearSystemSolveFailed)?;
direction_out.assign(gradient);
let mut rhsview = array1_to_col_matmut(direction_out);
factor.solve_in_place(rhsview.as_mut());
direction_out.mapv_inplace(|v| -v);
if array1_is_finite(direction_out) {
log::info!(
"[STAGE] PIRLS dense newton solve p={} flops~{} elapsed={:.3}s",
p,
(p as u64).saturating_mul((p as u64).saturating_mul(p as u64)) / 3,
dense_solve_start.elapsed().as_secs_f64(),
);
return Ok(());
}
Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed,
))
}
pub fn solve_newton_direction_implicit<F>(
apply_xtwx: F,
xtwx_diag: ArrayView1<'_, f64>,
dense_penalties: &[(f64, &Array2<f64>)],
op_penalties: &[(f64, &dyn crate::terms::penalty_op::PenaltyOp)],
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
ridge: f64,
rel_tol: f64,
max_iter: usize,
) -> Result<(), EstimationError>
where
F: Fn(&Array1<f64>) -> Array1<f64>,
{
let p = gradient.len();
if xtwx_diag.len() != p {
return Err(EstimationError::InvalidInput(format!(
"solve_newton_direction_implicit: xtwx_diag length {} != gradient length {}",
xtwx_diag.len(),
p
)));
}
for (_, s) in dense_penalties.iter() {
if s.nrows() != p || s.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"solve_newton_direction_implicit: dense penalty dim {}×{} != p={}",
s.nrows(),
s.ncols(),
p
)));
}
}
for (_, op) in op_penalties.iter() {
if op.dim() != p {
return Err(EstimationError::InvalidInput(format!(
"solve_newton_direction_implicit: op penalty dim {} != p={}",
op.dim(),
p
)));
}
}
if direction_out.len() != p {
*direction_out = Array1::zeros(p);
}
let pcg_start = std::time::Instant::now();
let mut precond_diag = xtwx_diag.to_owned();
if ridge > 0.0 {
precond_diag.mapv_inplace(|d| d + ridge);
}
for (lambda, s) in dense_penalties.iter() {
if *lambda == 0.0 {
continue;
}
for i in 0..p {
precond_diag[i] += *lambda * s[[i, i]];
}
}
for (lambda, op) in op_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let d = op.diag();
for i in 0..p {
precond_diag[i] += *lambda * d[i];
}
}
let apply_h = |v: &Array1<f64>| -> Array1<f64> {
let mut hv = apply_xtwx(v);
if ridge > 0.0 {
hv.zip_mut_with(v, |h, &x| *h += ridge * x);
}
for (lambda, s) in dense_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let sv = s.dot(v);
hv.scaled_add(*lambda, &sv);
}
for (lambda, op) in op_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let mut sv = Array1::<f64>::zeros(p);
op.matvec(v.view(), sv.view_mut());
hv.scaled_add(*lambda, &sv);
}
hv
};
let solution =
crate::linalg::utils::solve_spd_pcg(apply_h, gradient, &precond_diag, rel_tol, max_iter)
.ok_or(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed,
))?;
direction_out.assign(&solution);
direction_out.mapv_inplace(|v| -v);
if !array1_is_finite(direction_out) {
return Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed,
));
}
log::info!(
"[STAGE] PIRLS implicit (PCG) newton solve p={} dense_pens={} op_pens={} elapsed={:.3}s",
p,
dense_penalties.len(),
op_penalties.len(),
pcg_start.elapsed().as_secs_f64(),
);
Ok(())
}
fn project_coefficients_to_lower_bounds(beta: &mut Array1<f64>, lower_bounds: &Array1<f64>) {
for i in 0..beta.len() {
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] < lb {
beta[i] = lb;
}
}
}
fn projected_gradient_norm(
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: Option<&Array1<f64>>,
) -> f64 {
let Some(lb) = lower_bounds else {
return gradient.dot(gradient).sqrt();
};
let mut sum_sq = 0.0;
for i in 0..gradient.len() {
let g = gradient[i];
if lb[i].is_finite() && g > 0.0 {
let slack = beta[i] - lb[i];
let scale = beta[i].abs().max(lb[i].abs()).max(1.0);
let tol = 1e-6 * scale + 1e-10;
if slack < tol {
continue;
}
}
sum_sq += g * g;
}
sum_sq.sqrt()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PirlsSoftAccept {
NearStationaryPlateau,
BoundarySaturation,
RelativeBandPlateau,
}
#[derive(Clone, Copy, Debug)]
enum SoftAcceptProgress {
Realized { dev_change: f64 },
Predicted {
predicted_reduction: f64,
current_penalized: f64,
},
}
#[inline]
fn pirls_soft_acceptance(
state: &WorkingState,
projected_grad: f64,
progress: SoftAcceptProgress,
max_abs_eta: f64,
tol: f64,
) -> Option<PirlsSoftAccept> {
let objective_scale = state.deviance.abs().max(state.penalty_term.abs()).max(1.0);
let scaled_dev_tol = tol * objective_scale;
let near_stationary_plateau = match progress {
SoftAcceptProgress::Realized { dev_change } => {
state.near_stationary_kkt(projected_grad, tol) && dev_change.abs() < scaled_dev_tol
}
SoftAcceptProgress::Predicted {
predicted_reduction,
current_penalized,
} => {
let reduction_noise_floor = current_penalized.abs().max(1.0) * 1e-12;
state.near_stationary_kkt(projected_grad, tol)
&& predicted_reduction.abs() <= reduction_noise_floor
}
};
if near_stationary_plateau {
return Some(PirlsSoftAccept::NearStationaryPlateau);
}
let dev_change = match progress {
SoftAcceptProgress::Realized { dev_change } => dev_change,
SoftAcceptProgress::Predicted { .. } => return None,
};
if max_abs_eta >= PIRLS_ETA_ABS_CAP * (1.0 - 1e-12) && dev_change.abs() < scaled_dev_tol {
return Some(PirlsSoftAccept::BoundarySaturation);
}
if projected_grad <= tol.max(1e-6) * objective_scale
&& dev_change.abs() < scaled_dev_tol * 0.1
&& dev_change >= 0.0
{
return Some(PirlsSoftAccept::RelativeBandPlateau);
}
None
}
fn constrained_stationarity_norm(
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: Option<&Array1<f64>>,
linear_constraints: Option<&LinearInequalityConstraints>,
) -> f64 {
if let Some(constraints) = linear_constraints {
let kkt = compute_constraint_kkt_diagnostics(beta, gradient, constraints);
return kkt
.primal_feasibility
.max(kkt.dual_feasibility)
.max(kkt.complementarity)
.max(kkt.stationarity);
}
projected_gradient_norm(gradient, beta, lower_bounds)
}
fn chunk_rows_for_nnz_count(n: usize, p: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 256;
const MAX_ROWS: usize = 65_536;
if p == 0 {
return n.max(1);
}
(TARGET_BYTES / (p * 8))
.clamp(MIN_ROWS, MAX_ROWS)
.min(n.max(1))
}
fn count_dense_upper_nnz(matrix: &Array2<f64>, tol: f64) -> usize {
let p = matrix.nrows().min(matrix.ncols());
let mut nnz = 0usize;
for col in 0..p {
for row in 0..=col {
if matrix[[row, col]].abs() > tol {
nnz += 1;
}
}
}
nnz
}
fn estimate_sparse_native_decision(
workspace: &mut PirlsWorkspace,
x_original: &DesignMatrix,
s_lambda: &Array2<f64>,
coefficient_lower_bounds: Option<&Array1<f64>>,
linear_constraints_original: Option<&LinearInequalityConstraints>,
) -> SparsePirlsDecision {
let p = x_original.ncols();
let nnz_s_lambda = count_dense_upper_nnz(s_lambda, 1e-12);
let dense_reject = |reason: &'static str, nnz_x: usize| SparsePirlsDecision {
path: PirlsLinearSolvePath::DenseTransformed,
reason,
p,
nnz_x,
nnz_xtwx_symbolic: None,
nnz_s_lambda,
nnz_h_est: None,
density_h_est: None,
};
let has_finite_lower_bounds = coefficient_lower_bounds
.map(|lb| lb.iter().any(|bound| bound.is_finite()))
.unwrap_or(false);
if has_finite_lower_bounds || linear_constraints_original.is_some() {
return dense_reject("constraints_present", 0);
}
let x_sparse = if let Some(sparse) = x_original.as_sparse() {
sparse
} else {
let row_chunk_start = std::time::Instant::now();
let n = x_original.nrows();
let chunk = chunk_rows_for_nnz_count(n, x_original.ncols());
let mut nnz: usize = 0;
let mut chunks_processed = 0usize;
if chunk > 0 && n > 0 {
let mut start = 0;
while start < n {
let end = (start + chunk).min(n);
chunks_processed += 1;
match x_original.try_row_chunk(start..end) {
Ok(rows) => {
nnz = nnz.saturating_add(rows.iter().filter(|v| v.abs() > 1e-12).count());
}
Err(_) => {
nnz = nnz.saturating_add((end - start).saturating_mul(x_original.ncols()));
}
}
start = end;
}
}
log::info!(
"[STAGE] PIRLS row-chunk generation chunks={} n={} p={} nnz={} elapsed={:.3}s",
chunks_processed,
n,
x_original.ncols(),
nnz,
row_chunk_start.elapsed().as_secs_f64(),
);
return dense_reject("design_not_sparse", nnz);
};
let nnz_x = x_sparse.val().len();
match workspace.sparse_penalized_system_stats(x_sparse, s_lambda) {
Ok(stats) => {
let decision = SparsePirlsDecision {
path: if stats.density_upper <= SPARSE_NATIVE_MAX_H_DENSITY {
PirlsLinearSolvePath::SparseNative
} else {
PirlsLinearSolvePath::DenseTransformed
},
reason: if stats.density_upper <= SPARSE_NATIVE_MAX_H_DENSITY {
"sparse_native_eligible"
} else {
"penalized_hessian_too_dense"
},
p,
nnz_x,
nnz_xtwx_symbolic: Some(stats.nnz_xtwx_symbolic),
nnz_s_lambda: stats.nnz_s_lambda_upper,
nnz_h_est: Some(stats.nnz_h_upper),
density_h_est: Some(stats.density_upper),
};
decision
}
Err(_) => dense_reject("sparse_stats_failed", nnz_x),
}
}
fn should_use_sparse_native_pirls(
workspace: &mut PirlsWorkspace,
x_original: &DesignMatrix,
s_lambda: &Array2<f64>,
coefficient_lower_bounds: Option<&Array1<f64>>,
linear_constraints_original: Option<&LinearInequalityConstraints>,
) -> SparsePirlsDecision {
estimate_sparse_native_decision(
workspace,
x_original,
s_lambda,
coefficient_lower_bounds,
linear_constraints_original,
)
}
pub(crate) fn sparse_reml_penalized_hessian(
workspace: &mut PirlsWorkspace,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
s_lambda: &Array2<f64>,
ridge: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
workspace.assemble_sparse_penalized_hessian(x, weights, s_lambda, ridge)
}
fn ensure_sparse_positive_definitewithridge<F>(
mut assemble: F,
) -> Result<
(
SparseColMat<usize, f64>,
crate::linalg::sparse_exact::SparseExactFactor,
f64,
),
EstimationError,
>
where
F: FnMut(f64) -> Result<SparseColMat<usize, f64>, EstimationError>,
{
let mut ridge = 0.0_f64;
for _ in 0..16 {
let h = assemble(ridge)?;
match factorize_sparse_spd(&h) {
Ok(factor) => return Ok((h, factor, ridge)),
Err(_) => {
ridge = if ridge == 0.0 {
FIXED_STABILIZATION_RIDGE
} else {
ridge * 10.0
};
}
}
}
Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
})
}
fn add_diagonal_to_upper_sparse(
matrix: &SparseColMat<usize, f64>,
diagonal: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
if diagonal == 0.0 {
return Ok(matrix.clone());
}
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
let has_all_diags = (0..matrix.ncols()).all(|col| {
let start = col_ptr[col];
let end = col_ptr[col + 1];
row_idx[start..end].contains(&col)
});
if has_all_diags {
let mut new_values = values.to_vec();
for col in 0..matrix.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
if row_idx[idx] == col {
new_values[idx] += diagonal;
break;
}
}
}
let mut triplets = Vec::with_capacity(values.len());
for col in 0..matrix.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
triplets.push(Triplet::new(row_idx[idx], col, new_values[idx]));
}
}
return SparseColMat::try_new_from_triplets(matrix.nrows(), matrix.ncols(), &triplets)
.map_err(|_| {
EstimationError::InvalidInput(
"failed to rebuild sparse matrix with diagonal update".to_string(),
)
});
}
let mut triplets = Vec::with_capacity(values.len() + matrix.ncols());
for col in 0..matrix.ncols() {
let mut saw_diag = false;
for idx in col_ptr[col]..col_ptr[col + 1] {
let row = row_idx[idx];
let mut value = values[idx];
if row == col {
value += diagonal;
saw_diag = true;
}
triplets.push(Triplet::new(row, col, value));
}
if !saw_diag {
triplets.push(Triplet::new(col, col, diagonal));
}
}
SparseColMat::try_new_from_triplets(matrix.nrows(), matrix.ncols(), &triplets).map_err(|_| {
EstimationError::InvalidInput("failed to add diagonal to sparse matrix".to_string())
})
}
fn solve_subsystem_direction(
h_sub: &Array2<f64>,
g_sub: &Array1<f64>,
out: &mut Array1<f64>,
) -> Result<(), EstimationError> {
let n = g_sub.len();
if out.len() != n {
*out = Array1::zeros(n);
}
if let Ok(factor) = StableSolver::new("pirls bounded subsystem").factorize(h_sub) {
out.assign(g_sub);
let mut rhs = array1_to_col_matmut(out);
factor.solve_in_place(rhs.as_mut());
out.mapv_inplace(|v| -v);
if array1_is_finite(out) {
return Ok(());
}
}
let diag_scale = (0..n)
.map(|i| h_sub[[i, i]].abs())
.fold(0.0_f64, f64::max)
.max(1.0);
let mut tau = 1e-8 * diag_scale;
let mut h_reg = h_sub.to_owned();
for _ in 0..12 {
for i in 0..n {
h_reg[[i, i]] = h_sub[[i, i]] + tau;
}
if let Ok(factor) = StableSolver::new("pirls bounded subsystem ridge").factorize(&h_reg) {
out.assign(g_sub);
let mut rhs = array1_to_col_matmut(out);
factor.solve_in_place(rhs.as_mut());
out.mapv_inplace(|v| -v);
if array1_is_finite(out) {
return Ok(());
}
}
tau *= 10.0;
}
let gnorm = g_sub.dot(g_sub).sqrt();
if gnorm > 0.0 {
let scale = 1.0 / gnorm.max(diag_scale);
for i in 0..n {
out[i] = -g_sub[i] * scale;
}
return Ok(());
}
out.fill(0.0);
Ok(())
}
fn linear_constraints_from_lower_bounds(
lower_bounds: &Array1<f64>,
) -> Option<LinearInequalityConstraints> {
let activerows: Vec<usize> = (0..lower_bounds.len())
.filter(|&i| lower_bounds[i].is_finite())
.collect();
if activerows.is_empty() {
return None;
}
let p = lower_bounds.len();
let mut a = Array2::<f64>::zeros((activerows.len(), p));
let mut b = Array1::<f64>::zeros(activerows.len());
for (r, &idx) in activerows.iter().enumerate() {
a[[r, idx]] = 1.0;
b[r] = lower_bounds[idx];
}
Some(LinearInequalityConstraints { a, b })
}
fn compute_constraint_kkt_diagnostics(
beta: &Array1<f64>,
gradient: &Array1<f64>,
constraints: &LinearInequalityConstraints,
) -> ConstraintKktDiagnostics {
active_set::compute_constraint_kkt_diagnostics(beta, gradient, constraints)
}
fn select_active_set_release(
gradient: &Array1<f64>,
hd: &Array1<f64>,
active_idx: &[usize],
use_blands: bool,
) -> Option<usize> {
if use_blands {
for &i in active_idx {
let lambda_i = gradient[i] + hd[i];
let scale = gradient[i].abs().max(hd[i].abs()).max(1.0);
let tol = 64.0 * f64::EPSILON * scale;
if lambda_i < -tol {
return Some(i);
}
}
None
} else {
let mut worst = 0.0_f64;
let mut idx = None;
for &i in active_idx {
let lambda_i = gradient[i] + hd[i];
if lambda_i < worst {
worst = lambda_i;
idx = Some(i);
}
}
idx
}
}
pub(crate) fn solve_newton_directionwith_lower_bounds(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: &Array1<f64>,
direction_out: &mut Array1<f64>,
active_hint: Option<&mut Vec<usize>>,
) -> Result<(), EstimationError> {
let p = gradient.len();
if lower_bounds.len() != p || beta.len() != p {
return Err(EstimationError::InvalidInput(format!(
"lower-bound size mismatch: beta={}, gradient={}, bounds={}",
beta.len(),
gradient.len(),
lower_bounds.len()
)));
}
if direction_out.len() != p {
*direction_out = Array1::zeros(p);
}
direction_out.fill(0.0);
let has_active_hint = active_hint
.as_ref()
.map(|hint| !hint.is_empty())
.unwrap_or(false);
if !has_active_hint && solve_newton_direction_dense(hessian, gradient, direction_out).is_ok() {
let mut feasible = true;
for i in 0..p {
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] + direction_out[i] < lb {
feasible = false;
break;
}
}
if feasible {
return Ok(());
}
}
let mut active = vec![false; p];
if let Some(hint) = active_hint.as_ref() {
for &idx in hint.iter() {
if idx < p {
active[idx] = true;
}
}
}
for i in 0..p {
let lb = lower_bounds[i];
if lb.is_finite() && gradient[i] > 0.0 {
let scale = beta[i].abs().max(lb.abs()).max(1.0);
let tol = 1e-6 * scale + 1e-10;
if beta[i] <= lb + tol {
active[i] = true;
}
}
}
const BLANDS_RULE_GRACE: usize = 2;
let blands_threshold = BLANDS_RULE_GRACE * (p + 1);
let max_iters = 8 * (p + 1);
let mut d_free = Array1::<f64>::zeros(p);
for it in 0..max_iters {
let use_blands = it >= blands_threshold;
let free_idx: Vec<usize> = (0..p).filter(|&i| !active[i]).collect();
let active_idx: Vec<usize> = (0..p).filter(|&i| active[i]).collect();
direction_out.fill(0.0);
for &i in &active_idx {
let lb = lower_bounds[i];
if lb.is_finite() {
direction_out[i] = lb - beta[i];
}
}
if free_idx.is_empty() {
let hd = hessian.dot(direction_out);
if let Some(idx) = select_active_set_release(gradient, &hd, &active_idx, use_blands) {
active[idx] = false;
continue;
}
if let Some(hint) = active_hint {
hint.clear();
hint.extend((0..p).filter(|&i| active[i]));
}
return Ok(());
}
let n_free = free_idx.len();
let mut h_ff = Array2::<f64>::zeros((n_free, n_free));
let mut g_f = Array1::<f64>::zeros(n_free);
for (ii, &i) in free_idx.iter().enumerate() {
g_f[ii] = gradient[i];
for &j in &active_idx {
g_f[ii] += hessian[[i, j]] * direction_out[j];
}
for (jj, &j) in free_idx.iter().enumerate() {
h_ff[[ii, jj]] = hessian[[i, j]];
}
}
solve_subsystem_direction(&h_ff, &g_f, &mut d_free)?;
for (ii, &i) in free_idx.iter().enumerate() {
direction_out[i] = d_free[ii];
}
let mut hit_idx: Option<usize> = None;
let mut best_alpha = 1.0_f64;
for &i in &free_idx {
let lb = lower_bounds[i];
if !lb.is_finite() {
continue;
}
let slack = beta[i] - lb;
let di = direction_out[i];
if let Some(alpha_i) = boundary_hit_step_fraction(slack, di, best_alpha) {
best_alpha = alpha_i;
hit_idx = Some(i);
}
}
if let Some(i_hit) = hit_idx {
for i in 0..p {
direction_out[i] *= best_alpha;
}
active[i_hit] = true;
continue;
}
let hd = hessian.dot(direction_out);
if let Some(idx) = select_active_set_release(gradient, &hd, &active_idx, use_blands) {
active[idx] = false;
continue;
}
if let Some(hint) = active_hint {
hint.clear();
hint.extend((0..p).filter(|&i| active[i]));
}
return Ok(());
}
let gnorm = gradient.dot(gradient).sqrt();
if gnorm > 0.0 {
let diag_scale = (0..p)
.map(|i| hessian[[i, i]].abs())
.fold(0.0_f64, f64::max)
.max(1.0);
let step_scale = 1.0 / diag_scale;
for i in 0..p {
let di = -gradient[i] * step_scale;
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] + di < lb {
direction_out[i] = lb - beta[i];
} else {
direction_out[i] = di;
}
}
} else {
direction_out.fill(0.0);
}
if let Some(hint) = active_hint {
hint.clear();
}
Ok(())
}
fn solve_newton_directionwith_linear_constraints(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
beta: &Array1<f64>,
constraints: &LinearInequalityConstraints,
direction_out: &mut Array1<f64>,
active_hint: Option<&mut Vec<usize>>,
) -> Result<(), EstimationError> {
active_set::solve_newton_direction_with_linear_constraints(
hessian,
gradient,
beta,
constraints,
direction_out,
active_hint,
)
}
fn default_beta_guess_external(
p: usize,
link_function: LinkFunction,
y: ArrayView1<f64>,
priorweights: ArrayView1<f64>,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Array1<f64> {
let mut beta = Array1::<f64>::zeros(p);
let intercept_col = 0usize;
match link_function {
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => {
let mut weighted_sum = 0.0;
let mut totalweight = 0.0;
for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
weighted_sum += wi * yi;
totalweight += wi;
}
if totalweight > 0.0 {
let prevalence =
((weighted_sum + 0.5) / (totalweight + 1.0)).clamp(1e-6, 1.0 - 1e-6);
beta[intercept_col] = match link_function {
LinkFunction::Logit => (prevalence / (1.0 - prevalence)).ln(),
LinkFunction::Probit => {
standard_normal_quantile(prevalence).unwrap_or_else(|_| {
(prevalence / (1.0 - prevalence)).ln()
})
}
LinkFunction::CLogLog => (-(1.0 - prevalence).ln()).ln(),
LinkFunction::Sas => solve_intercept_for_prevalence(
link_function,
prevalence,
mixture_link_state,
sas_link_state,
)
.unwrap_or_else(|| {
standard_normal_quantile(prevalence)
.unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
}),
LinkFunction::BetaLogistic => solve_intercept_for_prevalence(
link_function,
prevalence,
mixture_link_state,
sas_link_state,
)
.unwrap_or_else(|| {
standard_normal_quantile(prevalence)
.unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
}),
LinkFunction::Log => unreachable!(),
LinkFunction::Identity => unreachable!(),
};
if mixture_link_state.is_some() {
beta[intercept_col] = solve_intercept_for_prevalence(
link_function,
prevalence,
mixture_link_state,
sas_link_state,
)
.unwrap_or(beta[intercept_col]);
}
}
}
LinkFunction::Identity => {
let mut weighted_sum = 0.0;
let mut totalweight = 0.0;
for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
weighted_sum += wi * yi;
totalweight += wi;
}
if totalweight > 0.0 {
beta[intercept_col] = weighted_sum / totalweight;
}
}
LinkFunction::Log => {
let mut weighted_sum = 0.0;
let mut totalweight = 0.0;
for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
weighted_sum += wi * yi;
totalweight += wi;
}
if totalweight > 0.0 {
let mean_y = (weighted_sum / totalweight).max(1e-10);
beta[intercept_col] = mean_y.ln();
}
}
}
beta
}
fn solve_intercept_for_prevalence(
link_function: LinkFunction,
prevalence: f64,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Option<f64> {
#[inline]
fn f_eta(
link_function: LinkFunction,
eta: f64,
prevalence: f64,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> f64 {
let inverse_link = if let Some(state) = mixture_link_state {
InverseLink::Mixture(state.clone())
} else if let Some(state) = sas_link_state {
match link_function {
LinkFunction::BetaLogistic => InverseLink::BetaLogistic(*state),
_ => InverseLink::Sas(*state),
}
} else {
InverseLink::Standard(link_function)
};
standard_inverse_link_jet(&inverse_link, eta)
.map(|jet| jet.mu - prevalence)
.unwrap_or(f64::NAN)
}
let mut lo = -40.0;
let mut hi = 40.0;
let mut f_lo = f_eta(
link_function,
lo,
prevalence,
mixture_link_state,
sas_link_state,
);
let mut f_hi = f_eta(
link_function,
hi,
prevalence,
mixture_link_state,
sas_link_state,
);
if !(f_lo.is_finite() && f_hi.is_finite()) {
return None;
}
for _ in 0..8 {
if f_lo <= 0.0 && f_hi >= 0.0 {
break;
}
lo *= 2.0;
hi *= 2.0;
f_lo = f_eta(
link_function,
lo,
prevalence,
mixture_link_state,
sas_link_state,
);
f_hi = f_eta(
link_function,
hi,
prevalence,
mixture_link_state,
sas_link_state,
);
if !(f_lo.is_finite() && f_hi.is_finite()) {
return None;
}
}
if f_lo > 0.0 {
return Some(lo);
}
if f_hi < 0.0 {
return Some(hi);
}
for _ in 0..80 {
let mid = 0.5 * (lo + hi);
let f_mid = f_eta(
link_function,
mid,
prevalence,
mixture_link_state,
sas_link_state,
);
if !f_mid.is_finite() {
return None;
}
if f_mid > 0.0 {
hi = mid;
} else {
lo = mid;
}
}
Some(0.5 * (lo + hi))
}
#[inline]
fn madsen_lm_accept_factor(rho: f64) -> f64 {
let two_rho_minus_one = 2.0 * rho - 1.0;
let cube = two_rho_minus_one * two_rho_minus_one * two_rho_minus_one;
(1.0 - cube).clamp(1.0 / 3.0, 2.0)
}
pub fn runworking_model_pirls<M, F>(
model: &mut M,
mut beta: Coefficients,
options: &WorkingModelPirlsOptions,
mut iteration_callback: F,
) -> Result<WorkingModelPirlsResult, EstimationError>
where
M: WorkingModel + ?Sized,
F: FnMut(&WorkingModelIterationInfo),
{
const LM_MAX_LAMBDA: f64 = 1e12;
fn is_lm_retriable_candidate_error(err: &EstimationError) -> bool {
match err {
EstimationError::LinearSystemSolveFailed(_)
| EstimationError::HessianNotPositiveDefinite { .. } => true,
EstimationError::InvalidInput(message) => {
let message = message.to_ascii_lowercase();
message.contains("nan")
|| message.contains("non-finite")
|| message.contains("infinite")
|| message.contains("overflow")
|| message.contains("exceeds f64 range")
}
EstimationError::ParameterConstraintViolation(_) => true,
_ => false,
}
}
fn lm_can_retry(loop_lambda: f64) -> bool {
loop_lambda.is_finite() && loop_lambda < LM_MAX_LAMBDA
}
fn lm_retry_exhausted(loop_lambda: f64, attempts: usize, max_attempts: usize) -> bool {
attempts >= max_attempts || !loop_lambda.is_finite() || loop_lambda > LM_MAX_LAMBDA
}
fn lm_nonconvergence_error(
options: &WorkingModelPirlsOptions,
last_change: f64,
) -> EstimationError {
EstimationError::PirlsDidNotConverge {
max_iterations: options.max_iterations,
last_change,
}
}
if let Some(lb) = options.coefficient_lower_bounds.as_ref() {
project_coefficients_to_lower_bounds(&mut beta.0, lb);
}
let mut lastgradient_norm = f64::INFINITY;
let mut last_deviance_change = f64::INFINITY;
let mut last_step_size = 0.0;
let mut last_step_halving = 0usize;
let mut last_iter_accept_rho: Option<f64> = None;
let mut max_abs_eta = 0.0;
let mut status = PirlsStatus::MaxIterationsReached;
let mut iterations = 0usize;
let mut plateau_streak = 0usize;
let mut min_penalized_deviance = f64::INFINITY;
let mut final_state: Option<WorkingState> = None;
let mut initial_gradient_norm: Option<f64> = None;
let inner_solve_start = std::time::Instant::now();
let mut newton_direction = Array1::<f64>::zeros(beta.len());
let mut linear_active_hint: Option<Vec<usize>> =
options.linear_constraints.as_ref().map(|_| Vec::new());
let mut bound_active_hint: Option<Vec<usize>> = options
.coefficient_lower_bounds
.as_ref()
.map(|_| Vec::new());
let mut consecutive_fisher_fallbacks = 0usize;
let mut force_fisher_for_rest = false;
let mut regularized_buf: Option<crate::linalg::matrix::SymmetricMatrix> = None;
let penalizedobjective = |state: &WorkingState| {
let mut value = state.deviance + state.penalty_term;
if options.firth_bias_reduction {
if let Some(jeffreys_logdet) = state.jeffreys_logdet() {
value -= 2.0 * jeffreys_logdet;
}
}
value
};
let mut lambda = options
.initial_lm_lambda
.map(|v| v.clamp(1e-9, 1.0))
.unwrap_or(1e-6);
let lm_max_attempts = options.max_step_halving.max(1);
'pirls_loop: for iter in 1..=options.max_iterations {
iterations = iter;
let iter_start = std::time::Instant::now();
log::debug!(
"[PIRLS] start iter {:>3} | lm_lambda {:.2e} | last_halving {} | last_dev_change {:.3e}",
iter,
lambda,
last_step_halving,
last_deviance_change
);
let preferred_curvature =
if model.supports_observed_information_curvature() && !force_fisher_for_rest {
HessianCurvatureKind::Observed
} else {
HessianCurvatureKind::Fisher
};
let mut used_fisher_fallback_this_iter = false;
let curvature_start = std::time::Instant::now();
let cache_curvature_kind = final_state.as_ref().map(|s| s.hessian_curvature);
let cached_state_matches = iter > 1 && cache_curvature_kind == Some(preferred_curvature);
let mut state = if cached_state_matches {
final_state
.take()
.expect("cached_state_matches implies final_state.is_some()")
} else {
match model.update_with_curvature(&beta, preferred_curvature) {
Ok(state) => state,
Err(_) if preferred_curvature == HessianCurvatureKind::Observed => {
used_fisher_fallback_this_iter = true;
consecutive_fisher_fallbacks += 1;
if consecutive_fisher_fallbacks > 2 && !force_fisher_for_rest {
log::info!(
"[PIRLS] force_fisher_for_rest engaged at iter={} (consecutive_fisher_fallbacks={}) reason=iter_start",
iter,
consecutive_fisher_fallbacks,
);
force_fisher_for_rest = true;
}
model.update_with_curvature(&beta, HessianCurvatureKind::Fisher)?
}
Err(err) => return Err(err),
}
};
let mut curvature_total = curvature_start.elapsed();
log::info!(
"[STAGE] PIRLS update_with_curvature iter={} curvature={:?} elapsed={:.3}s source={}",
iter,
state.hessian_curvature,
curvature_total.as_secs_f64(),
if cached_state_matches {
"reused_prev_accept"
} else {
"rebuilt"
},
);
let mut lm_solve_total = std::time::Duration::ZERO;
let mut lm_candidate_total = std::time::Duration::ZERO;
let mut lm_predred_total = std::time::Duration::ZERO;
let mut lm_attempts_done = 0usize;
let current_penalized = penalizedobjective(&state);
if current_penalized.is_finite() && current_penalized < min_penalized_deviance {
min_penalized_deviance = current_penalized;
}
#[cfg(test)]
test_support::record_penalized_deviance(current_penalized);
if initial_gradient_norm.is_none() {
let g0_sq: f64 = state
.gradient
.iter()
.map(|g| if g.is_finite() { g * g } else { 0.0 })
.sum();
let g0 = g0_sq.sqrt();
if g0.is_finite() && g0 > 0.0 {
initial_gradient_norm = Some(g0);
}
}
let current_grad_finite = state.gradient.iter().all(|g| g.is_finite());
if !current_grad_finite {
lastgradient_norm = f64::INFINITY;
max_abs_eta = state.eta.iter().copied().map(f64::abs).fold(0.0, f64::max);
final_state = Some(state);
if last_deviance_change.abs() < options.convergence_tolerance {
status = PirlsStatus::StalledAtValidMinimum;
}
break 'pirls_loop;
}
let mut loop_lambda = lambda;
let mut attempts = 0;
let lm_start_lambda = lambda;
#[allow(unused_assignments)]
let mut lm_accept_rho: Option<f64> = None;
let mut madsen_reject_factor = 2.0_f64;
let mut regularized = match (regularized_buf.take(), state.hessian.as_dense()) {
(Some(crate::linalg::matrix::SymmetricMatrix::Dense(mut buf)), Some(src))
if buf.nrows() == src.nrows() && buf.ncols() == src.ncols() =>
{
buf.assign(src);
crate::linalg::matrix::SymmetricMatrix::Dense(buf)
}
_ => state.hessian.clone(),
};
let mut applied_lambda = 0.0_f64;
let mut cached_sparse_regularized: Option<SparseColMat<usize, f64>> = None;
loop {
attempts += 1;
lm_attempts_done += 1;
let attempt_solve_start = std::time::Instant::now();
if let crate::linalg::matrix::SymmetricMatrix::Dense(ref mut dense) = regularized {
let delta_lambda = loop_lambda - applied_lambda;
let dim = dense.nrows();
for i in 0..dim {
dense[[i, i]] += delta_lambda;
}
applied_lambda = loop_lambda;
}
let has_constraints =
options.linear_constraints.is_some() || options.coefficient_lower_bounds.is_some();
let direction = match if let Some(h_sparse) = state.hessian.as_sparse() {
if has_constraints {
Err(EstimationError::InvalidInput(
"sparse-native PIRLS does not support constrained solves".to_string(),
))
} else {
let sparse_reg = add_diagonal_to_upper_sparse(h_sparse, loop_lambda)?;
let factor = factorize_sparse_spd(&sparse_reg)?;
newton_direction.assign(&solve_sparse_spd(&factor, &state.gradient)?);
newton_direction.mapv_inplace(|g| -g);
cached_sparse_regularized = Some(sparse_reg);
Ok(())
}
} else {
let dense_reg = regularized.as_dense().ok_or_else(|| {
EstimationError::InvalidInput(
"PIRLS Newton step requires a dense Hessian but got a non-dense variant"
.to_string(),
)
})?;
if let Some(lin) = options.linear_constraints.as_ref() {
solve_newton_directionwith_linear_constraints(
dense_reg,
&state.gradient,
beta.as_ref(),
lin,
&mut newton_direction,
linear_active_hint.as_mut(),
)
} else if let Some(lb) = options.coefficient_lower_bounds.as_ref() {
solve_newton_directionwith_lower_bounds(
dense_reg,
&state.gradient,
beta.as_ref(),
lb,
&mut newton_direction,
bound_active_hint.as_mut(),
)
} else {
solve_newton_direction_dense(dense_reg, &state.gradient, &mut newton_direction)
}
} {
Ok(()) => &newton_direction,
Err(e) => {
if has_constraints {
return Err(EstimationError::ParameterConstraintViolation(format!(
"constrained PIRLS step solve failed at iteration {iter} with damping λ={loop_lambda:.3e}: {e}"
)));
}
if lm_can_retry(loop_lambda) {
lm_solve_total += attempt_solve_start.elapsed();
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
continue;
} else {
newton_direction.assign(&state.gradient);
newton_direction.mapv_inplace(|g| -g);
&newton_direction
}
}
};
lm_solve_total += attempt_solve_start.elapsed();
if !array1_is_finite(direction) {
if lm_can_retry(loop_lambda) {
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
continue;
}
let detail = if has_constraints {
"constrained PIRLS produced non-finite step direction"
} else {
"PIRLS produced non-finite step direction"
};
return Err(EstimationError::InvalidInput(format!(
"{detail} at iteration {iter} with damping λ={loop_lambda:.3e}"
)));
}
let predred_start = std::time::Instant::now();
let q_term = if let Some(sparse_reg) = cached_sparse_regularized.as_ref() {
sparse_symmetric_upper_matvec_public(sparse_reg, direction)
} else {
regularized.dot(direction)
};
let quad = 0.5 * direction.dot(&q_term);
let lin = state.gradient.dot(direction);
let predicted_reduction = -(lin + quad);
lm_predred_total += predred_start.elapsed();
let mut candidatevec = &*beta + direction;
if options.linear_constraints.is_none()
&& let Some(lb) = options.coefficient_lower_bounds.as_ref()
{
project_coefficients_to_lower_bounds(&mut candidatevec, lb);
}
let candidate_beta = Coefficients::new(candidatevec);
let candidate_eval_start = std::time::Instant::now();
let candidate_eval_result =
model.update_candidate(&candidate_beta, state.hessian_curvature);
lm_candidate_total += candidate_eval_start.elapsed();
match candidate_eval_result {
Ok(candidate_state) => {
let screening_penalized = penalizedobjective(&candidate_state);
let screening_reduction = current_penalized - screening_penalized;
let noise_floor = current_penalized.abs().max(1.0) * 1e-14;
let screening_rho = if predicted_reduction > noise_floor {
screening_reduction / predicted_reduction
} else if screening_reduction >= -noise_floor {
1.0
} else {
-1.0
};
let candidate_grad_finite =
candidate_state.gradient.iter().all(|g| g.is_finite());
if screening_rho > 0.0
&& screening_penalized.is_finite()
&& candidate_grad_finite
{
let accepted_state = if options.firth_bias_reduction {
let firth_curv_start = std::time::Instant::now();
let firth_curv_result = model
.update_with_curvature(&candidate_beta, state.hessian_curvature);
curvature_total += firth_curv_start.elapsed();
match firth_curv_result {
Ok(state) => state,
Err(err) => {
if !is_lm_retriable_candidate_error(&err) {
return Err(err);
}
if lm_retry_exhausted(loop_lambda, attempts, lm_max_attempts) {
return Err(lm_nonconvergence_error(
options,
constrained_stationarity_norm(
&state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
),
));
}
if lm_can_retry(loop_lambda) {
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
continue;
}
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
continue;
}
}
} else {
candidate_state
};
let candidate_penalized = penalizedobjective(&accepted_state);
if candidate_penalized.is_finite()
&& candidate_penalized < min_penalized_deviance
{
min_penalized_deviance = candidate_penalized;
}
let actual_reduction = current_penalized - candidate_penalized;
let rho = if predicted_reduction > noise_floor {
actual_reduction / predicted_reduction
} else if actual_reduction >= -noise_floor {
1.0
} else {
-1.0
};
if !(rho > 0.0 && candidate_penalized.is_finite()) {
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
continue;
}
if preferred_curvature == HessianCurvatureKind::Observed {
if state.hessian_curvature == HessianCurvatureKind::Observed
&& !used_fisher_fallback_this_iter
{
consecutive_fisher_fallbacks = 0;
}
}
lm_accept_rho = Some(rho);
last_iter_accept_rho = Some(rho);
lambda = (loop_lambda * madsen_lm_accept_factor(rho)).max(1e-9);
beta = candidate_beta;
let candidategrad_norm = constrained_stationarity_norm(
&accepted_state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
let deviance_change = actual_reduction;
iteration_callback(&WorkingModelIterationInfo {
iteration: iter,
deviance: accepted_state.deviance,
gradient_norm: candidategrad_norm,
step_size: 1.0,
step_halving: attempts, });
lastgradient_norm = candidategrad_norm;
last_deviance_change = deviance_change;
last_step_size = 1.0;
last_step_halving = attempts;
max_abs_eta = accepted_state
.eta
.iter()
.copied()
.map(f64::abs)
.fold(0.0, f64::max);
let convergence_grad_norm = constrained_stationarity_norm(
&accepted_state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
final_state = Some(accepted_state);
let final_state_ref = final_state
.as_ref()
.expect("final_state set immediately above");
let f_scale = 1.0 + current_penalized.abs();
let lambda_floor = final_state_ref.ridge_used.max(1.0e-12);
let nd_correction = 1.0 + loop_lambda / lambda_floor;
let newton_decrement_sq_upper = (-lin).max(0.0) * nd_correction;
let nd_threshold =
options.convergence_tolerance * options.convergence_tolerance * f_scale;
let nd_pass = newton_decrement_sq_upper <= nd_threshold;
if final_state_ref
.certifies_kkt(convergence_grad_norm, options.convergence_tolerance)
|| nd_pass
{
status = PirlsStatus::Converged;
break 'pirls_loop;
}
match pirls_soft_acceptance(
final_state_ref,
convergence_grad_norm,
SoftAcceptProgress::Realized {
dev_change: deviance_change,
},
max_abs_eta,
options.convergence_tolerance,
) {
Some(reason) => {
plateau_streak += 1;
if plateau_streak >= 2 {
log::debug!(
"[PIRLS] iter {iter} early-exit on soft acceptance: \
{reason:?} (‖g‖={convergence_grad_norm:.3e}, \
Δdev={deviance_change:.3e})"
);
status = PirlsStatus::StalledAtValidMinimum;
break 'pirls_loop;
}
}
None => {
plateau_streak = 0;
}
}
break; } else {
if state.hessian_curvature == HessianCurvatureKind::Observed
&& !used_fisher_fallback_this_iter
{
used_fisher_fallback_this_iter = true;
consecutive_fisher_fallbacks += 1;
if consecutive_fisher_fallbacks > 2 && !force_fisher_for_rest {
log::info!(
"[PIRLS] force_fisher_for_rest engaged at iter={} (consecutive_fisher_fallbacks={}) reason=gain_rejection",
iter,
consecutive_fisher_fallbacks,
);
force_fisher_for_rest = true;
}
log::info!(
"[PIRLS] mid-iter Fisher fallback iter={} reason=gain_rejection",
iter,
);
let fisher_fallback_start = std::time::Instant::now();
state =
model.update_with_curvature(&beta, HessianCurvatureKind::Fisher)?;
curvature_total += fisher_fallback_start.elapsed();
regularized = state.hessian.clone();
applied_lambda = 0.0;
cached_sparse_regularized = None;
loop_lambda = lambda;
madsen_reject_factor = 2.0;
continue;
}
let stategrad_norm = constrained_stationarity_norm(
&state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
let projected_grad = stategrad_norm;
let lm_rejection_soft = pirls_soft_acceptance(
&state,
projected_grad,
SoftAcceptProgress::Predicted {
predicted_reduction,
current_penalized,
},
state.eta.iter().copied().map(f64::abs).fold(0.0, f64::max),
options.convergence_tolerance,
);
let near_stationary_pass = state
.near_stationary_kkt(projected_grad, options.convergence_tolerance);
if let Some(reason) = lm_rejection_soft {
log::debug!(
"[PIRLS] LM-rejection soft acceptance: {reason:?} \
(‖g‖={projected_grad:.3e}, \
predicted_reduction={predicted_reduction:.3e})"
);
lastgradient_norm = stategrad_norm;
last_deviance_change = 0.0;
last_step_size = 0.0;
last_step_halving = attempts;
max_abs_eta =
state.eta.iter().copied().map(f64::abs).fold(0.0, f64::max);
final_state = Some(state.clone());
status = PirlsStatus::StalledAtValidMinimum;
break 'pirls_loop;
}
if lm_retry_exhausted(loop_lambda, attempts, lm_max_attempts) {
lastgradient_norm = stategrad_norm;
if near_stationary_pass {
status = PirlsStatus::StalledAtValidMinimum;
} else {
let ceiling =
!loop_lambda.is_finite() || loop_lambda > LM_MAX_LAMBDA;
let attempts_used = attempts >= lm_max_attempts;
let max_abs_eta_now = state
.eta
.iter()
.copied()
.map(f64::abs)
.fold(0.0_f64, f64::max);
let relative_grad = state.relative_gradient_norm(projected_grad);
log::debug!(
"[PIRLS] LM step search exhausted at iter={}: \
attempts={}/{} lambda={:.3e} (ceiling={}) \
projected_grad={:.3e} (relative={:.3e}) \
current_pen={:.6e} predicted_reduction={:.3e} \
max|eta|={:.1} attempts_exhausted={}",
iter,
attempts,
lm_max_attempts,
loop_lambda,
ceiling,
projected_grad,
relative_grad,
current_penalized,
predicted_reduction,
max_abs_eta_now,
attempts_used,
);
status = PirlsStatus::LmStepSearchExhausted;
}
final_state = Some(state.clone());
break 'pirls_loop;
}
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
}
}
Err(err) => {
if state.hessian_curvature == HessianCurvatureKind::Observed
&& !used_fisher_fallback_this_iter
{
used_fisher_fallback_this_iter = true;
consecutive_fisher_fallbacks += 1;
if consecutive_fisher_fallbacks > 2 && !force_fisher_for_rest {
log::info!(
"[PIRLS] force_fisher_for_rest engaged at iter={} (consecutive_fisher_fallbacks={}) reason=candidate_err",
iter,
consecutive_fisher_fallbacks,
);
force_fisher_for_rest = true;
}
log::info!(
"[PIRLS] mid-iter Fisher fallback iter={} reason=candidate_err",
iter,
);
let fisher_err_start = std::time::Instant::now();
state = model.update_with_curvature(&beta, HessianCurvatureKind::Fisher)?;
curvature_total += fisher_err_start.elapsed();
regularized = state.hessian.clone();
applied_lambda = 0.0;
cached_sparse_regularized = None;
loop_lambda = lambda;
madsen_reject_factor = 2.0;
continue;
}
if !is_lm_retriable_candidate_error(&err) {
return Err(err);
}
if lm_retry_exhausted(loop_lambda, attempts, lm_max_attempts) {
return Err(lm_nonconvergence_error(
options,
constrained_stationarity_norm(
&state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
),
));
}
loop_lambda *= madsen_reject_factor;
madsen_reject_factor *= 2.0;
}
}
} regularized_buf = Some(regularized);
let iter_elapsed = iter_start.elapsed();
log::info!(
"[PIRLS iter-end] iter={:>3} elapsed={:.4}s lm_lambda={:.2e} g_norm={:.3e} last_dev_change={:.3e} last_halving={}",
iter,
iter_elapsed.as_secs_f64(),
lambda,
lastgradient_norm,
last_deviance_change,
last_step_halving,
);
let timed_total = curvature_total + lm_solve_total + lm_predred_total + lm_candidate_total;
let other_total = iter_elapsed.saturating_sub(timed_total);
log::info!(
"[PIRLS iter-breakdown] iter={:>3} attempts={} curvature={:.3}s solve={:.3}s predred={:.3}s candidate={:.3}s other={:.3}s",
iter,
lm_attempts_done,
curvature_total.as_secs_f64(),
lm_solve_total.as_secs_f64(),
lm_predred_total.as_secs_f64(),
lm_candidate_total.as_secs_f64(),
other_total.as_secs_f64(),
);
let lambda_ratio_log10 = if lm_start_lambda > 0.0 && lambda > 0.0 {
(lambda / lm_start_lambda).log10()
} else {
f64::NAN
};
log::info!(
"[PIRLS lm-trajectory] iter={:>3} start_lambda={:.3e} final_lambda={:.3e} \
log10_ratio={:.3} accept_rho={:.3} attempts={}",
iter,
lm_start_lambda,
lambda,
lambda_ratio_log10,
lm_accept_rho.unwrap_or(f64::NAN),
lm_attempts_done,
);
}
let total_iters = iterations.max(1) as f64;
let convergence_rate = match initial_gradient_norm {
Some(g0) if g0 > 0.0 && lastgradient_norm.is_finite() => {
let ratio = (lastgradient_norm / g0).max(1e-30);
ratio.powf(1.0 / total_iters)
}
_ => f64::NAN,
};
log::info!(
"[PIRLS solve-end] iters={} elapsed={:.4}s g_norm_initial={:.3e} g_norm_final={:.3e} convergence_rate={:.3e} status={:?}",
iterations,
inner_solve_start.elapsed().as_secs_f64(),
initial_gradient_norm.unwrap_or(f64::NAN),
lastgradient_norm,
convergence_rate,
status,
);
let mut state = final_state.ok_or(EstimationError::PirlsDidNotConverge {
max_iterations: options.max_iterations,
last_change: lastgradient_norm,
})?;
let final_projected_grad = constrained_stationarity_norm(
&state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
if status.is_failed_max_iterations() {
let tol = options.convergence_tolerance;
if state.certifies_kkt(final_projected_grad, tol) {
log::debug!(
"[PIRLS] post-loop rescue: strict KKT after MaxIterations \
(‖g‖={final_projected_grad:.3e})"
);
status = PirlsStatus::StalledAtValidMinimum;
} else if let Some(reason) = pirls_soft_acceptance(
&state,
final_projected_grad,
SoftAcceptProgress::Realized {
dev_change: last_deviance_change,
},
max_abs_eta,
tol,
) {
log::debug!(
"[PIRLS] post-loop rescue on soft acceptance: {reason:?} \
(‖g‖={final_projected_grad:.3e}, \
Δdev={last_deviance_change:.3e})"
);
status = PirlsStatus::StalledAtValidMinimum;
}
}
let exported_laplace_curvature: ExportedLaplaceCurvature =
if model.supports_observed_information_curvature() {
match model.update_with_curvature(&beta, HessianCurvatureKind::Observed) {
Ok(observed_state) => {
let inertia = observed_state
.hessian
.as_dense()
.and_then(crate::linalg::utils::symmetric_extremes);
let (label, accept_observed) = match inertia {
Some((min_eig, max_eig)) => {
let pd_tolerance = max_eig.abs().max(1.0) * 1e-12;
if min_eig > -pd_tolerance {
(ExportedLaplaceCurvature::ObservedExact, true)
} else {
let g_norm = constrained_stationarity_norm(
&observed_state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
log::warn!(
"[PIRLS] post-convergence observed Hessian indefinite: \
λ_min={min_eig:.3e}, pd_tol={pd_tolerance:.3e}, ‖g‖={g_norm:.3e}"
);
(
ExportedLaplaceCurvature::InvalidObservedCurvature {
min_eigenvalue: min_eig,
pd_tolerance,
gradient_norm: g_norm,
},
true,
)
}
}
None => {
(ExportedLaplaceCurvature::ObservedExact, true)
}
};
if accept_observed {
state = observed_state;
}
label
}
Err(err) => {
let g_norm = constrained_stationarity_norm(
&state.gradient,
beta.as_ref(),
options.coefficient_lower_bounds.as_ref(),
options.linear_constraints.as_ref(),
);
log::warn!(
"[PIRLS] post-convergence observed Hessian assembly failed: {err}; \
exporting InvalidObservedCurvature with ‖g‖={g_norm:.3e}"
);
ExportedLaplaceCurvature::InvalidObservedCurvature {
min_eigenvalue: f64::NAN,
pd_tolerance: f64::NAN,
gradient_norm: g_norm,
}
}
}
} else {
ExportedLaplaceCurvature::ExpectedInformationSurrogate
};
Ok(WorkingModelPirlsResult {
constraint_kkt: options
.linear_constraints
.as_ref()
.map(|lin| compute_constraint_kkt_diagnostics(beta.as_ref(), &state.gradient, lin))
.or_else(|| {
options.coefficient_lower_bounds.as_ref().and_then(|lb| {
linear_constraints_from_lower_bounds(lb).map(|lin| {
compute_constraint_kkt_diagnostics(beta.as_ref(), &state.gradient, &lin)
})
})
}),
beta,
state,
status,
iterations,
lastgradient_norm,
last_deviance_change,
last_step_size,
last_step_halving,
max_abs_eta,
min_penalized_deviance,
final_lm_lambda: lambda,
final_accept_rho: last_iter_accept_rho,
exported_laplace_curvature,
})
}
#[cfg(test)]
mod test_support {
thread_local! {
static PIRLS_PENALIZED_DEVIANCE_TRACE: std::cell::RefCell<Option<Vec<f64>>> =
const { std::cell::RefCell::new(None) };
}
pub(super) fn capture_pirls_penalized_deviance<F, R>(run: F) -> (R, Vec<f64>)
where
F: FnOnce() -> R,
{
PIRLS_PENALIZED_DEVIANCE_TRACE.with(|trace| {
*trace.borrow_mut() = Some(Vec::new());
});
let result = run();
let captured =
PIRLS_PENALIZED_DEVIANCE_TRACE.with(|trace| trace.borrow_mut().take().unwrap());
(result, captured)
}
pub(super) fn record_penalized_deviance(value: f64) {
PIRLS_PENALIZED_DEVIANCE_TRACE.with(|trace| {
if let Some(ref mut buf) = *trace.borrow_mut() {
buf.push(value);
}
});
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PirlsStatus {
Converged,
StalledAtValidMinimum,
MaxIterationsReached,
LmStepSearchExhausted,
Unstable,
}
impl PirlsStatus {
#[inline]
pub fn is_failed_max_iterations(self) -> bool {
matches!(
self,
PirlsStatus::MaxIterationsReached | PirlsStatus::LmStepSearchExhausted
)
}
}
#[derive(Clone)]
pub struct PirlsResult {
pub likelihood: GlmLikelihoodSpec,
pub beta_transformed: Coefficients,
pub penalized_hessian_transformed: SymmetricMatrix,
pub stabilizedhessian_transformed: SymmetricMatrix,
pub ridge_passport: RidgePassport,
pub ridge_used: f64,
pub deviance: f64,
pub edf: f64,
pub stable_penalty_term: f64,
pub firth: FirthDiagnostics,
pub finalweights: Array1<f64>,
pub final_offset: Array1<f64>,
pub final_eta: Array1<f64>,
pub finalmu: Array1<f64>,
pub solveweights: Array1<f64>,
pub solveworking_response: Array1<f64>,
pub solvemu: Array1<f64>,
pub solve_dmu_deta: Array1<f64>,
pub solve_d2mu_deta2: Array1<f64>,
pub solve_d3mu_deta3: Array1<f64>,
pub solve_c_array: Array1<f64>,
pub solve_d_array: Array1<f64>,
pub status: PirlsStatus,
pub iteration: usize,
pub max_abs_eta: f64,
pub lastgradient_norm: f64,
pub gradient_natural_scale: f64,
pub last_deviance_change: f64,
pub last_step_halving: usize,
pub hessian_curvature: HessianCurvatureKind,
pub exported_laplace_curvature: ExportedLaplaceCurvature,
pub final_lm_lambda: f64,
pub final_accept_rho: Option<f64>,
pub constraint_kkt: Option<ConstraintKktDiagnostics>,
pub linear_constraints_transformed: Option<LinearInequalityConstraints>,
pub reparam_result: ReparamResult,
pub x_transformed: DesignMatrix,
pub coordinate_frame: PirlsCoordinateFrame,
pub cache_compacted: bool,
pub min_penalized_deviance: f64,
}
impl PirlsResult {
pub fn dense_stabilizedhessian_transformed(
&self,
context: &str,
) -> Result<Array2<f64>, EstimationError> {
self.stabilizedhessian_transformed
.try_to_dense_exact(context)
.map_err(EstimationError::InvalidInput)
}
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
self.firth.jeffreys_logdet()
}
#[inline]
pub fn relative_gradient_norm(&self) -> f64 {
self.lastgradient_norm / (1.0 + self.gradient_natural_scale)
}
pub(crate) fn compact_for_reml_cache(&self) -> Self {
Self {
likelihood: self.likelihood,
beta_transformed: self.beta_transformed.clone(),
penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
ridge_passport: self.ridge_passport,
ridge_used: self.ridge_used,
deviance: self.deviance,
edf: self.edf,
stable_penalty_term: self.stable_penalty_term,
firth: self.firth.clone(),
finalweights: Array1::zeros(0),
final_offset: Array1::zeros(0),
final_eta: self.final_eta.clone(),
finalmu: Array1::zeros(0),
solveweights: self.solveweights.clone(),
solveworking_response: self.solveworking_response.clone(),
solvemu: self.solvemu.clone(),
solve_dmu_deta: Array1::zeros(0),
solve_d2mu_deta2: Array1::zeros(0),
solve_d3mu_deta3: Array1::zeros(0),
solve_c_array: self.solve_c_array.clone(),
solve_d_array: self.solve_d_array.clone(),
status: self.status,
iteration: self.iteration,
max_abs_eta: self.max_abs_eta,
lastgradient_norm: self.lastgradient_norm,
gradient_natural_scale: self.gradient_natural_scale,
last_deviance_change: self.last_deviance_change,
last_step_halving: self.last_step_halving,
hessian_curvature: self.hessian_curvature,
exported_laplace_curvature: self.exported_laplace_curvature.clone(),
final_lm_lambda: self.final_lm_lambda,
final_accept_rho: self.final_accept_rho,
constraint_kkt: self.constraint_kkt.clone(),
linear_constraints_transformed: self.linear_constraints_transformed.clone(),
reparam_result: self.reparam_result.clone(),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((0, 0)),
)),
coordinate_frame: self.coordinate_frame.clone(),
cache_compacted: true,
min_penalized_deviance: self.min_penalized_deviance,
}
}
pub(crate) fn rehydrate_after_reml_cache(
&self,
x_original: &DesignMatrix,
y: ArrayView1<'_, f64>,
priorweights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
inverse_link: &InverseLink,
) -> Result<Self, EstimationError> {
if !self.cache_compacted {
return Ok(self.clone());
}
let (score_c_array, score_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
computeworkingweight_derivatives_from_eta(
self.likelihood,
inverse_link,
&self.final_eta,
priorweights,
)?;
let (finalweights, solve_c_array, solve_d_array) =
if self.hessian_curvature == HessianCurvatureKind::Observed {
compute_observed_hessian_curvature_arrays(
self.likelihood,
inverse_link,
&self.final_eta,
y,
&self.solvemu,
&solve_dmu_deta,
&solve_d2mu_deta2,
&solve_d3mu_deta3,
&self.solveweights,
priorweights,
)?
} else {
(
self.solveweights.clone(),
score_c_array.clone(),
score_d_array.clone(),
)
};
let qs_arc = Arc::new(self.reparam_result.qs.clone());
Ok(Self {
likelihood: self.likelihood,
beta_transformed: self.beta_transformed.clone(),
penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
ridge_passport: self.ridge_passport,
ridge_used: self.ridge_used,
deviance: self.deviance,
edf: self.edf,
stable_penalty_term: self.stable_penalty_term,
firth: self.firth.clone(),
finalweights,
final_offset: offset.to_owned(),
final_eta: self.final_eta.clone(),
finalmu: self.solvemu.clone(),
solveweights: self.solveweights.clone(),
solveworking_response: self.solveworking_response.clone(),
solvemu: self.solvemu.clone(),
solve_dmu_deta,
solve_d2mu_deta2,
solve_d3mu_deta3,
solve_c_array,
solve_d_array,
status: self.status,
iteration: self.iteration,
max_abs_eta: self.max_abs_eta,
lastgradient_norm: self.lastgradient_norm,
gradient_natural_scale: self.gradient_natural_scale,
last_deviance_change: self.last_deviance_change,
last_step_halving: self.last_step_halving,
hessian_curvature: self.hessian_curvature,
exported_laplace_curvature: self.exported_laplace_curvature.clone(),
final_lm_lambda: self.final_lm_lambda,
final_accept_rho: self.final_accept_rho,
constraint_kkt: self.constraint_kkt.clone(),
linear_constraints_transformed: self.linear_constraints_transformed.clone(),
reparam_result: self.reparam_result.clone(),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(
ReparamOperator::new(x_original.clone(), qs_arc),
))),
coordinate_frame: self.coordinate_frame.clone(),
cache_compacted: false,
min_penalized_deviance: self.min_penalized_deviance,
})
}
}
fn assemble_pirls_result(
working_summary: &WorkingModelPirlsResult,
likelihood: GlmLikelihoodSpec,
offset: ArrayView1<'_, f64>,
penalized_hessian_transformed: SymmetricMatrix,
stabilizedhessian_transformed: SymmetricMatrix,
edf: f64,
penalty_term: f64,
finalmu: &Array1<f64>,
finalweights: &Array1<f64>,
scoreweights: &Array1<f64>,
finalz: &Array1<f64>,
final_c: &Array1<f64>,
final_d: &Array1<f64>,
final_dmu_deta: &Array1<f64>,
final_d2mu_deta2: &Array1<f64>,
final_d3mu_deta3: &Array1<f64>,
status: PirlsStatus,
reparam_result: ReparamResult,
x_transformed: DesignMatrix,
coordinate_frame: PirlsCoordinateFrame,
linear_constraints_transformed: Option<LinearInequalityConstraints>,
) -> PirlsResult {
let final_eta_arr = working_summary.state.eta.as_ref().clone();
PirlsResult {
likelihood,
beta_transformed: working_summary.beta.clone(),
penalized_hessian_transformed,
stabilizedhessian_transformed,
ridge_passport: RidgePassport::scaled_identity(
working_summary.state.ridge_used,
RidgePolicy::explicit_stabilization_full(),
),
ridge_used: working_summary.state.ridge_used,
deviance: working_summary.state.deviance,
edf,
stable_penalty_term: penalty_term,
firth: working_summary.state.firth.clone(),
finalweights: finalweights.clone(),
final_offset: offset.to_owned(),
final_eta: final_eta_arr,
finalmu: finalmu.clone(),
solveweights: scoreweights.clone(),
solveworking_response: finalz.clone(),
solvemu: finalmu.clone(),
solve_dmu_deta: final_dmu_deta.clone(),
solve_d2mu_deta2: final_d2mu_deta2.clone(),
solve_d3mu_deta3: final_d3mu_deta3.clone(),
solve_c_array: final_c.clone(),
solve_d_array: final_d.clone(),
status,
iteration: working_summary.iterations,
max_abs_eta: working_summary.max_abs_eta,
lastgradient_norm: working_summary.lastgradient_norm,
gradient_natural_scale: working_summary.state.gradient_natural_scale,
last_deviance_change: working_summary.last_deviance_change,
last_step_halving: working_summary.last_step_halving,
hessian_curvature: working_summary.state.hessian_curvature,
exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
final_lm_lambda: working_summary.final_lm_lambda,
final_accept_rho: working_summary.final_accept_rho,
constraint_kkt: working_summary.constraint_kkt.clone(),
linear_constraints_transformed,
reparam_result,
x_transformed,
coordinate_frame,
cache_compacted: false,
min_penalized_deviance: working_summary.min_penalized_deviance,
}
}
fn detect_logit_instability(
link: LinkFunction,
has_penalty: bool,
firth_active: bool,
summary: &WorkingModelPirlsResult,
finalmu: &Array1<f64>,
finalweights: &Array1<f64>,
y: ArrayView1<'_, f64>,
) -> bool {
if link != LinkFunction::Logit || firth_active {
return false;
}
let n = y.len() as f64;
if n == 0.0 {
return false;
}
let max_abs_eta = summary.max_abs_eta;
let sat_fraction = {
const SAT_EPS: f64 = 1e-3;
finalmu
.iter()
.filter(|&&m| m <= SAT_EPS || m >= 1.0 - SAT_EPS)
.count() as f64
/ n
};
let weight_collapse_fraction = {
const WEIGHT_EPS: f64 = 1e-8;
finalweights
.iter()
.filter(|&&w| w <= WEIGHT_EPS || !w.is_finite())
.count() as f64
/ n
};
let beta_norm = summary.beta.as_ref().dot(summary.beta.as_ref()).sqrt();
let dev_per_sample = summary.state.deviance / n;
let mut has_pos = false;
let mut has_neg = false;
let mut min_eta_pos = f64::INFINITY;
let mut max_eta_neg = f64::NEG_INFINITY;
for (eta_i, &yi) in summary.state.eta.iter().zip(y.iter()) {
if yi > 0.5 {
has_pos = true;
if *eta_i < min_eta_pos {
min_eta_pos = *eta_i;
}
} else {
has_neg = true;
if *eta_i > max_eta_neg {
max_eta_neg = *eta_i;
}
}
}
let order_separated = has_pos && has_neg && (min_eta_pos - max_eta_neg) > 1e-3;
let classic_signals =
max_abs_eta > 30.0 || sat_fraction > 0.98 || dev_per_sample < 1e-3 || beta_norm > 1e4;
if !has_penalty {
return classic_signals || order_separated;
}
let severe_saturation = sat_fraction > 0.995 && max_abs_eta > 30.0;
let weights_collapsed = weight_collapse_fraction > 0.98;
let dev_extremely_small = dev_per_sample < 1e-6;
order_separated || severe_saturation || weights_collapsed || dev_extremely_small
}
fn stack_lambdaweighted_penalty_root_canonical(
penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
p: usize,
) -> Array2<f64> {
let totalrows: usize = penalties.iter().map(|cp| cp.rank()).sum();
if totalrows == 0 {
return Array2::zeros((0, p));
}
let mut e = Array2::<f64>::zeros((totalrows, p));
let mut row_start = 0usize;
for (k, cp) in penalties.iter().enumerate() {
let rows = cp.rank();
if rows == 0 {
continue;
}
let scale = lambdas.get(k).copied().unwrap_or(0.0).max(0.0).sqrt();
if scale != 0.0 {
let r = &cp.col_range;
for row in 0..rows {
for col in 0..cp.block_dim() {
e[[row_start + row, r.start + col]] = scale * cp.root[[row, col]];
}
}
}
row_start += rows;
}
e
}
fn build_sparse_native_reparam_result(
base: ReparamResult,
penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
p: usize,
) -> ReparamResult {
let mut s_original = Array2::<f64>::zeros((p, p));
for (k, cp) in penalties.iter().enumerate() {
let lambda_k = lambdas.get(k).copied().unwrap_or(0.0);
if lambda_k != 0.0 {
cp.accumulate_weighted(&mut s_original, lambda_k);
}
}
let u_original = if base.u_truncated.nrows() == p {
base.qs.dot(&base.u_truncated)
} else {
Array2::<f64>::eye(p)
};
let canonical_transformed: Vec<crate::construction::CanonicalPenalty> =
penalties.iter().cloned().collect();
ReparamResult {
penalty_shrinkage_ridge: base.penalty_shrinkage_ridge,
s_transformed: s_original,
log_det: base.log_det,
det1: base.det1,
qs: Array2::<f64>::eye(p),
canonical_transformed,
e_transformed: stack_lambdaweighted_penalty_root_canonical(penalties, lambdas, p),
u_truncated: u_original,
}
}
fn build_diagonal_penalty_from_kronecker(
kron_result: &KroneckerReparamResult,
lambdas: &[f64],
) -> PirlsPenalty {
let d = kron_result.marginal_dims.len();
let p: usize = kron_result.marginal_dims.iter().copied().product();
let mut diag = Array1::<f64>::zeros(p);
let mut positive_indices = Vec::new();
let mut multi_idx = vec![0usize; d];
let mut flat = 0usize;
loop {
let mut sigma = kron_result.penalty_shrinkage_ridge;
for k in 0..d {
sigma += lambdas[k] * kron_result.marginal_eigenvalues[k][multi_idx[k]];
}
if kron_result.has_double_penalty && lambdas.len() > d {
sigma += lambdas[d];
}
diag[flat] = sigma;
if sigma > 0.0 {
positive_indices.push(flat);
}
flat += 1;
let mut carry = true;
for dim in (0..d).rev() {
if carry {
multi_idx[dim] += 1;
if multi_idx[dim] < kron_result.marginal_dims[dim] {
carry = false;
} else {
multi_idx[dim] = 0;
}
}
}
if carry {
break;
}
}
PirlsPenalty::Diagonal {
diag,
positive_indices,
}
}
pub struct PirlsProblem<'a, X> {
pub x: X,
pub offset: ArrayView1<'a, f64>,
pub y: ArrayView1<'a, f64>,
pub priorweights: ArrayView1<'a, f64>,
pub covariate_se: Option<ArrayView1<'a, f64>>,
}
pub struct PenaltyConfig<'a> {
pub canonical_penalties: &'a [crate::construction::CanonicalPenalty],
pub balanced_penalty_root: Option<&'a Array2<f64>>,
pub reparam_invariant: Option<&'a crate::construction::ReparamInvariant>,
pub p: usize,
pub coefficient_lower_bounds: Option<&'a Array1<f64>>,
pub linear_constraints_original: Option<&'a LinearInequalityConstraints>,
pub penalty_shrinkage_floor: Option<f64>,
pub kronecker_factored: Option<&'a crate::basis::KroneckerFactoredBasis>,
}
pub fn fit_model_for_fixed_rho<'a, X: Into<DesignMatrix> + Clone>(
rho: LogSmoothingParamsView<'_>,
problem: PirlsProblem<'a, X>,
penalty: PenaltyConfig<'_>,
config: &PirlsConfig,
warm_start_beta: Option<&Coefficients>,
) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
let PirlsProblem {
x,
offset,
y,
priorweights,
covariate_se,
} = problem;
let quadctx = crate::quadrature::QuadratureContext::new();
let lambdas = rho.exp();
let lambdas_slice = lambdas.as_slice_memory_order().ok_or_else(|| {
EstimationError::InvalidInput("non-contiguous lambda storage".to_string())
})?;
let likelihood = config.likelihood;
let link_function = config.link_function();
use crate::construction::{
EngineDims, create_balanced_penalty_root_from_canonical,
stable_reparameterization_engine_canonical,
};
let eb_cow: Cow<'_, Array2<f64>> = if let Some(precomputed) = penalty.balanced_penalty_root {
Cow::Borrowed(precomputed)
} else {
Cow::Owned(create_balanced_penalty_root_from_canonical(
penalty.canonical_penalties,
penalty.p,
)?)
};
let eb: &Array2<f64> = eb_cow.as_ref();
let cheap_s_lambda: Option<Array2<f64>> = if penalty.kronecker_factored.is_none() {
let mut s = Array2::<f64>::zeros((penalty.p, penalty.p));
for (k, cp) in penalty.canonical_penalties.iter().enumerate() {
let lam = lambdas_slice.get(k).copied().unwrap_or(0.0);
if lam != 0.0 {
cp.accumulate_weighted(&mut s, lam);
}
}
Some(s)
} else {
None
};
let kronecker_runtime = if let Some(kron) = penalty.kronecker_factored {
let kron_result = crate::construction::kronecker_reparameterization_engine(
&kron.marginal_designs,
&kron.marginal_penalties,
&kron.marginal_dims,
lambdas_slice,
kron.has_double_penalty,
penalty.penalty_shrinkage_floor,
)?;
let transform = Arc::new(KroneckerQsTransform::new(&kron_result));
let penalty_diag = build_diagonal_penalty_from_kronecker(&kron_result, lambdas_slice);
Some((kron_result, transform, penalty_diag))
} else {
None
};
let kronecker_constraints = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
let tb = build_transformed_lower_bound_constraints_with_transform(
&WorkingReparamTransform::Kronecker(Arc::clone(transform)),
penalty.coefficient_lower_bounds,
);
let tl = build_transformed_linear_constraints_with_transform(
&WorkingReparamTransform::Kronecker(Arc::clone(transform)),
penalty.linear_constraints_original,
);
Some(merge_linear_constraints(tb, tl))
} else {
None
};
let x_original: DesignMatrix = x.into();
let x_original = {
let auto_sparse = x_original
.as_dense()
.and_then(|dense| sparse_from_denseview(dense.view()));
auto_sparse.unwrap_or(x_original)
};
let ebrows = eb.nrows();
let erows = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
penalty_diag.rank()
} else {
penalty
.canonical_penalties
.iter()
.map(|cp| cp.rank())
.sum::<usize>()
};
let mut workspace = PirlsWorkspace::new(x_original.nrows(), x_original.ncols(), ebrows, erows);
let solver_decision = if let Some((_, _, _)) = kronecker_runtime.as_ref() {
SparsePirlsDecision {
path: PirlsLinearSolvePath::DenseTransformed,
reason: "kronecker_runtime",
p: x_original.ncols(),
nnz_x: 0,
nnz_xtwx_symbolic: None,
nnz_s_lambda: 0,
nnz_h_est: None,
density_h_est: None,
}
} else {
should_use_sparse_native_pirls(
&mut workspace,
&x_original,
cheap_s_lambda
.as_ref()
.expect("cheap_s_lambda should be present outside Kronecker path"),
penalty.coefficient_lower_bounds,
penalty.linear_constraints_original,
)
};
solver_decision.log_once();
let use_sparse_native = matches!(solver_decision.path, PirlsLinearSolvePath::SparseNative);
let dense_reparam_result = if !use_sparse_native && penalty.kronecker_factored.is_none() {
Some(stable_reparameterization_engine_canonical(
penalty.canonical_penalties,
lambdas_slice,
EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
penalty.reparam_invariant,
penalty.penalty_shrinkage_floor,
)?)
} else {
None
};
let qs_arc = dense_reparam_result
.as_ref()
.map(|reparam_result| Arc::new(reparam_result.qs.clone()));
let transform_active = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
Some(WorkingReparamTransform::Kronecker(Arc::clone(transform)))
} else if use_sparse_native {
None
} else {
Some(WorkingReparamTransform::Dense(Arc::clone(
qs_arc
.as_ref()
.expect("dense Qs should exist for non-Kronecker transformed path"),
)))
};
let penalty_active = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
penalty_diag.clone()
} else if use_sparse_native {
let s_lambda = cheap_s_lambda
.as_ref()
.expect("cheap_s_lambda should be present for sparse-native path")
.clone();
let e_root = stack_lambdaweighted_penalty_root_canonical(
penalty.canonical_penalties,
lambdas_slice,
penalty.p,
);
PirlsPenalty::Dense {
s_transformed: s_lambda,
e_transformed: e_root,
}
} else {
let dense = dense_reparam_result
.as_ref()
.expect("dense reparam result should be present outside Kronecker path");
PirlsPenalty::Dense {
s_transformed: dense.s_transformed.clone(),
e_transformed: dense.e_transformed.clone(),
}
};
let linear_constraints = if let Some(kc) = kronecker_constraints {
kc
} else if let Some(reparam) = dense_reparam_result.as_ref() {
let tb = build_transformed_lower_bound_constraints(
&reparam.qs,
penalty.coefficient_lower_bounds,
);
let tl =
build_transformed_linear_constraints(&reparam.qs, penalty.linear_constraints_original);
merge_linear_constraints(tb, tl)
} else {
let p = penalty.p;
let qs_identity = Array2::<f64>::eye(p);
let tb = build_transformed_lower_bound_constraints(
&qs_identity,
penalty.coefficient_lower_bounds,
);
let tl =
build_transformed_linear_constraints(&qs_identity, penalty.linear_constraints_original);
merge_linear_constraints(tb, tl)
};
let coordinate_frame = if use_sparse_native {
PirlsCoordinateFrame::OriginalSparseNative
} else {
PirlsCoordinateFrame::TransformedQs
};
let materialize_final_reparam_result = || -> Result<ReparamResult, EstimationError> {
if let Some((kron_result, _, _)) = kronecker_runtime.as_ref() {
let rs_list: Vec<Array2<f64>> = penalty
.canonical_penalties
.iter()
.map(|cp| cp.full_width_root())
.collect();
kron_result.materialize_dense_artifact_result(&rs_list, lambdas_slice, penalty.p)
} else if use_sparse_native {
let base = stable_reparameterization_engine_canonical(
penalty.canonical_penalties,
lambdas_slice,
EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
penalty.reparam_invariant,
penalty.penalty_shrinkage_floor,
)?;
Ok(build_sparse_native_reparam_result(
base,
penalty.canonical_penalties,
lambdas_slice,
penalty.p,
))
} else {
Ok(dense_reparam_result
.as_ref()
.expect("dense reparam result should be present outside Kronecker path")
.clone())
}
};
if matches!(link_function, LinkFunction::Identity) {
let (pls_result, _) = solve_penalized_least_squares_implicit(
&x_original,
transform_active.as_ref(),
y,
priorweights,
offset,
&penalty_active,
&mut workspace,
y,
link_function,
)?;
let beta_transformed = pls_result.beta;
let penalized_hessian = pls_result.penalized_hessian;
let edf = pls_result.edf;
let baseridge = pls_result.ridge_used;
let priorweights_owned = priorweights.to_owned();
let qbeta = transform_active
.as_ref()
.map(|transform| transform.apply(beta_transformed.as_ref()))
.unwrap_or_else(|| beta_transformed.as_ref().clone());
let mut eta = offset.to_owned();
eta += &x_original.apply(&qbeta);
let final_eta = eta.clone();
let finalmu = eta.clone();
let finalz = y.to_owned();
let mut weighted_residual = finalmu.clone();
weighted_residual -= &finalz;
weighted_residual *= &priorweights_owned;
let xt_wr = x_original.apply_transpose(&weighted_residual);
let gradient_data = transform_active
.as_ref()
.map(|transform| transform.apply_transpose(&xt_wr))
.unwrap_or(xt_wr);
let score_norm = array1_l2_norm(&gradient_data);
let s_beta = penalty_active.apply(beta_transformed.as_ref());
let s_beta_norm = array1_l2_norm(&s_beta);
let mut gradient = gradient_data;
gradient += &s_beta;
let mut penalty_term = beta_transformed.as_ref().dot(&s_beta);
let deviance = calculate_deviance(y, &finalmu, likelihood, priorweights);
let ridge_used = baseridge;
let stabilizedhessian = if ridge_used > 0.0 {
penalized_hessian
.addridge(ridge_used)
.map_err(|e| EstimationError::InvalidInput(format!("ridge addition failed: {e}")))?
} else {
penalized_hessian.clone()
};
let mut ridge_grad_norm = 0.0;
if ridge_used > 0.0 {
let ridge_penalty =
ridge_used * beta_transformed.as_ref().dot(beta_transformed.as_ref());
penalty_term += ridge_penalty;
gradient += &beta_transformed.as_ref().mapv(|v| ridge_used * v);
ridge_grad_norm = ridge_used * array1_l2_norm(beta_transformed.as_ref());
}
let gradient_norm = array1_l2_norm(&gradient);
let max_abs_eta = finalmu.iter().copied().map(f64::abs).fold(0.0, f64::max);
let log_likelihood =
calculate_loglikelihood_omitting_constants(y, &finalmu, likelihood, priorweights);
let working_state = WorkingState {
eta: LinearPredictor::new(finalmu.clone()),
gradient: gradient.clone(),
hessian: penalized_hessian.clone(),
log_likelihood,
deviance,
penalty_term,
firth: FirthDiagnostics::Inactive,
ridge_used,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
};
let zero_iter_penalized = deviance + penalty_term;
let working_summary = WorkingModelPirlsResult {
beta: beta_transformed.clone(),
state: working_state,
status: PirlsStatus::Converged,
iterations: 1,
lastgradient_norm: gradient_norm,
last_deviance_change: 0.0,
last_step_size: 1.0,
last_step_halving: 0,
max_abs_eta,
constraint_kkt: linear_constraints.as_ref().map(|lin| {
compute_constraint_kkt_diagnostics(beta_transformed.as_ref(), &gradient, lin)
}),
min_penalized_deviance: if zero_iter_penalized.is_finite() {
zero_iter_penalized
} else {
f64::INFINITY
},
final_lm_lambda: 1e-6,
final_accept_rho: None,
exported_laplace_curvature: ExportedLaplaceCurvature::ExpectedInformationSurrogate,
};
let (solve_c_array, solve_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
computeworkingweight_derivatives_from_eta(
config.likelihood,
&config.link_kind,
&final_eta,
priorweights_owned.view(),
)?;
let reparam_result = materialize_final_reparam_result()?;
let qs_arc_final = Arc::new(reparam_result.qs.clone());
let pirls_result = PirlsResult {
likelihood: config.likelihood,
beta_transformed,
penalized_hessian_transformed: penalized_hessian,
stabilizedhessian_transformed: stabilizedhessian,
ridge_passport: RidgePassport::scaled_identity(
ridge_used,
RidgePolicy::explicit_stabilization_full(),
),
ridge_used,
deviance,
edf,
stable_penalty_term: penalty_term,
firth: FirthDiagnostics::Inactive,
finalweights: priorweights_owned.clone(),
final_offset: offset.to_owned(),
final_eta: final_eta.clone(),
finalmu: finalmu.clone(),
solveweights: priorweights_owned,
solveworking_response: finalz.clone(),
solvemu: finalmu.clone(),
solve_dmu_deta,
solve_d2mu_deta2,
solve_d3mu_deta3,
solve_c_array,
solve_d_array,
status: PirlsStatus::Converged,
iteration: 1,
max_abs_eta,
lastgradient_norm: gradient_norm,
gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
last_deviance_change: 0.0,
last_step_halving: 0,
hessian_curvature: HessianCurvatureKind::Fisher,
exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
final_lm_lambda: working_summary.final_lm_lambda,
final_accept_rho: working_summary.final_accept_rho,
constraint_kkt: working_summary.constraint_kkt.clone(),
linear_constraints_transformed: linear_constraints.clone(),
reparam_result,
x_transformed: make_reparam_operator(&x_original, &qs_arc_final, use_sparse_native),
coordinate_frame: coordinate_frame.clone(),
cache_compacted: false,
min_penalized_deviance: working_summary.min_penalized_deviance,
};
return Ok((pirls_result, working_summary));
}
let x_original_for_result = x_original.clone();
let mut working_model = GamWorkingModel::new(
None, x_original.clone(),
coordinate_frame.clone(),
offset,
y,
priorweights,
penalty_active.clone(),
workspace,
likelihood,
config.link_kind.clone(),
config.firth_bias_reduction
&& matches!(
&config.link_kind,
InverseLink::Standard(LinkFunction::Logit)
),
transform_active.clone(),
quadctx,
);
if let Some(se) = covariate_se {
working_model = working_model.with_covariate_se(se.to_owned());
}
let mut beta_guess_original = warm_start_beta
.filter(|beta| beta.len() == penalty.p)
.map(|beta| beta.to_owned())
.unwrap_or_else(|| {
Coefficients::new(default_beta_guess_external(
penalty.p,
link_function,
y,
priorweights,
config.link_kind.mixture_state(),
config.link_kind.sas_state(),
))
});
if let Some(lb) = penalty.coefficient_lower_bounds {
project_coefficients_to_lower_bounds(&mut beta_guess_original.0, lb);
}
let initial_beta = transform_active
.as_ref()
.map(|transform| transform.apply_transpose(beta_guess_original.as_ref()))
.unwrap_or_else(|| beta_guess_original.as_ref().clone());
let firth_active = config.firth_bias_reduction && matches!(link_function, LinkFunction::Logit);
let base_max_step_halving = if firth_active { 60 } else { 30 };
let options = WorkingModelPirlsOptions {
max_iterations: if firth_active {
config.max_iterations.max(200)
} else {
config.max_iterations
},
convergence_tolerance: config.convergence_tolerance,
max_step_halving: base_max_step_halving,
min_step_size: if firth_active { 1e-12 } else { 1e-10 },
firth_bias_reduction: firth_active,
coefficient_lower_bounds: None,
linear_constraints: linear_constraints.clone(),
initial_lm_lambda: config.initial_lm_lambda,
};
let mut iteration_logger = |info: &WorkingModelIterationInfo| {
log::debug!(
"[PIRLS] iter {:>3} | deviance {:.6e} | |grad| {:.3e} | step {:.3e} (halving {})",
info.iteration,
info.deviance,
info.gradient_norm,
info.step_size,
info.step_halving
);
};
let mut working_summary = runworking_model_pirls(
&mut working_model,
Coefficients::new(initial_beta),
&options,
&mut iteration_logger,
)?;
let final_state = working_model.into_final_state();
let GamModelFinalState {
likelihood: final_likelihood,
coordinate_frame,
finalmu,
finalweights,
scoreweights,
finalz,
final_c,
final_d,
final_dmu_deta,
final_d2mu_deta2,
final_d3mu_deta3,
penalty_term,
..
} = final_state;
let penalized_hessian_transformed = working_summary.state.hessian.clone();
let stabilizedhessian_transformed = penalized_hessian_transformed.clone();
let mut edf = calculate_edf_with_penalty(&penalized_hessian_transformed, &penalty_active)?;
if !edf.is_finite() || edf.is_nan() {
let p = penalized_hessian_transformed.ncols() as f64;
let r = penalty_active.rank() as f64;
edf = (p - r).max(0.0);
}
let stalled_at_valid_minimum = |summary: &WorkingModelPirlsResult| -> bool {
let dev_scale = summary.state.deviance.abs().max(1.0);
let dev_tol = options.convergence_tolerance * dev_scale;
let step_floor = options.min_step_size * 2.0;
let progress_stopped =
summary.last_deviance_change.abs() <= dev_tol || summary.last_step_size <= step_floor;
let near_stationary = summary
.state
.near_stationary_kkt(summary.lastgradient_norm, options.convergence_tolerance);
progress_stopped && near_stationary
};
let mut status = working_summary.status.clone();
if status.is_failed_max_iterations() && stalled_at_valid_minimum(&working_summary) {
status = PirlsStatus::StalledAtValidMinimum;
working_summary.status = status.clone();
}
if status.is_failed_max_iterations()
&& firth_active
&& stalled_at_valid_minimum(&working_summary)
{
status = PirlsStatus::StalledAtValidMinimum;
working_summary.status = status.clone();
}
let has_penalty = penalty_active.rank() > 0;
let firth_active = options.firth_bias_reduction;
if detect_logit_instability(
link_function,
has_penalty,
firth_active,
&working_summary,
&finalmu,
&finalweights,
y,
) {
status = PirlsStatus::Unstable;
working_summary.status = status.clone();
}
let reparam_result_final = materialize_final_reparam_result()?;
let qs_arc_final = Arc::new(reparam_result_final.qs.clone());
let x_transformed_final =
make_reparam_operator(&x_original_for_result, &qs_arc_final, use_sparse_native);
let pirls_result = assemble_pirls_result(
&working_summary,
final_likelihood,
offset,
penalized_hessian_transformed,
stabilizedhessian_transformed,
edf,
penalty_term,
&finalmu,
&finalweights,
&scoreweights,
&finalz,
&final_c,
&final_d,
&final_dmu_deta,
&final_d2mu_deta2,
&final_d3mu_deta3,
status,
reparam_result_final,
x_transformed_final,
coordinate_frame,
linear_constraints,
);
Ok((pirls_result, working_summary))
}
#[derive(Clone)]
pub struct PirlsConfig {
pub likelihood: GlmLikelihoodSpec,
pub link_kind: InverseLink,
pub max_iterations: usize,
pub convergence_tolerance: f64,
pub firth_bias_reduction: bool,
pub initial_lm_lambda: Option<f64>,
}
impl PirlsConfig {
#[inline]
pub fn link_function(&self) -> LinkFunction {
self.link_kind.link_function()
}
}
#[inline]
#[cfg(debug_assertions)]
fn max_symmetric_asymmetry(matrix: &Array2<f64>) -> f64 {
let n = matrix.nrows().min(matrix.ncols());
let mut max_asym = 0.0_f64;
for i in 0..n {
for j in 0..i {
let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
if diff > max_asym {
max_asym = diff;
}
}
}
max_asym
}
#[inline]
#[cfg(debug_assertions)]
fn debug_assert_symmetric_tol(matrix: &Array2<f64>, label: &str, tol: f64) {
let max_asym = max_symmetric_asymmetry(matrix);
assert!(
max_asym <= tol,
"{} asymmetry too large: {:.3e} (tol {:.3e})",
label,
max_asym,
tol
);
}
fn make_reparam_operator(
x_original: &DesignMatrix,
qs_arc: &Arc<Array2<f64>>,
use_sparse_native: bool,
) -> DesignMatrix {
if use_sparse_native {
x_original.clone()
} else {
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(
ReparamOperator::new(x_original.clone(), Arc::clone(qs_arc)),
)))
}
}
fn solve_penalized_least_squares_implicit(
x_original: &DesignMatrix,
transform: Option<&WorkingReparamTransform>,
z: ArrayView1<f64>,
weights: ArrayView1<f64>,
offset: ArrayView1<f64>,
penalty: &PirlsPenalty,
workspace: &mut PirlsWorkspace,
y: ArrayView1<f64>,
link_function: LinkFunction,
) -> Result<(StablePLSResult, usize), EstimationError> {
let p_dim = penalty.dim();
if transform.is_none() {
if let Some(x_sparse) = x_original.as_sparse() {
let PirlsPenalty::Dense { s_transformed, .. } = penalty else {
return Err(EstimationError::InvalidInput(
"sparse-native PIRLS requires a dense transformed penalty matrix".to_string(),
));
};
let weights_owned = weights.to_owned();
let (h_sparse, factor, ridge_used) =
ensure_sparse_positive_definitewithridge(|ridge| {
let ridge = if ridge == 0.0 {
FIXED_STABILIZATION_RIDGE
} else {
ridge
};
workspace.assemble_sparse_penalized_hessian(
x_sparse,
&weights_owned,
s_transformed,
ridge,
)
})?;
let mut wz = z.to_owned();
wz -= &offset;
wz *= &weights_owned;
let rhs = x_original.transpose_vector_multiply(&wz);
let betavec = solve_sparse_spd(&factor, &rhs)?;
let h_sym = SymmetricMatrix::Sparse(h_sparse);
let edf = calculate_edf_with_penalty(&h_sym, penalty)?;
let fitted_vals = {
let xb = x_original.apply(&betavec);
let mut f = xb;
f += &offset;
f
};
let standard_deviation = match link_function {
LinkFunction::Identity => {
let residuals = &y - &fitted_vals;
let weighted_rss: f64 = weights
.iter()
.zip(residuals.iter())
.map(|(&w, &r)| w * r * r)
.sum();
let effective_n = y.len() as f64;
(weighted_rss / (effective_n - edf).max(1.0)).sqrt()
}
_ => 1.0,
};
return Ok((
StablePLSResult {
beta: Coefficients::new(betavec),
penalized_hessian: h_sym,
edf,
standard_deviation,
ridge_used,
},
p_dim,
));
}
}
workspace.fill_sqrtweights(&weights);
if workspace.wz.len() != z.len() {
workspace.wz = Array1::zeros(z.len());
}
workspace.wz.assign(&z);
workspace.wz -= &offset;
workspace.wz *= &weights;
let weights_owned = weights.to_owned();
let xtwx_orig = match x_original {
DesignMatrix::Dense(x_dense) if x_dense.is_materialized_dense() => {
let p = x_dense.ncols();
let x_dense = x_dense.to_dense_arc();
if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
workspace.hessian_buf = Array2::zeros((p, p).f());
} else {
workspace.hessian_buf.fill(0.0);
}
PirlsWorkspace::add_dense_xtwx_streaming_from_sqrt(
&workspace.sqrtw,
&mut workspace.weighted_x_chunk,
x_dense.as_ref(),
&mut workspace.hessian_buf,
get_global_parallelism(),
);
std::mem::take(&mut workspace.hessian_buf)
}
_ => x_original
.diag_xtw_x(&weights_owned)
.map_err(EstimationError::InvalidInput)?,
};
#[cfg(debug_assertions)]
let xtwx_orig_asym = max_symmetric_asymmetry(&xtwx_orig);
let xtwx_transformed = if let Some(transform) = transform {
transform.conjugate_matrix(&xtwx_orig)
} else {
xtwx_orig
};
let mut penalized_hessian = xtwx_transformed.clone();
penalty.add_to_hessian(&mut penalized_hessian);
let xtwy_orig = x_original.transpose_vector_multiply(&workspace.wz);
if workspace.vec_buf_p.len() != p_dim {
workspace.vec_buf_p = Array1::zeros(p_dim);
}
if let Some(transform) = transform {
workspace
.vec_buf_p
.assign(&transform.apply_transpose(&xtwy_orig));
} else {
workspace.vec_buf_p.assign(&xtwy_orig);
}
#[cfg(debug_assertions)]
{
let xtwx_asym = max_symmetric_asymmetry(&xtwx_transformed);
let penalty_asym = match penalty {
PirlsPenalty::Dense { s_transformed, .. } => max_symmetric_asymmetry(s_transformed),
PirlsPenalty::Diagonal { .. } => 0.0,
};
let total_asym = max_symmetric_asymmetry(&penalized_hessian);
assert!(
total_asym <= 1e-8,
"implicit PLS penalized Hessian asymmetry too large: total={total_asym:.3e}, xtwx_orig={xtwx_orig_asym:.3e}, xtwx={xtwx_asym:.3e}, penalty={penalty_asym:.3e}, tol={:.3e}",
1e-8
);
}
let nugget = FIXED_STABILIZATION_RIDGE;
let mut regularizedhessian = penalized_hessian.clone();
if nugget > 0.0 {
for i in 0..p_dim {
regularizedhessian[[i, i]] += nugget;
}
}
let ridge_used = nugget;
if workspace.rhs_full.len() != p_dim {
workspace.rhs_full = Array1::zeros(p_dim);
}
workspace.rhs_full.assign(&workspace.vec_buf_p);
let factor = StableSolver::new("pirls implicit pls")
.factorize(®ularizedhessian)
.map_err(EstimationError::LinearSystemSolveFailed)?;
let mut rhsview = array1_to_col_matmut(&mut workspace.rhs_full);
factor.solve_in_place(rhsview.as_mut());
if !array1_is_finite(&workspace.rhs_full) {
return Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed,
));
}
let betavec = workspace.rhs_full.clone();
let edf = calculate_edfwithworkspace_with_penalty(®ularizedhessian, penalty, workspace)?;
let qbeta = if let Some(transform) = transform {
transform.apply(&betavec)
} else {
betavec.clone()
};
let xqbeta = x_original.apply(&qbeta);
let mut fitted = xqbeta;
fitted += &offset;
let standard_deviation = match link_function {
LinkFunction::Identity => {
let residuals = &y - &fitted;
let weighted_rss: f64 = weights
.iter()
.zip(residuals.iter())
.map(|(&w, &r)| w * r * r)
.sum();
let effective_n = y.len() as f64;
(weighted_rss / (effective_n - edf).max(1.0)).sqrt()
}
_ => 1.0,
};
Ok((
StablePLSResult {
beta: Coefficients::new(betavec),
penalized_hessian: SymmetricMatrix::Dense(penalized_hessian),
edf,
standard_deviation,
ridge_used,
},
p_dim,
))
}
fn build_transformed_lower_bound_constraints(
qs: &Array2<f64>,
coefficient_lower_bounds: Option<&Array1<f64>>,
) -> Option<LinearInequalityConstraints> {
let lb = coefficient_lower_bounds?;
if lb.len() != qs.nrows() {
return None;
}
let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
if activerows.is_empty() {
return None;
}
let mut a = Array2::<f64>::zeros((activerows.len(), qs.ncols()));
let mut b = Array1::<f64>::zeros(activerows.len());
for (r, &idx) in activerows.iter().enumerate() {
a.row_mut(r).assign(&qs.row(idx));
b[r] = lb[idx];
}
Some(LinearInequalityConstraints { a, b })
}
fn build_transformed_lower_bound_constraints_with_transform(
transform: &WorkingReparamTransform,
coefficient_lower_bounds: Option<&Array1<f64>>,
) -> Option<LinearInequalityConstraints> {
let lb = coefficient_lower_bounds?;
let p = match transform {
WorkingReparamTransform::Dense(qs) => qs.nrows(),
WorkingReparamTransform::Kronecker(kron) => kron.p,
};
if lb.len() != p {
return None;
}
let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
if activerows.is_empty() {
return None;
}
let mut a = Array2::<f64>::zeros((activerows.len(), p));
let mut b = Array1::<f64>::zeros(activerows.len());
for (r, &idx) in activerows.iter().enumerate() {
let mut basis = Array1::<f64>::zeros(p);
basis[idx] = 1.0;
let row = transform.apply_transpose(&basis);
a.row_mut(r).assign(&row);
b[r] = lb[idx];
}
Some(LinearInequalityConstraints { a, b })
}
fn build_transformed_linear_constraints(
qs: &Array2<f64>,
linear_constraints: Option<&LinearInequalityConstraints>,
) -> Option<LinearInequalityConstraints> {
let lc = linear_constraints?;
if lc.a.ncols() != qs.nrows() {
return None;
}
Some(LinearInequalityConstraints {
a: lc.a.dot(qs),
b: lc.b.clone(),
})
}
fn build_transformed_linear_constraints_with_transform(
transform: &WorkingReparamTransform,
linear_constraints: Option<&LinearInequalityConstraints>,
) -> Option<LinearInequalityConstraints> {
let lc = linear_constraints?;
let p = match transform {
WorkingReparamTransform::Dense(qs) => qs.nrows(),
WorkingReparamTransform::Kronecker(kron) => kron.p,
};
if lc.a.ncols() != p {
return None;
}
let mut a = Array2::<f64>::zeros((lc.a.nrows(), p));
for row in 0..lc.a.nrows() {
let transformed = transform.apply_transpose(&lc.a.row(row).to_owned());
a.row_mut(row).assign(&transformed);
}
Some(LinearInequalityConstraints { a, b: lc.b.clone() })
}
fn merge_linear_constraints(
first: Option<LinearInequalityConstraints>,
second: Option<LinearInequalityConstraints>,
) -> Option<LinearInequalityConstraints> {
match (first, second) {
(None, None) => None,
(Some(c), None) | (None, Some(c)) => Some(c),
(Some(c1), Some(c2)) => {
if c1.a.ncols() != c2.a.ncols() {
return None;
}
let rows = c1.a.nrows() + c2.a.nrows();
let cols = c1.a.ncols();
let mut a = Array2::<f64>::zeros((rows, cols));
a.slice_mut(s![0..c1.a.nrows(), ..]).assign(&c1.a);
a.slice_mut(s![c1.a.nrows()..rows, ..]).assign(&c2.a);
let mut b = Array1::<f64>::zeros(rows);
b.slice_mut(s![0..c1.b.len()]).assign(&c1.b);
b.slice_mut(s![c1.b.len()..rows]).assign(&c2.b);
Some(LinearInequalityConstraints { a, b })
}
}
}
fn sparse_from_denseview(x: ArrayView2<f64>) -> Option<DesignMatrix> {
let nrows = x.nrows();
let ncols = x.ncols();
if nrows == 0 || ncols == 0 {
return None;
}
if ncols <= 32 {
return None;
}
const ZERO_EPS: f64 = 1e-12;
let total = nrows.saturating_mul(ncols);
if total == 0 {
return None;
}
let sparse_nnz_limit = ((total as f64) * 0.20).floor() as usize;
let mut nnz = 0usize;
for &val in x.iter() {
if val.abs() > ZERO_EPS {
nnz += 1;
if nnz > sparse_nnz_limit {
return None;
}
}
}
let mut triplets = Vec::with_capacity(nnz);
for (row_idx, row) in x.outer_iter().enumerate() {
for (col_idx, &val) in row.iter().enumerate() {
if val.abs() > ZERO_EPS {
triplets.push(Triplet::new(row_idx, col_idx, val));
}
}
}
SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
.ok()
.map(DesignMatrix::from)
}
const LOGIT_ZERO_HIGHER_DERIVATIVES_ON_NONSMOOTH: bool = true;
#[inline]
fn standard_inverse_link_jet(
inverse_link: &InverseLink,
eta: f64,
) -> Result<MixtureInverseLinkJet, EstimationError> {
crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta)
}
#[inline]
fn bernoulli_logit_geometry_from_jet(
eta_raw: f64,
eta_used: f64,
y: f64,
priorweight: f64,
jet: crate::mixture_link::LogitJet5,
zero_on_nonsmooth: bool,
) -> WorkingBernoulliGeometry {
let fisher = jet.d1;
let nonsmooth = eta_raw != eta_used || !fisher.is_finite() || fisher < 0.0;
let (c, d) = if nonsmooth && zero_on_nonsmooth {
(0.0, 0.0)
} else {
(priorweight * jet.d2, priorweight * jet.d3)
};
WorkingBernoulliGeometry {
mu: jet.mu,
weight: priorweight * fisher,
z: bernoulli_exact_working_response(eta_used, y, jet.mu, jet.d1),
c,
d,
}
}
#[inline]
fn bernoulli_geometry_from_jet(
eta_raw: f64,
eta_used: f64,
y: f64,
priorweight: f64,
jet: MixtureInverseLinkJet,
) -> WorkingBernoulliGeometry {
let mu = jet.mu;
let v = mu * (1.0 - mu);
let n0 = jet.d1 * jet.d1;
let fisher = if v.is_finite() && v > 0.0 {
n0 / v
} else {
0.0
};
let nonsmooth =
eta_raw != eta_used || !v.is_finite() || v <= 0.0 || !fisher.is_finite() || fisher < 0.0;
let (c, d) = if nonsmooth {
(0.0, 0.0)
} else {
let v1 = jet.d1 * (1.0 - 2.0 * mu);
let v2 = jet.d2 * (1.0 - 2.0 * mu) - 2.0 * jet.d1 * jet.d1;
let n1 = 2.0 * jet.d1 * jet.d2;
let n2 = 2.0 * (jet.d2 * jet.d2 + jet.d1 * jet.d3);
let numer1 = n1 * v - n0 * v1;
let c = priorweight * numer1 / (v * v);
let d = priorweight * ((n2 * v - n0 * v2) / (v * v) - 2.0 * numer1 * v1 / (v * v * v));
(c, d)
};
WorkingBernoulliGeometry {
mu,
weight: priorweight * fisher,
z: bernoulli_exact_working_response(eta_used, y, mu, jet.d1),
c,
d,
}
}
#[inline]
fn bernoulli_exact_working_response(eta: f64, y: f64, mu: f64, dmu_deta: f64) -> f64 {
if dmu_deta.is_finite() && dmu_deta > 0.0 {
let delta = (y - mu) / dmu_deta;
if delta.is_finite() {
return eta + delta;
}
}
eta
}
#[inline]
fn write_identityworking_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
mu.assign(eta);
weights.assign(&priorweights);
z.assign(&y);
if let Some(derivs) = derivatives {
derivs.c.fill(0.0);
derivs.d.fill(0.0);
derivs.dmu_deta.fill(1.0);
derivs.d2mu_deta2.fill(0.0);
derivs.d3mu_deta3.fill(0.0);
}
}
#[inline]
fn write_poisson_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
const MIN_MU: f64 = 1e-10;
const MIN_WEIGHT: f64 = 1e-12;
if let Some(derivs) = derivatives {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
let dmu_s = derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
let c_s = derivs.c.as_slice_mut().expect("c must be contiguous");
let d_s = derivs.d.as_slice_mut().expect("d must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), dmu_o), d2_o), d3_o), c_o), d_o))| {
let eta_raw = eta[i];
let eta_i = eta_raw.clamp(-700.0, 700.0);
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
let raw_weight = priorweights[i].max(0.0) * mu_i;
let floor_active = raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
*w_o = if raw_weight > 0.0 {
raw_weight.max(MIN_WEIGHT)
} else {
0.0
};
*z_o = eta_i + (y[i] - mu_i) / mu_i;
*dmu_o = mu_i;
*d2_o = mu_i;
*d3_o = mu_i;
if floor_active || eta_raw != eta_i {
*c_o = 0.0;
*d_o = 0.0;
} else {
*c_o = raw_weight;
*d_o = raw_weight;
}
},
);
} else {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_i = eta[i].clamp(-700.0, 700.0);
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
let raw_weight = priorweights[i].max(0.0) * mu_i;
*w_o = if raw_weight > 0.0 {
raw_weight.max(MIN_WEIGHT)
} else {
0.0
};
*z_o = eta_i + (y[i] - mu_i) / mu_i;
});
}
}
#[inline]
fn write_gamma_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
shape: f64,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
const MIN_MU: f64 = 1e-10;
if let Some(derivs) = derivatives {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
let dmu_s = derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
let c_s = derivs.c.as_slice_mut().expect("c must be contiguous");
let d_s = derivs.d.as_slice_mut().expect("d must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), dmu_o), d2_o), d3_o), c_o), d_o))| {
let eta_i = eta[i].clamp(-700.0, 700.0);
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
*w_o = priorweights[i].max(0.0) * shape;
*z_o = eta_i + (y[i] - mu_i) / mu_i;
*dmu_o = mu_i;
*d2_o = mu_i;
*d3_o = mu_i;
*c_o = 0.0;
*d_o = 0.0;
},
);
} else {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_i = eta[i].clamp(-700.0, 700.0);
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
*w_o = priorweights[i].max(0.0) * shape;
*z_o = eta_i + (y[i] - mu_i) / mu_i;
});
}
}
#[inline]
pub fn update_glmvectors(
y: ArrayView1<f64>,
eta: &Array1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if matches!(link, LinkFunction::Logit)
&& inverse_link.mixture_state().is_none()
&& inverse_link.sas_state().is_none()
{
if let Some(derivs) = derivatives {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
let c_s = derivs.c.as_slice_mut().expect("c must be contiguous");
let d_s = derivs.d.as_slice_mut().expect("d must be contiguous");
let dmu_s = derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-700.0, 700.0);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
LOGIT_ZERO_HIGHER_DERIVATIVES_ON_NONSMOOTH,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
},
);
} else {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-700.0, 700.0);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
LOGIT_ZERO_HIGHER_DERIVATIVES_ON_NONSMOOTH,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
});
}
return Ok(());
}
match link {
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => {
let zero_on_nonsmooth =
matches!(link, LinkFunction::Logit) && LOGIT_ZERO_HIGHER_DERIVATIVES_ON_NONSMOOTH;
if let Some(derivs) = derivatives {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
let c_s = derivs.c.as_slice_mut().expect("c must be contiguous");
let d_s = derivs.d.as_slice_mut().expect("d must be contiguous");
let dmu_s = derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(
i,
(((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o),
)|
-> Result<(), EstimationError> {
let eta_used = match link {
LinkFunction::Logit => eta[i].clamp(-700.0, 700.0),
LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => eta[i].clamp(-30.0, 30.0),
LinkFunction::Log => eta[i].clamp(-700.0, 700.0),
LinkFunction::Identity => eta[i],
};
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
} else {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let eta_used = match link {
LinkFunction::Logit => eta[i].clamp(-700.0, 700.0),
LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => eta[i].clamp(-30.0, 30.0),
LinkFunction::Log => eta[i].clamp(-700.0, 700.0),
LinkFunction::Identity => eta[i],
};
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
}
Ok(())
})?;
}
Ok(())
}
LinkFunction::Identity => {
write_identityworking_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
LinkFunction::Log => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
}
}
#[inline]
pub fn update_glmvectors_by_family(
y: ArrayView1<f64>,
eta: &Array1<f64>,
likelihood: GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
) -> Result<(), EstimationError> {
likelihood.irls_update(y, eta, priorweights, mu, weights, z, None, None)
}
fn integrated_inverse_link_from_family(
family: GlmLikelihoodFamily,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<InverseLink, EstimationError> {
match family {
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog => Ok(InverseLink::Standard(family.link_function())),
GlmLikelihoodFamily::BinomialSas => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialSas update requires explicit SasLinkState".to_string(),
)
})?;
Ok(InverseLink::Sas(*state))
}
GlmLikelihoodFamily::BinomialBetaLogistic => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialBetaLogistic update requires explicit SasLinkState"
.to_string(),
)
})?;
Ok(InverseLink::BetaLogistic(*state))
}
GlmLikelihoodFamily::BinomialMixture => {
let state = mixture_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialMixture update requires explicit MixtureLinkState"
.to_string(),
)
})?;
Ok(InverseLink::Mixture(state.clone()))
}
GlmLikelihoodFamily::GaussianIdentity
| GlmLikelihoodFamily::PoissonLog
| GlmLikelihoodFamily::GammaLog => Err(EstimationError::InvalidInput(format!(
"Integrated link-runtime update is not supported for family {:?}",
family
))),
}
}
#[inline]
pub fn update_glmvectors_integrated_for_link(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if !matches!(
inverse_link,
InverseLink::Standard(LinkFunction::Logit)
| InverseLink::Standard(LinkFunction::Probit)
| InverseLink::Standard(LinkFunction::CLogLog)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
) {
return Err(EstimationError::InvalidInput(format!(
"Integrated link-runtime update is not supported for inverse link {:?}",
inverse_link
)));
}
if let Some(derivs) = derivatives {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
let c_s = derivs.c.as_slice_mut().expect("c must be contiguous");
let d_s = derivs.d.as_slice_mut().expect("d must be contiguous");
let dmu_s = derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))|
-> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::families::lognormal_kernel::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(LinkFunction::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-700.0, 700.0);
let geom = bernoulli_geometry_from_jet(
eta[i],
e,
y[i],
priorweights[i],
local_jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = local_jet.d1;
*d2_o = local_jet.d2;
*d3_o = local_jet.d3;
Ok(())
},
)?;
} else {
let mu_s = mu.as_slice_mut().expect("mu must be contiguous");
let weights_s = weights.as_slice_mut().expect("weights must be contiguous");
let z_s = z.as_slice_mut().expect("z must be contiguous");
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::families::lognormal_kernel::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(LinkFunction::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-700.0, 700.0);
let geom = bernoulli_geometry_from_jet(eta[i], e, y[i], priorweights[i], local_jet);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
Ok(())
})?;
}
Ok(())
}
#[inline]
pub fn update_glmvectors_integrated_by_family(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
family: GlmLikelihoodFamily,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<(), EstimationError> {
let inverse_link =
integrated_inverse_link_from_family(family, mixture_link_state, sas_link_state)?;
update_glmvectors_integrated_for_link(
quadctx,
y,
eta,
se,
&inverse_link,
priorweights,
mu,
weights,
z,
derivatives,
)
}
fn computeworkingweight_derivatives_from_eta(
likelihood: GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<
(
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
),
EstimationError,
> {
let n = eta.len();
let mut c = Array1::<f64>::zeros(n);
let mut d = Array1::<f64>::zeros(n);
let mut dmu_deta = Array1::<f64>::zeros(n);
let mut d2mu_deta2 = Array1::<f64>::zeros(n);
let mut d3mu_deta3 = Array1::<f64>::zeros(n);
match likelihood.family {
GlmLikelihoodFamily::GaussianIdentity => {
dmu_deta.fill(1.0);
}
GlmLikelihoodFamily::PoissonLog => {
const MIN_WEIGHT: f64 = 1e-12;
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| -> Result<(), EstimationError> {
let eta_used = eta[i].clamp(-700.0, 700.0);
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let raw_weight = priorweights[i].max(0.0) * jet.mu;
let floor_active = raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
if eta[i] != eta_used || floor_active {
*c_o = 0.0;
*d_o = 0.0;
} else {
*c_o = raw_weight;
*d_o = raw_weight;
}
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
Ok(())
},
)?;
}
GlmLikelihoodFamily::GammaLog => {
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
dmu_s
.par_iter_mut()
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((dmu_o, d2_o), d3_o))| -> Result<(), EstimationError> {
let jet =
standard_inverse_link_jet(inverse_link, eta[i].clamp(-700.0, 700.0))?;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
Ok(())
},
)?;
}
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture => {
let link = inverse_link.link_function();
let zero_on_nonsmooth =
matches!(link, LinkFunction::Logit) && LOGIT_ZERO_HIGHER_DERIVATIVES_ON_NONSMOOTH;
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| -> Result<(), EstimationError> {
let eta_used = match link {
LinkFunction::Logit => eta[i].clamp(-700.0, 700.0),
LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => eta[i].clamp(-30.0, 30.0),
LinkFunction::Log => eta[i].clamp(-700.0, 700.0),
LinkFunction::Identity => eta[i],
};
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
zero_on_nonsmooth,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
}
}
Ok((c, d, dmu_deta, d2mu_deta2, d3mu_deta3))
}
#[derive(Clone, Copy, Debug)]
pub struct VarianceJet {
pub v: f64,
pub v1: f64,
pub v2: f64,
pub v3: f64,
pub v4: f64,
}
impl VarianceJet {
#[inline]
pub fn bernoulli(mu: f64) -> Self {
Self {
v: mu * (1.0 - mu),
v1: 1.0 - 2.0 * mu,
v2: -2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn poisson(mu: f64) -> Self {
Self {
v: mu,
v1: 1.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gamma(mu: f64) -> Self {
Self {
v: mu * mu,
v1: 2.0 * mu,
v2: 2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gaussian() -> Self {
Self {
v: 1.0,
v1: 0.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn binomial_n(mu: f64) -> Self {
Self::bernoulli(mu)
}
}
const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
#[inline]
fn fixed_glm_dispersion(likelihood: GlmLikelihoodSpec) -> f64 {
likelihood.fixed_phi().unwrap_or(1.0)
}
#[inline]
pub fn weight_family_for_glm_likelihood(likelihood: GlmLikelihoodSpec) -> WeightFamily {
match likelihood.family {
GlmLikelihoodFamily::GaussianIdentity => WeightFamily::Gaussian,
GlmLikelihoodFamily::PoissonLog => WeightFamily::Poisson,
GlmLikelihoodFamily::GammaLog => WeightFamily::Gamma,
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture => WeightFamily::Binomial,
}
}
#[inline]
fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
match inverse_link {
InverseLink::Standard(LinkFunction::Identity) => WeightLink::Identity,
InverseLink::Standard(LinkFunction::Log) => WeightLink::Log,
InverseLink::Standard(LinkFunction::Logit) => WeightLink::Logit,
InverseLink::Standard(LinkFunction::Probit)
| InverseLink::Standard(LinkFunction::CLogLog)
| InverseLink::Standard(LinkFunction::Sas)
| InverseLink::Standard(LinkFunction::BetaLogistic)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => WeightLink::Other,
}
}
#[inline]
fn supports_observed_hessian_curvature_for_likelihood(
likelihood: GlmLikelihoodSpec,
_: &InverseLink,
) -> bool {
matches!(
likelihood.family,
GlmLikelihoodFamily::GammaLog
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture
)
}
#[inline]
fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
match inverse_link {
InverseLink::Standard(LinkFunction::Logit | LinkFunction::Log) => eta.clamp(-700.0, 700.0),
InverseLink::Standard(LinkFunction::Identity) => eta,
_ => eta.clamp(-30.0, 30.0),
}
}
#[inline]
fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
(fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
.max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
}
fn solver_hessian_weights_into(
hessian_weights: &Array1<f64>,
fisher_weights: &Array1<f64>,
out: &mut Array1<f64>,
) {
if out.len() != hessian_weights.len() {
*out = Array1::<f64>::zeros(hessian_weights.len());
}
ndarray::Zip::from(out)
.and(hessian_weights)
.and(fisher_weights)
.par_for_each(|o, &w, &fw| {
let floor = solver_hessian_weight_floor(fw);
*o = if w.is_finite() && w > floor { w } else { floor };
});
}
fn compute_observed_hessian_curvature_arrays_into(
likelihood: GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
mu: &Array1<f64>,
dmu_deta: &Array1<f64>,
d2mu_deta2: &Array1<f64>,
d3mu_deta3: &Array1<f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
hessian_weights: &mut Array1<f64>,
hessian_c: &mut Array1<f64>,
hessian_d: &mut Array1<f64>,
) -> Result<(), EstimationError> {
assert!(supports_observed_hessian_curvature_for_likelihood(
likelihood,
inverse_link
));
let n = eta.len();
if hessian_weights.len() != n {
*hessian_weights = Array1::<f64>::zeros(n);
}
if hessian_c.len() != n {
*hessian_c = Array1::<f64>::zeros(n);
}
if hessian_d.len() != n {
*hessian_d = Array1::<f64>::zeros(n);
}
let weight_family = weight_family_for_glm_likelihood(likelihood);
let weight_link = weight_link_for_inverse_link(inverse_link);
let phi = fixed_glm_dispersion(likelihood);
hessian_weights
.as_slice_mut()
.expect("hessian weights must be contiguous")
.par_iter_mut()
.zip(
hessian_c
.as_slice_mut()
.expect("hessian c must be contiguous")
.par_iter_mut(),
)
.zip(
hessian_d
.as_slice_mut()
.expect("hessian d must be contiguous")
.par_iter_mut(),
)
.enumerate()
.try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
inverse_link, eta_used,
)?;
let jet = MixtureInverseLinkJet {
mu: mu[i],
d1: dmu_deta[i],
d2: d2mu_deta2[i],
d3: d3mu_deta3[i],
};
let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
weight_family,
weight_link,
eta_used,
y[i],
mu[i],
phi,
priorweights[i].max(0.0),
jet,
h4,
);
let fisher_weight = fisher_weights[i].max(0.0);
if !(w_obs.is_finite() && w_obs > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"observed Hessian curvature is not positive finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
)));
}
if !c_obs.is_finite() || !d_obs.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
)));
}
*w_out = w_obs;
*c_out = c_obs;
*d_out = d_obs;
Ok(())
})
}
fn compute_observed_hessian_curvature_arrays(
likelihood: GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
mu: &Array1<f64>,
dmu_deta: &Array1<f64>,
d2mu_deta2: &Array1<f64>,
d3mu_deta3: &Array1<f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let n = eta.len();
let mut hessian_weights = Array1::<f64>::zeros(n);
let mut hessian_c = Array1::<f64>::zeros(n);
let mut hessian_d = Array1::<f64>::zeros(n);
compute_observed_hessian_curvature_arrays_into(
likelihood,
inverse_link,
eta,
y,
mu,
dmu_deta,
d2mu_deta2,
d3mu_deta3,
fisher_weights,
priorweights,
&mut hessian_weights,
&mut hessian_c,
&mut hessian_d,
)?;
Ok((hessian_weights, hessian_c, hessian_d))
}
#[inline]
pub fn observed_weight_noncanonical(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> (f64, f64, f64) {
let VarianceJet {
v,
v1,
v2,
v3,
v4: _,
} = vj;
let phi_v = phi * v;
let phi_v2 = phi * v * v;
let phi_v3 = phi * v * v * v;
let h1_sq = h1 * h1;
let w_f = h1_sq / phi_v;
let n0 = h1_sq; let n1 = 2.0 * h1 * h2; let n2 = 2.0 * (h2 * h2 + h1 * h3); let vd1 = h1 * v1; let vd2 = h2 * v1 + h1_sq * v2;
let c_f = (n1 * v - n0 * vd1) / phi_v2;
let numer_cf = n1 * v - n0 * vd1;
let dnumer_cf = n2 * v - n0 * vd2;
let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
let b_num = h2 * v - h1_sq * v1;
let b = b_num / phi_v2;
let b_eta_num =
h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
let b_eta = b_eta_num / phi_v3;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
- 3.0 * (h2 * h2 + h1 * h3) * v * v1
- 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
- 3.0 * h1_sq * h2 * v * v2
- h1_cu * (h1 * v1 * v2 + v * h1 * v3)
+ 6.0 * h1_sq * h2 * v1 * v1
+ 4.0 * h1_qu * v1 * v2;
let phi_v4 = phi_v3 * v;
let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
let resid = y - mu;
let w_obs = w_f - resid * b;
let c_obs = c_f + h1 * b - resid * b_eta;
let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
(pw * w_obs, pw * c_obs, pw * d_obs)
}
#[inline]
pub fn e_obs_from_jets(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
h5: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> f64 {
let VarianceJet { v, v1, v2, v3, v4 } = vj;
let q = phi * v;
let h1_sq = h1 * h1;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let q1 = phi * v1 * h1;
let q2 = phi * (v1 * h2 + v2 * h1_sq);
let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
let q4 = phi
* (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
let t0 = h1 / q;
let t1 = (h2 - t0 * q1) / q;
let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
let resid = y - mu;
let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
pw * e_obs
}
pub fn compute_noncanonical_observed_weights(
eta: &Array1<f64>,
y: ArrayView1<f64>,
jets: &[MixtureInverseLinkJet],
h4: &[f64],
var_jet_fn: impl Fn(f64) -> VarianceJet + Sync,
phi: f64,
prior_weights: ArrayView1<f64>,
) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
let n = eta.len();
let mut w = Array1::<f64>::zeros(n);
let mut c = Array1::<f64>::zeros(n);
let mut d = Array1::<f64>::zeros(n);
let w_s = w.as_slice_mut().expect("w must be contiguous");
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
w_s.par_iter_mut()
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((w_o, c_o), d_o))| {
let jet = &jets[i];
let vj = var_jet_fn(jet.mu);
let (wi, ci, di) = observed_weight_noncanonical(
y[i],
jet.mu,
jet.d1,
jet.d2,
jet.d3,
h4[i],
vj,
phi,
prior_weights[i],
);
*w_o = wi;
*c_o = ci;
*d_o = di;
});
(w, c, d)
}
#[inline]
pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let inv_phi = pw / phi;
let w = inv_phi * mu * (2.0 * mu - y);
let c = inv_phi * mu * (4.0 * mu - y);
let d = inv_phi * mu * (8.0 * mu - y);
(w, c, d)
}
#[inline]
pub fn fisher_weight_gaussian_log(mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let mu2 = mu * mu;
let inv_phi = pw / phi;
(inv_phi * mu2, inv_phi * 2.0 * mu2, inv_phi * 4.0 * mu2)
}
#[inline]
pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let eta2 = eta * eta;
let eta4 = eta2 * eta2;
let eta5 = eta4 * eta;
let eta6 = eta4 * eta2;
let ey = eta * y;
let inv_phi = pw / phi;
let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
(w, c, d)
}
#[inline]
pub fn fisher_weight_gaussian_inverse(eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let eta2 = eta * eta;
let eta4 = eta2 * eta2;
let eta5 = eta4 * eta;
let eta6 = eta4 * eta2;
let inv_phi = pw / phi;
(inv_phi / eta4, -4.0 * inv_phi / eta5, 20.0 * inv_phi / eta6)
}
#[inline]
fn observed_weight_binomial_logit_from_jet(
n_trials: f64,
jet: MixtureInverseLinkJet,
pw: f64,
) -> (f64, f64, f64) {
let scale = pw * n_trials;
(scale * jet.d1, scale * jet.d2, scale * jet.d3)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightFamily {
Gaussian,
Binomial,
Poisson,
Gamma,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightLink {
Identity,
Log,
Logit,
Inverse,
Other,
}
#[inline]
pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
match family {
WeightFamily::Gaussian => VarianceJet::gaussian(),
WeightFamily::Binomial => VarianceJet::binomial_n(mu),
WeightFamily::Poisson => VarianceJet::poisson(mu),
WeightFamily::Gamma => VarianceJet::gamma(mu),
}
}
pub fn observed_weight_dispatch(
family: WeightFamily,
link: WeightLink,
eta: f64,
y: f64,
mu: f64,
phi: f64,
prior_weight: f64,
jet: MixtureInverseLinkJet,
h4: f64,
) -> (f64, f64, f64) {
match (family, link) {
(WeightFamily::Gaussian, WeightLink::Log) => {
observed_weight_gaussian_log(y, mu, phi, prior_weight)
}
(WeightFamily::Gaussian, WeightLink::Inverse) => {
observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
}
(WeightFamily::Binomial, WeightLink::Logit) => {
observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
}
_ => {
let vj = variance_jet_for_weight_family(family, mu);
observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
}
}
}
pub fn fisher_weight_dispatch(
family: WeightFamily,
link: WeightLink,
eta: f64,
mu: f64,
phi: f64,
prior_weight: f64,
jet: MixtureInverseLinkJet,
) -> (f64, f64, f64) {
match (family, link) {
(WeightFamily::Gaussian, WeightLink::Log) => {
fisher_weight_gaussian_log(mu, phi, prior_weight)
}
(WeightFamily::Gaussian, WeightLink::Inverse) => {
fisher_weight_gaussian_inverse(eta, phi, prior_weight)
}
(WeightFamily::Binomial, WeightLink::Logit) => {
observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
}
_ => {
let vj = variance_jet_for_weight_family(family, mu);
observed_weight_noncanonical(mu, mu, jet.d1, jet.d2, jet.d3, 0.0, vj, phi, prior_weight)
}
}
}
pub fn compute_observed_weights_dispatched(
family: WeightFamily,
link: WeightLink,
eta: &Array1<f64>,
y: ArrayView1<f64>,
jets: &[MixtureInverseLinkJet],
h4: &[f64],
phi: f64,
prior_weights: ArrayView1<f64>,
) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
let n = eta.len();
match (family, link) {
(WeightFamily::Gaussian, WeightLink::Log)
| (WeightFamily::Gaussian, WeightLink::Inverse)
| (WeightFamily::Binomial, WeightLink::Logit) => {
let mut w = Array1::<f64>::zeros(n);
let mut c = Array1::<f64>::zeros(n);
let mut d = Array1::<f64>::zeros(n);
ndarray::Zip::indexed(&mut w)
.and(&mut c)
.and(&mut d)
.par_for_each(|i, wo, co, doo| {
let (wi, ci, di) = observed_weight_dispatch(
family,
link,
eta[i],
y[i],
jets[i].mu,
phi,
prior_weights[i],
jets[i].clone(),
h4[i],
);
*wo = wi;
*co = ci;
*doo = di;
});
(w, c, d)
}
_ => {
let var_jet_fn = match family {
WeightFamily::Gaussian => VarianceJet::gaussian as fn() -> VarianceJet,
WeightFamily::Binomial => {
return compute_noncanonical_observed_weights(
eta,
y,
jets,
h4,
VarianceJet::binomial_n,
phi,
prior_weights,
);
}
WeightFamily::Poisson => {
return compute_noncanonical_observed_weights(
eta,
y,
jets,
h4,
VarianceJet::poisson,
phi,
prior_weights,
);
}
WeightFamily::Gamma => {
return compute_noncanonical_observed_weights(
eta,
y,
jets,
h4,
VarianceJet::gamma,
phi,
prior_weights,
);
}
};
compute_noncanonical_observed_weights(
eta,
y,
jets,
h4,
|mu| {
assert!(mu.is_finite());
var_jet_fn()
},
phi,
prior_weights,
)
}
}
}
#[derive(Clone)]
pub enum DirectionalWorkingCurvature {
Diagonal(Array1<f64>),
}
pub fn directionalworking_curvature_from_c_array(
c_array: &Array1<f64>,
hessian_weights: &Array1<f64>,
eta_direction: &Array1<f64>,
) -> DirectionalWorkingCurvature {
let mut w_direction = c_array * eta_direction;
for i in 0..w_direction.len() {
if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
w_direction[i] = 0.0;
}
}
DirectionalWorkingCurvature::Diagonal(w_direction)
}
#[inline]
pub fn calculate_deviance(
y: ArrayView1<f64>,
mu: &Array1<f64>,
likelihood: GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
) -> f64 {
const EPS: f64 = 1e-8;
match likelihood.family {
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture => {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total_residual: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i];
let wi = priorweights[i];
let term1 = if yi > EPS {
yi * (yi.ln() - mui_c.ln())
} else {
0.0
};
let term2 = if yi < 1.0 - EPS {
(1.0 - yi) * ((1.0 - yi).ln() - (1.0 - mui_c).ln())
} else {
0.0
};
wi * (term1 + term2)
})
.sum();
2.0 * total_residual
}
GlmLikelihoodFamily::GaussianIdentity => ndarray::Zip::from(y)
.and(mu)
.and(priorweights)
.map_collect(|&yi, &mui, &wi| wi * (yi - mui) * (yi - mui))
.sum(),
GlmLikelihoodFamily::PoissonLog => {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i].max(EPS);
let term = if yi > EPS {
yi * (yi / mui_c).ln() - (yi - mui_c)
} else {
mui_c
};
priorweights[i] * term
})
.sum();
2.0 * total
}
GlmLikelihoodFamily::GammaLog => {
let shape = likelihood.gamma_shape().unwrap_or(1.0);
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi_c = y[i].max(EPS);
let mui_c = mu[i].max(EPS);
let ratio = yi_c / mui_c;
priorweights[i] * shape * (ratio - 1.0 - ratio.ln())
})
.sum();
2.0 * total
}
}
}
#[inline]
pub(crate) fn calculate_loglikelihood_omitting_constants(
y: ArrayView1<f64>,
mu: &Array1<f64>,
likelihood: GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
) -> f64 {
const EPS: f64 = 1e-8;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = y.len();
match likelihood.family {
GlmLikelihoodFamily::GaussianIdentity => (0..n)
.into_par_iter()
.map(|i| {
let resid = y[i] - mu[i];
-0.5 * priorweights[i] * resid * resid
})
.sum(),
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = mu[i].clamp(EPS, 1.0 - EPS);
priorweights[i] * (y[i] * mui_c.ln() + (1.0 - y[i]) * (1.0 - mui_c).ln())
})
.sum(),
GlmLikelihoodFamily::PoissonLog => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = mu[i].max(EPS);
let log_term = if y[i] > 0.0 { y[i] * mui_c.ln() } else { 0.0 };
priorweights[i] * (log_term - mui_c)
})
.sum(),
GlmLikelihoodFamily::GammaLog => gamma_loglikelihood_with_shape(
y,
mu,
priorweights,
likelihood.gamma_shape().unwrap_or(1.0),
),
}
}
#[derive(Clone)]
pub struct StablePLSResult {
pub beta: Coefficients,
pub penalized_hessian: SymmetricMatrix,
pub edf: f64,
pub standard_deviation: f64,
pub ridge_used: f64,
}
fn calculate_edf(
penalized_hessian: &SymmetricMatrix,
e_transformed: &Array2<f64>,
) -> Result<f64, EstimationError> {
let p = penalized_hessian.ncols();
let r = e_transformed.nrows();
let mp = ((p - r) as f64).max(0.0);
if r == 0 {
return Ok(p as f64);
}
let rhs_arr = e_transformed.t().to_owned();
let factor =
penalized_hessian
.factorize()
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
let sol = factor
.solvemulti(&rhs_arr)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
if sol.nrows() == p && sol.ncols() == r && sol.iter().all(|v| v.is_finite()) {
return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
sol[[i, j]]
}));
}
Err(EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})
}
fn calculate_edf_with_penalty(
penalized_hessian: &SymmetricMatrix,
penalty: &PirlsPenalty,
) -> Result<f64, EstimationError> {
match penalty {
PirlsPenalty::Dense { e_transformed, .. } => {
calculate_edf(penalized_hessian, e_transformed)
}
PirlsPenalty::Diagonal {
diag,
positive_indices,
..
} => calculate_edf_from_diagonal_penalty(penalized_hessian, diag, positive_indices),
}
}
fn calculate_edfwithworkspace(
penalized_hessian: &Array2<f64>,
e_transformed: &Array2<f64>,
workspace: &mut PirlsWorkspace,
) -> Result<f64, EstimationError> {
let p = penalized_hessian.ncols();
let r = e_transformed.nrows();
let mp = ((p - r) as f64).max(0.0);
if r == 0 {
return Ok(p as f64);
}
if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
workspace.final_aug_matrix = Array2::zeros((p, r));
}
for j in 0..r {
for i in 0..p {
workspace.final_aug_matrix[[i, j]] = e_transformed[[j, i]];
}
}
let factor = StableSolver::new("pirls edf workspace")
.factorize(penalized_hessian)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
{
let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
factor.solve_in_place(rhsview.as_mut());
}
if workspace.final_aug_matrix.nrows() == p
&& workspace.final_aug_matrix.ncols() == r
&& array2_is_finite(&workspace.final_aug_matrix)
{
return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
workspace.final_aug_matrix[(i, j)]
}));
}
Err(EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})
}
fn calculate_edfwithworkspace_with_penalty(
penalized_hessian: &Array2<f64>,
penalty: &PirlsPenalty,
workspace: &mut PirlsWorkspace,
) -> Result<f64, EstimationError> {
match penalty {
PirlsPenalty::Dense { e_transformed, .. } => {
calculate_edfwithworkspace(penalized_hessian, e_transformed, workspace)
}
PirlsPenalty::Diagonal {
diag,
positive_indices,
..
} => calculate_edfwithworkspace_from_diagonal_penalty(
penalized_hessian,
diag,
positive_indices,
workspace,
),
}
}
fn calculate_edf_from_diagonal_penalty(
penalized_hessian: &SymmetricMatrix,
diag: &Array1<f64>,
positive_indices: &[usize],
) -> Result<f64, EstimationError> {
let p = penalized_hessian.ncols();
let r = positive_indices.len();
let mp = ((p - r) as f64).max(0.0);
if r == 0 {
return Ok(p as f64);
}
let mut rhs_arr = Array2::<f64>::zeros((p, r));
for (col, &idx) in positive_indices.iter().enumerate() {
rhs_arr[[idx, col]] = 1.0;
}
let factor =
penalized_hessian
.factorize()
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
let sol = factor
.solvemulti(&rhs_arr)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
let mut tr = 0.0;
for (col, &idx) in positive_indices.iter().enumerate() {
tr += diag[idx] * sol[[idx, col]];
}
Ok((p as f64 - tr).clamp(mp, p as f64))
}
fn calculate_edfwithworkspace_from_diagonal_penalty(
penalized_hessian: &Array2<f64>,
diag: &Array1<f64>,
positive_indices: &[usize],
workspace: &mut PirlsWorkspace,
) -> Result<f64, EstimationError> {
let p = penalized_hessian.ncols();
let r = positive_indices.len();
let mp = ((p - r) as f64).max(0.0);
if r == 0 {
return Ok(p as f64);
}
if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
workspace.final_aug_matrix = Array2::zeros((p, r));
} else {
workspace.final_aug_matrix.fill(0.0);
}
for (col, &idx) in positive_indices.iter().enumerate() {
workspace.final_aug_matrix[[idx, col]] = 1.0;
}
let factor = StableSolver::new("pirls diagonal edf workspace")
.factorize(penalized_hessian)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
{
let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
factor.solve_in_place(rhsview.as_mut());
}
let mut tr = 0.0;
for (col, &idx) in positive_indices.iter().enumerate() {
tr += diag[idx] * workspace.final_aug_matrix[[idx, col]];
}
Ok((p as f64 - tr).clamp(mp, p as f64))
}
#[inline]
fn edf_from_solution<F>(
p: usize,
r: usize,
mp: f64,
e_transformed: &Array2<f64>,
solved_at: F,
) -> f64
where
F: Fn(usize, usize) -> f64,
{
let mut tr = 0.0;
for j in 0..r {
for i in 0..p {
tr += solved_at(i, j) * e_transformed[(j, i)];
}
}
(p as f64 - tr).clamp(mp, p as f64)
}
#[cfg(test)]
mod tests {
use super::{
LinearInequalityConstraints, PenaltyConfig, PirlsConfig, PirlsLinearSolvePath,
PirlsProblem, PirlsWorkspace, bernoulli_geometry_from_jet, calculate_deviance,
compute_constraint_kkt_diagnostics, compute_observed_hessian_curvature_arrays,
default_beta_guess_external, fit_model_for_fixed_rho, madsen_lm_accept_factor,
select_active_set_release, should_log_pirls_decision_summary,
should_use_sparse_native_pirls, solve_newton_directionwith_linear_constraints,
solve_newton_directionwith_lower_bounds, update_glmvectors,
};
use crate::matrix::DesignMatrix;
use crate::mixture_link::InverseLinkJet as MixtureInverseLinkJet;
use crate::probability::standard_normal_quantile;
use crate::solver::active_set;
use crate::types::{
Coefficients, GlmLikelihoodFamily, GlmLikelihoodSpec, InverseLink, LinkFunction,
LogSmoothingParamsView,
};
use approx::assert_relative_eq;
use faer::sparse::{SparseColMat, Triplet};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, array};
fn calculate_scale(
beta: &Array1<f64>,
x: ArrayView2<f64>,
y: ArrayView1<f64>,
weights: ArrayView1<f64>,
offset: ArrayView1<f64>,
edf: f64,
link_function: LinkFunction,
) -> f64 {
match link_function {
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic
| LinkFunction::Log => 1.0,
LinkFunction::Identity => {
let mut fitted = x.dot(beta);
fitted += &offset;
let residuals = &y - &fitted;
let weighted_rss: f64 = weights
.iter()
.zip(residuals.iter())
.map(|(&w, &r)| w * r * r)
.sum();
let effective_n = y.len() as f64;
(weighted_rss / (effective_n - edf).max(1.0)).sqrt()
}
}
}
#[test]
fn madsen_lm_reject_trajectory_doubles_per_rejection() {
let mut loop_lambda = 1.0_f64;
let mut v = 2.0_f64;
let trajectory = (0..6)
.map(|_| {
loop_lambda *= v;
v *= 2.0;
loop_lambda
})
.collect::<Vec<_>>();
assert_eq!(
trajectory,
vec![2.0, 8.0, 64.0, 1024.0, 32_768.0, 2_097_152.0],
"Madsen rejection trajectory must double the multiplier each time"
);
}
#[test]
fn madsen_lm_accept_factor_matches_canonical_textbook_values() {
let cases: &[(f64, f64, &str)] = &[
(1.0, 1.0 / 3.0, "rho=1: floored at 1/3 (cube=1, 1-cube=0)"),
(0.75, 0.875, "rho=0.75: 1 - (0.5)^3 = 0.875 (slight shrink)"),
(0.5, 1.0, "rho=0.5: 1 - 0 = 1.0 (no change)"),
(
0.25,
1.125,
"rho=0.25: 1 - (-0.5)^3 = 1.125 (slight expand)",
),
];
for (rho, expected, why) in cases {
let got = madsen_lm_accept_factor(*rho);
assert!(
(got - expected).abs() < 1e-12,
"madsen_lm_accept_factor({rho}) = {got:.6}, expected {expected:.6} — {why}"
);
}
let small_positive = madsen_lm_accept_factor(1e-9);
assert!(
(small_positive - 2.0).abs() < 1e-6,
"rho ≈ 0⁺ must approach the 2.0 cap; got {small_positive:.6}"
);
assert_eq!(madsen_lm_accept_factor(-100.0), 2.0);
assert_eq!(madsen_lm_accept_factor(100.0), 1.0 / 3.0);
assert!(madsen_lm_accept_factor(0.99).is_finite());
assert!(madsen_lm_accept_factor(0.01) <= 2.0 + 1e-15);
assert!(madsen_lm_accept_factor(0.99) >= 1.0 / 3.0 - 1e-15);
}
#[test]
fn gaussian_scale_uses_offset_in_residuals() {
let x = array![[1.0], [2.0], [3.0]];
let beta = array![2.0];
let offset = array![10.0, 20.0, 30.0];
let y = array![12.0, 24.0, 36.0]; let w = Array1::ones(3);
let scale = calculate_scale(
&beta,
x.view(),
y.view(),
w.view(),
offset.view(),
0.0,
LinkFunction::Identity,
);
assert!(
scale.abs() < 1e-12,
"scale must be ~0 for exact fit with offset; got {}",
scale
);
}
#[test]
fn gaussian_scale_matchesweighted_sdwith_offset() {
let x = array![[1.0], [2.0], [4.0]];
let beta = array![1.5];
let offset = array![0.5, -1.0, 2.0];
let y = array![2.2, 2.0, 7.5];
let w = array![1.0, 2.0, 0.5];
let edf = 1.25;
let scale = calculate_scale(
&beta,
x.view(),
y.view(),
w.view(),
offset.view(),
edf,
LinkFunction::Identity,
);
let mut fitted = x.dot(&beta);
fitted += &offset;
let rss: f64 = w
.iter()
.zip(y.iter().zip(fitted.iter()))
.map(|(&wi, (&yi, &fi))| wi * (yi - fi).powi(2))
.sum();
let expected = (rss / ((y.len() as f64 - edf).max(1.0))).sqrt();
assert!(
(scale - expected).abs() < 1e-12,
"scale mismatch: got {}, expected {}",
scale,
expected
);
}
#[test]
fn kkt_diagnosticszero_for_strictly_feasible_stationary_point() {
let constraints = LinearInequalityConstraints {
a: array![[1.0, 0.0], [0.0, 1.0]],
b: array![0.0, 0.0],
};
let beta = array![1.0, 2.0];
let grad = array![0.0, 0.0];
let diag = compute_constraint_kkt_diagnostics(&beta, &grad, &constraints);
assert!(diag.primal_feasibility <= 1e-12);
assert!(diag.dual_feasibility <= 1e-12);
assert!(diag.complementarity <= 1e-12);
assert!(diag.stationarity <= 1e-12);
}
#[test]
fn kkt_diagnostics_capture_active_lower_bound_solution() {
let constraints = LinearInequalityConstraints {
a: array![[1.0, 0.0], [0.0, 1.0]],
b: array![0.0, 0.0],
};
let beta = array![0.0, 1.5];
let grad = array![2.0, 0.0];
let diag = compute_constraint_kkt_diagnostics(&beta, &grad, &constraints);
assert_eq!(diag.n_constraints, 2);
assert_eq!(diag.n_active, 1);
assert!(diag.primal_feasibility <= 1e-12);
assert!(diag.dual_feasibility <= 1e-12);
assert!(diag.complementarity <= 1e-12);
assert!(diag.stationarity <= 1e-10);
}
#[test]
fn linear_constraint_active_set_releases_positive_kkt_systemmultiplier() {
let hessian = array![[1.0]];
let gradient = array![-1.0];
let beta = array![0.0];
let constraints = LinearInequalityConstraints {
a: array![[1.0], [-1.0]],
b: array![0.0, -0.1],
};
let mut direction = Array1::zeros(1);
solve_newton_directionwith_linear_constraints(
&hessian,
&gradient,
&beta,
&constraints,
&mut direction,
None,
)
.expect("constrained Newton direction should solve");
assert!(
(direction[0] - 0.1).abs() <= 1e-10,
"expected step to upper bound (0.1), got {}",
direction[0]
);
}
#[test]
fn linear_constraint_active_set_ignores_near_tangential_inactiverows() {
let hessian = array![[1.0, 0.0], [0.0, 1.0]];
let gradient = array![-1.0, 0.0];
let beta = array![0.0, 0.0];
let constraints = LinearInequalityConstraints {
a: array![[-1e-16, 1.0]],
b: array![-1.0],
};
let mut direction = Array1::zeros(2);
solve_newton_directionwith_linear_constraints(
&hessian,
&gradient,
&beta,
&constraints,
&mut direction,
None,
)
.expect("near-tangential inactive row should not block the Newton step");
assert!(
(direction[0] - 1.0).abs() <= 1e-12,
"expected unconstrained x-step of 1.0, got {}",
direction[0]
);
assert!(
direction[1].abs() <= 1e-12,
"expected zero y-step, got {}",
direction[1]
);
}
#[test]
fn default_beta_guess_logit_uses_log_odds_prevalence() {
let y = array![0.0, 1.0, 1.0, 1.0];
let w = Array1::ones(4);
let beta =
default_beta_guess_external(3, LinkFunction::Logit, y.view(), w.view(), None, None);
let prevalence: f64 = (3.0 + 0.5) / (4.0 + 1.0);
let prevalence = prevalence.max(1e-6_f64).min(1.0_f64 - 1e-6_f64);
let expected = (prevalence / (1.0 - prevalence)).ln();
assert!((beta[0] - expected).abs() < 1e-12);
assert_eq!(beta[1], 0.0);
assert_eq!(beta[2], 0.0);
}
#[test]
fn default_beta_guess_probit_uses_standard_normal_quantile() {
let y = array![0.0, 1.0, 1.0, 1.0];
let w = Array1::ones(4);
let beta =
default_beta_guess_external(3, LinkFunction::Probit, y.view(), w.view(), None, None);
let prevalence: f64 = (3.0 + 0.5) / (4.0 + 1.0);
let prevalence = prevalence.max(1e-6_f64).min(1.0_f64 - 1e-6_f64);
let log_odds = (prevalence / (1.0 - prevalence)).ln();
let expected =
standard_normal_quantile(prevalence).expect("clamped prevalence must be valid");
assert!((expected - log_odds).abs() > 1e-3);
assert!((beta[0] - expected).abs() < 1e-12);
assert_eq!(beta[1], 0.0);
assert_eq!(beta[2], 0.0);
}
#[test]
fn sparse_native_decision_rejects_dense_design() {
let x = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![
[1.0, 0.0],
[0.0, 1.0]
]));
let s = array![[1.0, 0.0], [0.0, 1.0]];
let mut workspace = PirlsWorkspace::new(2, 2, 0, 0);
let decision = should_use_sparse_native_pirls(&mut workspace, &x, &s, None, None);
assert_eq!(decision.path, PirlsLinearSolvePath::DenseTransformed);
assert_eq!(decision.reason, "design_not_sparse");
}
#[test]
fn pirls_decision_summary_logs_on_power_of_two_repetitions() {
assert!(!should_log_pirls_decision_summary(1));
assert!(should_log_pirls_decision_summary(2));
assert!(!should_log_pirls_decision_summary(3));
assert!(should_log_pirls_decision_summary(4));
assert!(!should_log_pirls_decision_summary(6));
assert!(should_log_pirls_decision_summary(8));
}
#[test]
fn sparse_native_decision_collects_sparse_stats_for_large_sparse_design() {
let triplets: Vec<_> = (0..300).map(|i| Triplet::new(i, i, 1.0)).collect();
let x = SparseColMat::try_new_from_triplets(300, 300, &triplets)
.expect("sparse identity should build");
let x = DesignMatrix::from(x);
let s = Array2::from_diag(&Array1::ones(300));
let mut workspace = PirlsWorkspace::new(300, 300, 0, 0);
let decision = should_use_sparse_native_pirls(&mut workspace, &x, &s, None, None);
assert_eq!(decision.path, PirlsLinearSolvePath::SparseNative);
assert_eq!(decision.reason, "sparse_native_eligible");
assert_eq!(decision.nnz_x, 300);
assert_eq!(decision.nnz_xtwx_symbolic, Some(300));
assert_eq!(decision.nnz_h_est, Some(300));
assert!(decision.density_h_est.expect("density") < 0.01);
}
#[test]
fn sparse_native_decision_allows_moderate_sparse_designs_below_old_width_gate() {
let triplets: Vec<_> = (0..64).map(|i| Triplet::new(i, i, 1.0)).collect();
let x = SparseColMat::try_new_from_triplets(64, 64, &triplets)
.expect("sparse identity should build");
let x = DesignMatrix::from(x);
let s = Array2::from_diag(&Array1::ones(64));
let mut workspace = PirlsWorkspace::new(64, 64, 0, 0);
let decision = should_use_sparse_native_pirls(&mut workspace, &x, &s, None, None);
assert_eq!(decision.path, PirlsLinearSolvePath::SparseNative);
assert_eq!(decision.reason, "sparse_native_eligible");
assert_eq!(decision.nnz_x, 64);
assert_eq!(decision.nnz_xtwx_symbolic, Some(64));
assert_eq!(decision.nnz_h_est, Some(64));
assert!(decision.density_h_est.expect("density") < 0.05);
}
#[test]
fn sparse_native_decision_rejects_finite_lower_bounds() {
let triplets: Vec<_> = (0..64).map(|i| Triplet::new(i, i, 1.0)).collect();
let x = SparseColMat::try_new_from_triplets(64, 64, &triplets)
.expect("sparse identity should build");
let x = DesignMatrix::from(x);
let s = Array2::from_diag(&Array1::ones(64));
let mut lower_bounds = Array1::from_elem(64, f64::NEG_INFINITY);
lower_bounds[0] = 0.0;
let mut workspace = PirlsWorkspace::new(64, 64, 0, 0);
let decision =
should_use_sparse_native_pirls(&mut workspace, &x, &s, Some(&lower_bounds), None);
assert_eq!(decision.path, PirlsLinearSolvePath::DenseTransformed);
assert_eq!(decision.reason, "constraints_present");
}
#[test]
fn sparse_penalized_assembly_matches_dense_diagonal_case() {
let triplets = vec![
Triplet::new(0, 0, 1.0),
Triplet::new(1, 1, 2.0),
Triplet::new(2, 2, 3.0),
];
let x = SparseColMat::try_new_from_triplets(3, 3, &triplets)
.expect("diagonal sparse matrix should build");
let weights = array![2.0, 3.0, 5.0];
let s_lambda = array![[4.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 8.0]];
let ridge = 1e-8;
let mut workspace = PirlsWorkspace::new(3, 3, 0, 0);
let assembled =
super::sparse_reml_penalized_hessian(&mut workspace, &x, &weights, &s_lambda, ridge)
.expect("sparse penalized assembly should succeed");
let dense = DesignMatrix::from(x.clone()).to_dense();
let mut expected = dense.t().dot(&Array2::from_diag(&weights)).dot(&dense);
expected += &s_lambda;
for i in 0..3 {
expected[[i, i]] += ridge;
}
let actual = DesignMatrix::from(assembled).to_dense();
for i in 0..3 {
for j in 0..3 {
let target = if i <= j { expected[[i, j]] } else { 0.0 };
assert!(
(actual[[i, j]] - target).abs() < 1e-10,
"mismatch at ({}, {}): {} vs {}",
i,
j,
actual[[i, j]],
target
);
}
}
}
#[test]
fn pirls_result_stores_integrated_logit_derivative_jet() {
let x = array![[1.0], [1.0], [1.0], [1.0], [1.0]];
let y = array![0.0, 1.0, 0.0, 1.0, 1.0];
let w = Array1::ones(5);
let offset = Array1::zeros(5);
let rho = Array1::<f64>::zeros(1);
let covariate_se = array![0.9, 0.7, 0.8, 0.6, 0.75];
let rs = vec![array![[1.0]]];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialLogit),
link_kind: InverseLink::Standard(LinkFunction::Logit),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (fit, _) = fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: Some(covariate_se.view()),
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 1,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
Some(&Coefficients::new(array![0.0])),
)
.expect("integrated logit PIRLS fit");
let ctx = crate::quadrature::QuadratureContext::new();
for i in 0..y.len() {
let jet = crate::quadrature::integrated_inverse_link_jet(
&ctx,
LinkFunction::Logit,
fit.final_eta[i].clamp(-700.0, 700.0),
covariate_se[i],
)
.expect("logit integrated inverse-link jet should evaluate");
let expected = bernoulli_geometry_from_jet(
fit.final_eta[i],
fit.final_eta[i].clamp(-700.0, 700.0),
y[i],
w[i],
MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
},
);
assert_relative_eq!(
fit.solve_dmu_deta[i],
jet.d1,
epsilon = 1e-9,
max_relative = 1e-9
);
assert_relative_eq!(
fit.solve_d2mu_deta2[i],
jet.d2,
epsilon = 1e-9,
max_relative = 1e-8
);
assert_relative_eq!(
fit.solve_d3mu_deta3[i],
jet.d3,
epsilon = 1e-8,
max_relative = 1e-7
);
assert_relative_eq!(
fit.solve_c_array[i],
expected.c,
epsilon = 1e-9,
max_relative = 1e-8
);
assert_relative_eq!(
fit.solve_d_array[i],
expected.d,
epsilon = 1e-8,
max_relative = 1e-7
);
}
}
#[test]
fn pure_logit_working_state_preserves_tail_fisher_mass() {
let y = array![1.0];
let eta = array![50.0];
let priorweights = array![1.0];
let inverse_link = InverseLink::Standard(LinkFunction::Logit);
let mut mu = Array1::zeros(1);
let mut weights = Array1::zeros(1);
let mut z = Array1::zeros(1);
update_glmvectors(
y.view(),
&eta,
&inverse_link,
priorweights.view(),
&mut mu,
&mut weights,
&mut z,
None,
)
.expect("pure logit working state");
let jet = crate::mixture_link::logit_inverse_link_jet5(eta[0]);
assert!(jet.d1 > 0.0);
assert!(
(weights[0] - jet.d1).abs() < 1e-30,
"pure logit PIRLS weight should equal the stable tail formula at eta={}; got {} vs {}",
eta[0],
weights[0],
jet.d1
);
assert!(
(mu[0] - jet.mu).abs() < 1e-30,
"pure logit PIRLS mu mismatch at eta={}; got {} vs {}",
eta[0],
mu[0],
jet.mu
);
let expected_z = eta[0] + (y[0] - jet.mu) / jet.d1;
assert!(
(z[0] - expected_z).abs() < 1e-12,
"pure logit PIRLS z should preserve the exact working response at eta={}; got {} vs {}",
eta[0],
z[0],
expected_z
);
assert!(
(weights[0] * (z[0] - eta[0]) - (y[0] - jet.mu)).abs() < 1e-30,
"pure logit PIRLS score carrier should preserve y-mu at eta={}; got {} vs {}",
eta[0],
weights[0] * (z[0] - eta[0]),
y[0] - jet.mu
);
}
#[test]
fn gamma_log_deviance_uses_gamma_formula() {
let y = array![2.0, 5.0];
let mu = array![1.0, 4.0];
let w = array![1.5, 0.75];
let dev = calculate_deviance(
y.view(),
&mu,
GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::GammaLog),
w.view(),
);
let expected = 2.0
* (1.5 * (2.0_f64 / 1.0 - 1.0 - (2.0_f64 / 1.0).ln())
+ 0.75 * (5.0_f64 / 4.0 - 1.0 - (5.0_f64 / 4.0).ln()));
assert_relative_eq!(dev, expected, epsilon = 1e-12, max_relative = 1e-12);
}
#[test]
fn gamma_log_observed_curvature_matches_shape_one_closed_form() {
let eta = array![0.2, -0.4];
let mu = eta.mapv(f64::exp);
let y = array![1.8, 0.7];
let w = array![2.0, 0.5];
let dmu = mu.clone();
let d2mu = mu.clone();
let d3mu = mu.clone();
let fisher = w.clone();
let (w_obs, c_obs, d_obs) = compute_observed_hessian_curvature_arrays(
GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::GammaLog),
&InverseLink::Standard(LinkFunction::Log),
&eta,
y.view(),
&mu,
&dmu,
&d2mu,
&d3mu,
&fisher,
w.view(),
)
.expect("gamma-log observed curvature should evaluate");
for i in 0..eta.len() {
let expected_w = w[i] * y[i] / mu[i];
assert_relative_eq!(w_obs[i], expected_w, epsilon = 1e-12, max_relative = 1e-12);
assert_relative_eq!(c_obs[i], -expected_w, epsilon = 1e-12, max_relative = 1e-12);
assert_relative_eq!(d_obs[i], expected_w, epsilon = 1e-12, max_relative = 1e-12);
}
}
#[test]
fn gamma_log_fit_profiles_shape_instead_of_fixing_one() {
let x = array![[1.0], [1.0], [1.0], [1.0], [1.0], [1.0]];
let y = array![0.8, 1.1, 1.7, 2.0, 2.6, 3.1];
let w = Array1::ones(y.len());
let offset = Array1::zeros(y.len());
let rho = array![0.0];
let rs = vec![array![[0.0]]];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::GammaLog),
link_kind: InverseLink::Standard(LinkFunction::Log),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (result, _) = fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: None,
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 1,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
None,
)
.expect("gamma PIRLS fit");
let fitted_shape = result
.likelihood
.gamma_shape()
.expect("gamma fit should expose fitted shape");
let profiled_shape =
super::estimate_gamma_shape_from_eta(y.view(), &result.final_eta, w.view());
assert!(fitted_shape > 1.0, "shape should not stay fixed at one");
assert_relative_eq!(
fitted_shape,
profiled_shape,
epsilon = 1e-10,
max_relative = 1e-10
);
}
#[test]
fn poisson_cache_rehydration_preserves_log_derivatives() {
let x = array![[1.0], [1.0], [1.0], [1.0]];
let y = array![1.0, 2.0, 4.0, 8.0];
let w = Array1::ones(4);
let offset = Array1::zeros(4);
let rho = array![0.0];
let rs = vec![array![[1.0]]];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::PoissonLog),
link_kind: InverseLink::Standard(LinkFunction::Log),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (fit, _) = fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: None,
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 1,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
None,
)
.expect("poisson PIRLS fit");
let compacted = fit.compact_for_reml_cache();
let rehydrated = compacted
.rehydrate_after_reml_cache(
&DesignMatrix::from(x.clone()),
y.view(),
w.view(),
offset.view(),
&InverseLink::Standard(LinkFunction::Log),
)
.expect("rehydration should succeed");
assert_eq!(fit.solve_c_array.len(), rehydrated.solve_c_array.len());
for i in 0..fit.solve_c_array.len() {
assert_relative_eq!(
fit.solve_c_array[i],
rehydrated.solve_c_array[i],
epsilon = 1e-12,
max_relative = 1e-12
);
assert_relative_eq!(
fit.solve_d_array[i],
rehydrated.solve_d_array[i],
epsilon = 1e-12,
max_relative = 1e-12
);
}
}
#[test]
fn linear_constraint_active_set_releases_stalewarm_boundary_hint() {
let hessian = array![[2.0]];
let gradient = array![0.0];
let beta = array![1e-9];
let constraints = LinearInequalityConstraints {
a: array![[1.0]],
b: array![0.0],
};
let mut direction = Array1::zeros(1);
let mut active_hint = vec![0];
solve_newton_directionwith_linear_constraints(
&hessian,
&gradient,
&beta,
&constraints,
&mut direction,
Some(&mut active_hint),
)
.expect("active-set solve should succeed");
assert_relative_eq!(direction[0], 0.0, epsilon = 1e-14);
let projected = &beta + &direction;
assert_relative_eq!(projected[0], beta[0], epsilon = 1e-14);
assert!(active_hint.is_empty());
}
#[test]
fn linear_constraint_active_set_releases_stalewarm_hint() {
let hessian = array![[1.0]];
let gradient = array![-1.0];
let beta = array![0.0];
let constraints = LinearInequalityConstraints {
a: array![[1.0], [-1.0]],
b: array![0.0, -0.1],
};
let mut direction = Array1::zeros(1);
let mut active_hint = vec![0];
solve_newton_directionwith_linear_constraints(
&hessian,
&gradient,
&beta,
&constraints,
&mut direction,
Some(&mut active_hint),
)
.expect("stale warm active-set hint should be releasable");
assert!(
(direction[0] - 0.1).abs() <= 1e-10,
"expected step to upper bound (0.1), got {}",
direction[0]
);
assert_eq!(active_hint, vec![1]);
}
#[test]
fn working_set_kkt_diagnostics_use_active_setmultipliers() {
let working_constraints = LinearInequalityConstraints {
a: array![[1.0, 0.0], [2.0, 0.0], [0.0, 1.0]],
b: array![0.0, 0.0, 0.0],
};
let x = array![0.0, 0.0];
let lambda_true = array![1.0, 0.5, 2.0];
let gradient = working_constraints.a.t().dot(&lambda_true);
let kkt = active_set::working_set_kkt_diagnostics_from_multipliers(
&x,
&gradient,
&working_constraints,
&lambda_true,
3,
)
.expect("working-set KKT diagnostics");
assert!(kkt.primal_feasibility <= 1e-12);
assert!(kkt.dual_feasibility <= 1e-12);
assert!(kkt.complementarity <= 1e-12);
assert!(kkt.stationarity <= 1e-12);
assert_eq!(kkt.n_active, 3);
}
#[test]
fn compress_activeworking_set_groups_near_collinearrows() {
let constraints = LinearInequalityConstraints {
a: array![
[0.0, 0.5, 0.0],
[0.0, 0.50000000000003, 0.0],
[1.0, 0.0, 0.0]
],
b: array![1e-8, 1.00000000000005e-8, 0.2],
};
let x = array![0.0, 0.0, 0.0];
let active = vec![0, 1, 2];
let compressed = active_set::compress_active_working_set(&x, &constraints, &active)
.expect("compress working set");
assert_eq!(compressed.constraints.a.nrows(), 2);
assert_eq!(compressed.groups.len(), 2);
assert!(
compressed.groups.iter().any(|g| g == &vec![0, 1]),
"near-collinear rows should be grouped together: {:?}",
compressed.groups
);
}
#[test]
fn lower_bound_active_set_releases_stalewarm_boundary_hint() {
let hessian = array![[2.0]];
let gradient = array![0.0];
let beta = array![1e-9];
let lower_bounds = array![0.0];
let mut direction = Array1::zeros(1);
let mut active_hint = vec![0];
solve_newton_directionwith_lower_bounds(
&hessian,
&gradient,
&beta,
&lower_bounds,
&mut direction,
Some(&mut active_hint),
)
.expect("lower-bound active-set solve should succeed");
assert_relative_eq!(direction[0], 0.0, epsilon = 1e-14);
let projected = &beta + &direction;
assert_relative_eq!(projected[0], beta[0], epsilon = 1e-14);
assert!(active_hint.is_empty());
}
#[test]
fn select_active_set_release_worst_violation_picks_most_negative() {
let gradient = array![-0.1, -0.5, -0.2];
let hd = array![0.0, 0.0, 0.0];
let active_idx = vec![0, 1, 2];
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, false),
Some(1)
);
}
#[test]
fn select_active_set_release_blands_picks_lowest_index_with_negative_multiplier() {
let gradient = array![-0.1, -0.5, -0.2];
let hd = array![0.0, 0.0, 0.0];
let active_idx = vec![0, 1, 2];
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, true),
Some(0)
);
}
#[test]
fn select_active_set_release_blands_deadband_ignores_round_off() {
let g = 1.0_f64;
let lambda_noise = -32.0 * f64::EPSILON * g; let gradient = array![g];
let hd = array![lambda_noise - g]; let active_idx = vec![0];
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, true),
None,
"round-off-level multiplier must not trigger Bland's release"
);
let lambda_real = -128.0 * f64::EPSILON * g;
let hd = array![lambda_real - g];
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, true),
Some(0)
);
}
#[test]
fn select_active_set_release_returns_none_when_kkt_satisfied() {
let gradient = array![0.5, 1.0, 0.0];
let hd = array![0.0, 0.0, 0.0];
let active_idx = vec![0, 1, 2];
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, false),
None
);
assert_eq!(
select_active_set_release(&gradient, &hd, &active_idx, true),
None
);
}
#[test]
fn lower_bound_active_set_releases_stalewarm_hint() {
let hessian = array![[1.0]];
let gradient = array![-1.0];
let beta = array![0.0];
let lower_bounds = array![0.0];
let mut direction = Array1::zeros(1);
let mut active_hint = vec![0];
solve_newton_directionwith_lower_bounds(
&hessian,
&gradient,
&beta,
&lower_bounds,
&mut direction,
Some(&mut active_hint),
)
.expect("stale warm lower-bound hint should be releasable");
assert!(
(direction[0] - 1.0).abs() <= 1e-12,
"expected unconstrained step of 1.0 after releasing stale bound, got {}",
direction[0]
);
assert!(active_hint.is_empty());
}
}
#[cfg(test)]
mod root_cause_tests {
use super::*;
use ndarray::{Array1, Array2, array};
fn scalar_working_state(
beta: &Coefficients,
curvature: HessianCurvatureKind,
gradient: f64,
deviance: f64,
) -> WorkingState {
WorkingState {
eta: LinearPredictor::new(array![beta.as_ref()[0]]),
gradient: array![gradient],
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(array![[1.0]]),
log_likelihood: 0.0,
deviance,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: curvature,
gradient_natural_scale: 0.0,
}
}
fn test_working_state(beta: &Coefficients, curvature: HessianCurvatureKind) -> WorkingState {
scalar_working_state(beta, curvature, 1.0, 1.0)
}
#[derive(Default)]
struct CandidateEvalFailureModel {
observed_updates: usize,
fisher_updates: usize,
observed_candidate_calls: usize,
fisher_candidate_calls: usize,
}
impl CandidateEvalFailureModel {
fn state(beta: &Coefficients, curvature: HessianCurvatureKind) -> WorkingState {
test_working_state(beta, curvature)
}
}
impl WorkingModel for CandidateEvalFailureModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
match curvature {
HessianCurvatureKind::Observed => self.observed_updates += 1,
HessianCurvatureKind::Fisher => self.fisher_updates += 1,
}
Ok(Self::state(beta, curvature))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
match curvature {
HessianCurvatureKind::Observed => self.observed_candidate_calls += 1,
HessianCurvatureKind::Fisher => self.fisher_candidate_calls += 1,
}
Err(EstimationError::InvalidInput(format!(
"non-finite candidate evaluation under {curvature:?} curvature at beta={:.3e}",
beta.as_ref()[0],
)))
}
fn supports_observed_information_curvature(&self) -> bool {
true
}
}
#[derive(Default)]
struct PermanentCandidateErrorModel {
candidate_calls: usize,
}
impl WorkingModel for PermanentCandidateErrorModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(test_working_state(beta, curvature))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.candidate_calls += 1;
Err(EstimationError::InvalidSpecification(format!(
"permanent candidate failure under {curvature:?} curvature at beta={:.3e}",
beta.as_ref()[0],
)))
}
}
#[derive(Default)]
struct FirthAcceptedStateFailureModel {
current_state_calls: usize,
candidate_state_calls: usize,
candidate_screen_calls: usize,
}
impl WorkingModel for FirthAcceptedStateFailureModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
if beta.as_ref()[0].abs() < 1e-12 {
self.current_state_calls += 1;
Ok(test_working_state(beta, curvature))
} else {
self.candidate_state_calls += 1;
Err(EstimationError::InvalidInput(format!(
"overflow while re-evaluating accepted candidate under {curvature:?} curvature at beta={:.3e}",
beta.as_ref()[0],
)))
}
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.candidate_screen_calls += 1;
let mut state = test_working_state(beta, curvature);
state.deviance = 0.5;
state.gradient = array![0.5];
Ok(state)
}
}
#[derive(Default)]
struct ActiveConstraintKktModel;
impl WorkingModel for ActiveConstraintKktModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(scalar_working_state(beta, curvature, 1.0, 0.0))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(scalar_working_state(beta, curvature, 1.0, 0.0))
}
}
struct PlateauStatusModel {
gradient: f64,
current_deviance: f64,
candidate_deviance: f64,
}
impl PlateauStatusModel {
fn state(
beta: &Coefficients,
curvature: HessianCurvatureKind,
gradient: f64,
deviance: f64,
) -> WorkingState {
scalar_working_state(beta, curvature, gradient, deviance)
}
}
impl WorkingModel for PlateauStatusModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(Self::state(
beta,
curvature,
self.gradient,
self.current_deviance,
))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(Self::state(
beta,
curvature,
self.gradient,
self.candidate_deviance,
))
}
}
#[test]
fn projected_gradient_excludes_near_bound_kkt_forces() {
let gradient = array![0.5, 1e-4];
let beta = array![1e-6, 2.0];
let lower_bounds = array![0.0, f64::NEG_INFINITY];
let norm = projected_gradient_norm(&gradient, &beta, Some(&lower_bounds));
assert!(
norm < 0.01,
"projected gradient should exclude near-bound KKT force (beta=1e-6, lb=0), got {:.6e}",
norm
);
}
#[test]
fn bound_solver_treats_near_bound_positive_grad_as_active() {
let hessian = array![[2.0, 0.0], [0.0, 2.0]];
let gradient = array![1.0, 0.0];
let beta = array![1e-6, 5.0];
let lower_bounds = array![0.0, f64::NEG_INFINITY];
let mut direction = Array1::zeros(2);
let mut active_hint = vec![];
solve_newton_directionwith_lower_bounds(
&hessian,
&gradient,
&beta,
&lower_bounds,
&mut direction,
Some(&mut active_hint),
)
.expect("solve should succeed");
assert!(
active_hint.contains(&0),
"near-bound coeff with positive gradient should be in active set, got {:?}",
active_hint
);
assert!(
(direction[0] - (-1e-6)).abs() < 1e-14,
"direction should snap to bound (lb - beta = -1e-6), got {:.6e}",
direction[0]
);
}
#[test]
fn pirls_converges_at_active_linear_constraint_kkt_point() {
let mut model = ActiveConstraintKktModel;
let options = WorkingModelPirlsOptions {
max_iterations: 3,
convergence_tolerance: 1e-8,
max_step_halving: 3,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: Some(LinearInequalityConstraints {
a: array![[1.0]],
b: array![0.0],
}),
initial_lm_lambda: None,
};
let summary =
runworking_model_pirls(&mut model, Coefficients::new(array![0.0]), &options, |_| {})
.expect("active-constraint KKT point should be accepted as converged");
assert_eq!(summary.status, PirlsStatus::Converged);
assert!(
summary.lastgradient_norm <= 1e-12,
"KKT-aware stationarity norm should vanish at the constrained optimum, got {:.6e}",
summary.lastgradient_norm
);
let kkt = summary
.constraint_kkt
.expect("linear constraint run should report KKT diagnostics");
assert!(kkt.primal_feasibility <= 1e-12);
assert!(kkt.dual_feasibility <= 1e-12);
assert!(kkt.complementarity <= 1e-12);
assert!(kkt.stationarity <= 1e-12);
}
#[test]
fn certifies_kkt_accepts_biobank_pathological_case() {
let n = 320_000usize;
let p = 20usize;
let g_norm = 1.465e-5;
let tol = 1e-6;
let state = WorkingState {
eta: LinearPredictor::new(Array1::zeros(n)),
gradient: Array1::zeros(p),
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(Array2::zeros((p, p))),
log_likelihood: 0.0,
deviance: 1.0,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: 1.0e3,
};
assert!(
state.certifies_kkt(g_norm, tol),
"scale-invariant certificate should accept biobank pathological case"
);
assert!(
!(g_norm < tol),
"this test must witness the failure of the old absolute test; \
otherwise it does not prove the fix"
);
}
#[test]
fn certifies_kkt_is_scale_invariant() {
let n = 1000usize;
let p = 10usize;
let tol = 1e-6;
let g_norm = 1.0;
let natural_scale = 5.0e6;
let mk_state = |g: Array1<f64>, ns: f64| WorkingState {
eta: LinearPredictor::new(Array1::zeros(n)),
gradient: g,
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(Array2::zeros((p, p))),
log_likelihood: 0.0,
deviance: 0.0,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: ns,
};
let base = mk_state(Array1::zeros(p), natural_scale);
let scaled = mk_state(Array1::zeros(p), natural_scale * 1000.0);
assert_eq!(
base.certifies_kkt(g_norm, tol),
scaled.certifies_kkt(g_norm * 1000.0, tol),
"KKT classification must be invariant under uniform F → c·F"
);
}
#[test]
fn certifies_kkt_accepts_under_either_bound() {
let n = 100usize;
let p = 5usize;
let tol = 1e-6;
let state_well_scaled = WorkingState {
eta: LinearPredictor::new(Array1::zeros(n)),
gradient: Array1::zeros(p),
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(Array2::zeros((p, p))),
log_likelihood: 0.0,
deviance: 0.0,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: 1.0e6,
};
assert!(state_well_scaled.certifies_kkt(0.99e-6 * (1.0 + 1.0e6), tol));
let state_unscaled = WorkingState {
eta: LinearPredictor::new(Array1::zeros(n)),
gradient: Array1::zeros(p),
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(Array2::zeros((p, p))),
log_likelihood: 0.0,
deviance: 0.0,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: 0.0,
};
assert!(state_unscaled.certifies_kkt(2.0e-6, tol));
}
#[test]
fn near_stationary_kkt_uses_ten_times_band() {
let n = 100usize;
let p = 4usize;
let tol = 1e-6;
let state = WorkingState {
eta: LinearPredictor::new(Array1::zeros(n)),
gradient: Array1::zeros(p),
hessian: crate::linalg::matrix::SymmetricMatrix::Dense(Array2::zeros((p, p))),
log_likelihood: 0.0,
deviance: 0.0,
penalty_term: 0.0,
firth: FirthDiagnostics::Inactive,
ridge_used: 0.0,
hessian_curvature: HessianCurvatureKind::Fisher,
gradient_natural_scale: 99.0,
};
assert!(state.near_stationary_kkt(9.9e-4, tol));
assert!(!state.near_stationary_kkt(2.0e-3, tol));
assert!(!state.certifies_kkt(9.9e-4, tol));
}
#[test]
fn newton_decrement_correction_upper_bounds_true_decrement() {
let lambda_min = 0.5_f64;
let lambda_lm = 0.25_f64;
let g = ndarray::array![1.0_f64, 1.0];
let true_decrement_sq = g[0].powi(2) / 2.0 + g[1].powi(2) / 0.5;
let damped_decrement_sq =
g[0].powi(2) / (2.0 + lambda_lm) + g[1].powi(2) / (0.5 + lambda_lm);
let correction = 1.0 + lambda_lm / lambda_min;
let upper_bound = damped_decrement_sq * correction;
assert!(
upper_bound >= true_decrement_sq,
"(1 + λ_lm/λ_min)·damped must upper-bound true decrement: \
upper={:.6} true={:.6}",
upper_bound,
true_decrement_sq,
);
assert!(
upper_bound <= 2.0 * true_decrement_sq,
"correction should not be wildly loose: upper={:.6} true={:.6}",
upper_bound,
true_decrement_sq,
);
}
#[test]
fn lm_gain_ratio_accepts_zero_step_at_stationarity() {
let current_penalized: f64 = 9e5;
let predicted_reduction: f64 = 5e-16;
let actual_reduction: f64 = -1e-14;
let noise_floor = current_penalized.abs().max(1.0) * 1e-14;
let rho = if predicted_reduction > noise_floor {
actual_reduction / predicted_reduction
} else if actual_reduction >= -noise_floor {
1.0 } else {
-1.0
};
assert!(
rho > 0.0,
"near-zero reductions should not hard-reject; rho={:.1}, pred={:.2e}, actual={:.2e}, noise={:.2e}",
rho,
predicted_reduction,
actual_reduction,
noise_floor
);
}
#[test]
fn candidate_evaluation_errors_respect_lm_exhaustion_budget() {
let mut model = CandidateEvalFailureModel::default();
let options = WorkingModelPirlsOptions {
max_iterations: 1,
convergence_tolerance: 1e-8,
max_step_halving: 5,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let err = match runworking_model_pirls(
&mut model,
Coefficients::new(array![0.0]),
&options,
|_| {},
) {
Ok(_) => panic!("candidate evaluation failures should exhaust LM retries and surface"),
Err(err) => err,
};
match err {
EstimationError::PirlsDidNotConverge {
max_iterations,
last_change,
} => {
assert!(
max_iterations == options.max_iterations,
"expected LM exhaustion to surface as PIRLS non-convergence with screening cap"
);
assert!(last_change.is_finite() && last_change > 0.0);
}
other => {
panic!("expected PirlsDidNotConverge from candidate evaluation, got {other:?}")
}
}
assert_eq!(
model.observed_updates, 1,
"the PIRLS iteration should start on observed curvature once"
);
assert_eq!(
model.fisher_updates, 1,
"candidate failure should trigger exactly one observed->Fisher fallback"
);
assert_eq!(
model.observed_candidate_calls, 1,
"observed candidate evaluation should fail once before the Fisher fallback"
);
assert_eq!(
model.fisher_candidate_calls,
options.max_step_halving - 1,
"Fisher candidate evaluation must stop at the configured LM retry budget"
);
}
#[test]
fn permanent_candidate_errors_do_not_trigger_lm_retries() {
let mut model = PermanentCandidateErrorModel::default();
let options = WorkingModelPirlsOptions {
max_iterations: 1,
convergence_tolerance: 1e-8,
max_step_halving: 5,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let err = match runworking_model_pirls(
&mut model,
Coefficients::new(array![0.0]),
&options,
|_| {},
) {
Ok(_) => panic!("permanent candidate failures should surface immediately"),
Err(err) => err,
};
match err {
EstimationError::InvalidSpecification(message) => {
assert!(
message.contains("permanent candidate failure"),
"expected permanent candidate failure, got {message}"
);
}
other => panic!("expected InvalidSpecification, got {other:?}"),
}
assert_eq!(
model.candidate_calls, 1,
"non-retriable candidate failures should not be re-evaluated under stronger damping"
);
}
#[test]
fn firth_candidate_reevaluation_respects_lm_retry_budget() {
let mut model = FirthAcceptedStateFailureModel::default();
let options = WorkingModelPirlsOptions {
max_iterations: 1,
convergence_tolerance: 1e-8,
max_step_halving: 4,
min_step_size: 0.0,
firth_bias_reduction: true,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let err = match runworking_model_pirls(
&mut model,
Coefficients::new(array![0.0]),
&options,
|_| {},
) {
Ok(_) => panic!("Firth candidate reevaluation failures should not loop indefinitely"),
Err(err) => err,
};
match err {
EstimationError::PirlsDidNotConverge {
max_iterations,
last_change,
} => {
assert_eq!(max_iterations, options.max_iterations);
assert!(last_change.is_finite() && last_change > 0.0);
}
other => panic!("expected PirlsDidNotConverge, got {other:?}"),
}
assert_eq!(model.current_state_calls, 1);
assert_eq!(
model.candidate_screen_calls, options.max_step_halving,
"screening pass should retry until the LM budget is exhausted"
);
assert_eq!(
model.candidate_state_calls, options.max_step_halving,
"Firth accepted-state reevaluation must stop at the configured LM retry budget"
);
}
#[test]
fn plateaued_accepted_step_does_not_report_converged_with_large_projected_gradient() {
let mut model = PlateauStatusModel {
gradient: 5e-5,
current_deviance: 1.0,
candidate_deviance: 1.0 - 1.25e-9,
};
let options = WorkingModelPirlsOptions {
max_iterations: 1,
convergence_tolerance: 1e-6,
max_step_halving: 4,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let result =
runworking_model_pirls(&mut model, Coefficients::new(array![0.0]), &options, |_| {})
.expect("plateaued accepted step should still return a final state");
assert_eq!(
result.status,
PirlsStatus::MaxIterationsReached,
"projected gradient 5e-5 is well above the near-stationary band and must not be promoted to Converged/Stalled — the candidate step is accepted but the outer iteration counter must run out as MaxIterationsReached, not be silently re-classified"
);
}
#[test]
fn rejected_noise_scale_step_requires_near_stationary_projected_gradient() {
let mut model = PlateauStatusModel {
gradient: 2e-5,
current_deviance: 1.0e6,
candidate_deviance: 1.0e6 + 1.0,
};
let options = WorkingModelPirlsOptions {
max_iterations: 1,
convergence_tolerance: 1e-6,
max_step_halving: 1,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let result =
runworking_model_pirls(&mut model, Coefficients::new(array![0.0]), &options, |_| {})
.expect("noise-scale rejected step should still preserve the current state");
assert_eq!(
result.status,
PirlsStatus::LmStepSearchExhausted,
"projected gradient 2e-5 exceeds the near-stationary band and must hit the LM-exhaust exit, not be accepted after a noise-scale rejection or fall through to MaxIterationsReached"
);
}
fn assert_deviance_monotone(trace: &[f64], label: &str) {
assert!(
trace.len() >= 2,
"{}: expected at least 2 deviance recordings, got {}",
label,
trace.len()
);
for i in 1..trace.len() {
let prev = trace[i - 1];
let curr = trace[i];
let tol = 1e-8 * prev.abs() + 1e-12;
assert!(
curr <= prev + tol,
"{}: deviance increased at iteration {} -> {}: {:.12e} -> {:.12e} (delta = {:.3e})",
label,
i - 1,
i,
prev,
curr,
curr - prev,
);
}
}
#[test]
fn test_deviance_monotonicity_gaussian() {
let n = 20;
let mut x_data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = i as f64 / (n - 1) as f64;
x_data[[i, 0]] = 1.0; x_data[[i, 1]] = t; y[i] = 3.0 + 2.0 * t + 0.3 * (((i * 17 + 5) % 11) as f64 / 11.0 - 0.5);
}
let w = Array1::ones(n);
let offset = Array1::zeros(n);
let rho = array![0.0]; let rs = vec![array![[0.0, 0.0], [0.0, 1.0]]];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::GaussianIdentity),
link_kind: InverseLink::Standard(LinkFunction::Identity),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (result, trace) = super::test_support::capture_pirls_penalized_deviance(|| {
fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x_data.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: None,
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 2,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
None,
)
});
result.expect("Gaussian P-IRLS fit should succeed");
if trace.len() < 2 {
return;
}
assert_deviance_monotone(&trace, "Gaussian");
}
#[test]
fn test_deviance_monotonicity_logistic() {
let n = 30;
let mut x_data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = (i as f64 / (n - 1) as f64) * 4.0 - 2.0; x_data[[i, 0]] = 1.0;
x_data[[i, 1]] = t;
let eta = 0.5 + 1.5 * t;
let p = 1.0 / (1.0 + (-eta).exp());
let pseudo_random = ((i * 31 + 7) % 17) as f64 / 17.0;
y[i] = if pseudo_random < p { 1.0 } else { 0.0 };
}
let w = Array1::ones(n);
let offset = Array1::zeros(n);
let rho = array![0.0];
let rs = vec![array![[0.0, 0.0], [0.0, 1.0]]];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialLogit),
link_kind: InverseLink::Standard(LinkFunction::Logit),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (result, trace) = super::test_support::capture_pirls_penalized_deviance(|| {
fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x_data.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: None,
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 2,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
None,
)
});
result.expect("Logistic P-IRLS fit should succeed");
assert_deviance_monotone(&trace, "Logistic");
}
#[test]
fn test_deviance_monotonicity_logistic_multiseed() {
let seeds: &[u64] = &[42, 137, 271, 314, 997];
let n = 25;
for &seed in seeds {
let mut x_data = Array2::<f64>::zeros((n, 3));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t1 = (i as f64 / (n - 1) as f64) * 6.0 - 3.0;
let t2 =
((i as u64).wrapping_mul(seed).wrapping_add(13) % 100) as f64 / 100.0 - 0.5;
x_data[[i, 0]] = 1.0;
x_data[[i, 1]] = t1;
x_data[[i, 2]] = t2;
let eta = -0.3 + 1.0 * t1 + 0.8 * t2;
let p = 1.0 / (1.0 + (-eta).exp());
let hash = (i as u64)
.wrapping_mul(seed)
.wrapping_add(seed >> 2)
.wrapping_mul(2654435761);
let pseudo_uniform = (hash % 10000) as f64 / 10000.0;
y[i] = if pseudo_uniform < p { 1.0 } else { 0.0 };
}
let ones: f64 = y.iter().sum();
if ones < 1.0 {
y[0] = 1.0;
}
if ones > (n as f64 - 1.0) {
y[n - 1] = 0.0;
}
let w = Array1::ones(n);
let offset = Array1::zeros(n);
let rho = array![0.0, 0.0];
let rs = vec![
array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
];
let canonical: Vec<crate::construction::CanonicalPenalty> = rs
.iter()
.map(|r| {
let local = r.t().dot(r);
crate::construction::CanonicalPenalty {
root: r.clone(),
col_range: 0..r.ncols(),
total_dim: r.ncols(),
nullity: 0,
local,
positive_eigenvalues: Vec::new(),
op: None,
}
})
.collect();
let config = PirlsConfig {
likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialLogit),
link_kind: InverseLink::Standard(LinkFunction::Logit),
max_iterations: 100,
convergence_tolerance: 1e-8,
firth_bias_reduction: false,
initial_lm_lambda: None,
};
let (result, trace) = super::test_support::capture_pirls_penalized_deviance(|| {
fit_model_for_fixed_rho(
LogSmoothingParamsView::new(rho.view()),
PirlsProblem {
x: x_data.view(),
offset: offset.view(),
y: y.view(),
priorweights: w.view(),
covariate_se: None,
},
PenaltyConfig {
canonical_penalties: &canonical,
balanced_penalty_root: None,
reparam_invariant: None,
p: 3,
coefficient_lower_bounds: None,
linear_constraints_original: None,
penalty_shrinkage_floor: None,
kronecker_factored: None,
},
&config,
None,
)
});
result.unwrap_or_else(|e| {
panic!("Logistic P-IRLS fit failed for seed {}: {:?}", seed, e)
});
assert_deviance_monotone(&trace, &format!("Logistic(seed={})", seed));
}
}
#[test]
fn solve_newton_direction_implicit_matches_dense_at_k500() {
use crate::terms::closed_form_operator::ClosedFormPenaltyOperator;
use crate::terms::penalty_op::PenaltyOp;
const K: usize = 500;
const D: usize = 4;
let mut state: u64 = 0xDEADBEEF_CAFEBABE;
let mut next = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 11) as f64) / ((1u64 << 53) as f64)
};
let mut centers = Array2::<f64>::zeros((K, D));
for i in 0..K {
for j in 0..D {
centers[[i, j]] = next();
}
}
let op = std::sync::Arc::new(ClosedFormPenaltyOperator::new(
centers.view(),
2,
2,
1,
1.0,
None,
None,
0,
None,
));
let p = op.dim();
assert_eq!(p, K);
let s_dense = op.as_dense();
let mut xtwx = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..=i {
let v = if i == j {
2.0 + ((i as f64) * 0.07).sin() * 0.3
} else {
(((i as f64 - j as f64) * 0.13).cos()) * 0.02 / (((i + 1) as f64).sqrt())
};
xtwx[[i, j]] = v;
xtwx[[j, i]] = v;
}
}
let xtwx_diag: Array1<f64> = (0..p).map(|i| xtwx[[i, i]]).collect();
let lambda = 0.1_f64;
let ridge = 0.0_f64;
let gradient = Array1::<f64>::from_shape_fn(p, |i| ((i as f64) * 0.31).sin());
let mut h_dense = xtwx.clone();
for i in 0..p {
for j in 0..p {
h_dense[[i, j]] += lambda * s_dense[[i, j]];
}
}
let mut dense_dir = Array1::<f64>::zeros(p);
super::solve_newton_direction_dense(&h_dense, &gradient, &mut dense_dir)
.expect("dense Newton solve should succeed on synthetic SPD");
let xtwx_for_closure = xtwx.clone();
let apply_xtwx = move |v: &Array1<f64>| -> Array1<f64> { xtwx_for_closure.dot(v) };
let op_pen: &dyn PenaltyOp = op.as_ref();
let mut implicit_dir = Array1::<f64>::zeros(p);
super::solve_newton_direction_implicit(
apply_xtwx,
xtwx_diag.view(),
&[],
&[(lambda, op_pen)],
&gradient,
&mut implicit_dir,
ridge,
1e-12,
4 * p,
)
.expect("implicit Newton solve should succeed on synthetic SPD");
let dense_norm: f64 = dense_dir.iter().map(|v| v * v).sum::<f64>().sqrt();
let mut diff_sq = 0.0_f64;
for i in 0..p {
let d = implicit_dir[i] - dense_dir[i];
diff_sq += d * d;
}
let rel = diff_sq.sqrt() / dense_norm.max(1e-300);
assert!(
rel < 1e-9,
"implicit-PCG vs dense-Cholesky Newton direction relative diff {} exceeds 1e-9",
rel
);
}
#[derive(Default)]
struct InnerFisherButObservedSpdAtMode {
observed_post_calls: usize,
}
impl WorkingModel for InnerFisherButObservedSpdAtMode {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
if curvature == HessianCurvatureKind::Observed {
self.observed_post_calls += 1;
}
Ok(scalar_working_state(beta, curvature, 0.0, 0.0))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(scalar_working_state(beta, curvature, 0.0, 0.0))
}
fn supports_observed_information_curvature(&self) -> bool {
true
}
}
#[test]
fn exported_laplace_observed_exact_when_post_finalization_spd() {
let mut model = InnerFisherButObservedSpdAtMode::default();
let options = WorkingModelPirlsOptions {
max_iterations: 2,
convergence_tolerance: 1e-8,
max_step_halving: 3,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let summary =
runworking_model_pirls(&mut model, Coefficients::new(array![0.0]), &options, |_| {})
.expect("converged scalar model should produce a result");
assert!(
matches!(
summary.exported_laplace_curvature,
ExportedLaplaceCurvature::ObservedExact
),
"post-convergence Observed-SPD must export ObservedExact, got {:?}",
summary.exported_laplace_curvature
);
assert!(
model.observed_post_calls >= 1,
"post-convergence finalization must call update_with_curvature(Observed) \
at least once to assert SPD inertia"
);
}
#[derive(Default)]
struct CanonicalSurrogateModel;
impl WorkingModel for CanonicalSurrogateModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(scalar_working_state(beta, curvature, 0.0, 0.0))
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
Ok(scalar_working_state(beta, curvature, 0.0, 0.0))
}
}
#[test]
fn exported_laplace_surrogate_when_observed_unsupported() {
let mut model = CanonicalSurrogateModel;
let options = WorkingModelPirlsOptions {
max_iterations: 2,
convergence_tolerance: 1e-8,
max_step_halving: 3,
min_step_size: 0.0,
firth_bias_reduction: false,
coefficient_lower_bounds: None,
linear_constraints: None,
initial_lm_lambda: None,
};
let summary =
runworking_model_pirls(&mut model, Coefficients::new(array![0.0]), &options, |_| {})
.expect("canonical surrogate model should converge");
assert!(
matches!(
summary.exported_laplace_curvature,
ExportedLaplaceCurvature::ExpectedInformationSurrogate
),
"model that doesn't support observed information must export \
ExpectedInformationSurrogate (no silent ObservedExact relabel), \
got {:?}",
summary.exported_laplace_curvature
);
}
}