use convergence::effective_kkt_tolerance;
use damping::{
add_scaled_diagonal_to_upper_sparse, compute_lm_d2, update_scaled_diagonal_in_place,
};
pub use edf::StablePLSResult;
use edf::{
calculate_edf_from_sparse_factor, calculate_edf_with_penalty,
calculate_edfwithworkspace_from_factor, calculate_edfwithworkspace_with_penalty,
};
use log_link_working_state::ETA_CLAMP;
pub(crate) use log_link_working_state::MIN_WEIGHT;
use penalty::{
KroneckerQsTransform, PirlsPenalty, WorkingCoordinateDesign, WorkingReparamTransform,
attach_penalty_shift,
};
use pls_solver::solve_penalized_least_squares_implicit;
pub use pls_solver::{GaussianFixedCache, SparseXtwxPrecomputed};
pub use reweight::runworking_model_pirls;
pub(crate) use state::array1_l2_norm;
pub use state::{
AdaptiveKktTolerance, ExportedLaplaceCurvature, FirthDiagnostics, HessianCurvatureKind,
PirlsCoordinateFrame, PirlsLinearSolvePath, PirlsResult, PirlsStatus,
WorkingModelIterationInfo, WorkingModelPirlsResult, WorkingState,
};
const GAMMA_SHAPE_MIN: f64 = 1e-8;
const GAMMA_SHAPE_MAX: f64 = 1e12;
const GAMMA_SHAPE_TARGET_TOL: f64 = 1e-12;
pub(super) const PIRLS_ETA_ABS_CAP: f64 = 40.0;
#[inline]
fn gamma_shape_score(shape: f64, target: f64) -> f64 {
shape.ln() - digamma(shape) - target
}
fn estimate_gamma_shape_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
const EPS: f64 = 1e-12;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_target, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let yi = y[i].max(EPS);
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(EPS);
let ratio = yi / mui;
(wi * (ratio - ratio.ln() - 1.0), wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(t1, w1), (t2, w2)| (t1 + t2, w1 + w2),
);
if total_weight <= 0.0 {
return 1.0;
}
let target = (weighted_target / total_weight).max(0.0);
if target <= GAMMA_SHAPE_TARGET_TOL {
return GAMMA_SHAPE_MAX;
}
let discriminant = (target - 3.0) * (target - 3.0) + 24.0 * target;
let approx = ((3.0 - target) + discriminant.sqrt()) / (12.0 * target);
let mut lo = GAMMA_SHAPE_MIN;
let mut hi = approx.max(1.0);
while hi < GAMMA_SHAPE_MAX && gamma_shape_score(hi, target) > 0.0 {
hi = (hi * 2.0).min(GAMMA_SHAPE_MAX);
}
if gamma_shape_score(hi, target) > 0.0 {
return GAMMA_SHAPE_MAX;
}
for _ in 0..80 {
let mid = 0.5 * (lo + hi);
if gamma_shape_score(mid, target) > 0.0 {
lo = mid;
} else {
hi = mid;
}
if (hi - lo) <= GAMMA_SHAPE_TARGET_TOL * hi.max(1.0) {
break;
}
}
0.5 * (lo + hi)
}
fn estimate_beta_phi_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
const PHI_MIN: f64 = 1e-3;
const PHI_MAX: f64 = 1e6;
const MU_EPS: f64 = 1e-9;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_pearson, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let mui = (1.0 / (1.0 + (-eta[i].clamp(-ETA_CLAMP, ETA_CLAMP)).exp()))
.clamp(MU_EPS, 1.0 - MU_EPS);
let var_unit = mui * (1.0 - mui);
let resid = y[i] - mui;
(wi * resid * resid / var_unit, wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(p1, w1), (p2, w2)| (p1 + p2, w1 + w2),
);
if total_weight <= 0.0 || weighted_pearson <= 0.0 {
return 1.0;
}
let one_plus_phi = (total_weight / weighted_pearson).max(1.0 + PHI_MIN);
(one_plus_phi - 1.0).clamp(PHI_MIN, PHI_MAX)
}
fn estimate_tweedie_phi_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
p: f64,
) -> f64 {
const PHI_MIN: f64 = 1e-6;
const PHI_MAX: f64 = 1e12;
const MU_EPS: f64 = 1e-300;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_pearson, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(MU_EPS);
let resid = y[i] - mui;
let var_unit = mui.powf(p).max(MU_EPS);
(wi * resid * resid / var_unit, wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(p1, w1), (p2, w2)| (p1 + p2, w1 + w2),
);
if total_weight <= 0.0 || !weighted_pearson.is_finite() || weighted_pearson <= 0.0 {
return 1.0;
}
(weighted_pearson / total_weight).clamp(PHI_MIN, PHI_MAX)
}
const NEGBIN_THETA_MIN: f64 = 1e-3;
const NEGBIN_THETA_MAX: f64 = 1e6;
fn negbin_theta_score_and_info(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
theta: f64,
) -> (f64, f64) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let psi_theta = digamma(theta);
let trigamma_theta = trigamma(theta);
let ln_theta = theta.ln();
let inv_theta = 1.0 / theta;
let (score, info) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let yi = y[i];
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(1e-300);
let theta_plus_mu = theta + mui;
let theta_plus_y = theta + yi;
let s = digamma(yi + theta) - psi_theta + ln_theta + 1.0
- theta_plus_mu.ln()
- theta_plus_y / theta_plus_mu;
let info_row = -trigamma(yi + theta) + trigamma_theta - inv_theta + 2.0 / theta_plus_mu
- theta_plus_y / (theta_plus_mu * theta_plus_mu);
(wi * s, wi * info_row)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(s1, i1), (s2, i2)| (s1 + s2, i1 + i2),
);
(score, info)
}
fn estimate_negbin_theta_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (wsum, wmu, wpearson) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64, 0.0_f64);
}
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(1e-300);
let resid = y[i] - mui;
(wi, wi * mui, wi * resid * resid / mui)
})
.reduce(
|| (0.0_f64, 0.0_f64, 0.0_f64),
|(a1, b1, c1), (a2, b2, c2)| (a1 + a2, b1 + b2, c1 + c2),
);
if wsum <= 0.0 {
return 1.0;
}
let mu_bar = wmu / wsum;
let pearson_ratio = wpearson / wsum;
let mut theta = if pearson_ratio > 1.0 + 1e-6 {
(mu_bar / (pearson_ratio - 1.0)).clamp(NEGBIN_THETA_MIN, NEGBIN_THETA_MAX)
} else {
NEGBIN_THETA_MAX
};
let (score_hi, _) = negbin_theta_score_and_info(y, eta, priorweights, NEGBIN_THETA_MAX);
if !score_hi.is_finite() {
return 1.0;
}
if score_hi >= 0.0 {
return NEGBIN_THETA_MAX;
}
let (score_lo, _) = negbin_theta_score_and_info(y, eta, priorweights, NEGBIN_THETA_MIN);
if !score_lo.is_finite() || score_lo <= 0.0 {
return NEGBIN_THETA_MIN;
}
let mut lo = NEGBIN_THETA_MIN;
let mut hi = NEGBIN_THETA_MAX;
theta = theta.clamp(lo, hi);
const MAX_NEWTON_ITERS: usize = 100;
const REL_TOL: f64 = 1e-10;
for _ in 0..MAX_NEWTON_ITERS {
let (score, info) = negbin_theta_score_and_info(y, eta, priorweights, theta);
if !score.is_finite() {
break;
}
if score > 0.0 {
lo = theta;
} else {
hi = theta;
}
let next = if info.is_finite() && info > 0.0 {
let candidate = theta + score / info;
if candidate > lo && candidate < hi {
candidate
} else {
0.5 * (lo + hi)
}
} else {
0.5 * (lo + hi)
};
if (next - theta).abs() <= REL_TOL * theta.max(1.0) {
theta = next;
break;
}
theta = next;
}
theta.clamp(NEGBIN_THETA_MIN, NEGBIN_THETA_MAX)
}
#[derive(Clone, Debug)]
pub struct SparsePirlsDecision {
pub path: PirlsLinearSolvePath,
pub reason: &'static str,
pub p: usize,
pub nnz_x: usize,
pub nnz_xtwx_symbolic: Option<usize>,
pub nnz_s_lambda: usize,
pub nnz_h_est: Option<usize>,
pub density_h_est: Option<f64>,
}
fn fmt_opt_usize(v: Option<usize>) -> String {
v.map(|v| v.to_string()).unwrap_or_else(|| "na".to_string())
}
fn fmt_opt_f64(v: Option<f64>) -> String {
v.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "na".to_string())
}
impl SparsePirlsDecision {
fn path_str(&self) -> &'static str {
match self.path {
PirlsLinearSolvePath::DenseTransformed => "dense_transformed",
PirlsLinearSolvePath::SparseNative => "sparse_native",
}
}
fn format_fields(&self, path: &str) -> String {
format!(
"path={path} reason={} p={} nnz_x={} nnz_xtwx_symbolic={} nnz_s_lambda={} nnz_h_est={} density_h_est={}",
self.reason,
self.p,
self.nnz_x,
fmt_opt_usize(self.nnz_xtwx_symbolic),
self.nnz_s_lambda,
fmt_opt_usize(self.nnz_h_est),
fmt_opt_f64(self.density_h_est),
)
}
fn log_once(&self) {
let path = self.path_str();
let key = self.format_fields(path);
let repetition_count = pirls_decision_repetition_count(key.clone());
if repetition_count == 1 {
log::debug!("[pirls-path] {key}");
return;
}
if should_log_pirls_decision_summary(repetition_count) {
log::debug!(
"[pirls-path] repeated path={} reason={} count={} (suppressing identical decisions)",
path,
self.reason,
repetition_count,
);
}
}
}
fn pirls_decision_repetition_count(log_key: String) -> usize {
static PIRLS_DECISION_LOG_COUNTS: OnceLock<Mutex<HashMap<String, usize>>> = OnceLock::new();
let counts = PIRLS_DECISION_LOG_COUNTS.get_or_init(|| Mutex::new(HashMap::new()));
let mut counts = counts.lock().expect("pirls decision log counter poisoned");
let count = counts.entry(log_key).or_insert(0);
*count += 1;
*count
}
fn should_log_pirls_decision_summary(repetition_count: usize) -> bool {
repetition_count > 1 && repetition_count.is_power_of_two()
}
const SPARSE_NATIVE_MAX_H_DENSITY: f64 = 0.30;
#[derive(Clone, Debug)]
struct SparsePenaltyPattern {
upper_triplets: Vec<(usize, usize, f64)>,
nnz_upper: usize,
}
impl SparsePenaltyPattern {
fn from_dense_upper(matrix: &Array2<f64>, tol: f64) -> Self {
let p = matrix.nrows().min(matrix.ncols());
let mut upper_triplets = Vec::new();
for col in 0..p {
for row in 0..=col {
let value = matrix[[row, col]];
if value.abs() > tol {
upper_triplets.push((row, col, value));
}
}
}
let nnz_upper = upper_triplets.len();
Self {
upper_triplets,
nnz_upper,
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct SparsePenalizedSystemStats {
pub(crate) nnz_xtwx_symbolic: usize,
pub(crate) nnz_s_lambda_upper: usize,
pub(crate) nnz_h_upper: usize,
pub(crate) density_upper: f64,
}
struct SparsePenalizedSystemCache {
xtwx_cache: SparseXtWxCache,
penalty_pattern: SparsePenaltyPattern,
h_upper_symbolic: SymbolicSparseColMat<usize>,
h_uppervalues: Vec<f64>,
h_upper_col_ptr: Vec<usize>,
h_upperrow_idx: Vec<usize>,
p: usize,
}
impl SparsePenalizedSystemCache {
fn new(
x: &SparseColMat<usize, f64>,
penalty_pattern: SparsePenaltyPattern,
) -> Result<Self, EstimationError> {
let xtwx_cache = SparseXtWxCache::new(x)?;
let p = x.ncols();
let h_upper_symbolic = build_penalized_symbolic(
p,
xtwx_cache.xtwx_symbolic.col_ptr(),
xtwx_cache.xtwx_symbolic.row_idx(),
&penalty_pattern.upper_triplets,
)?;
let h_uppervalues = vec![0.0; h_upper_symbolic.row_idx().len()];
Ok(Self {
xtwx_cache,
penalty_pattern,
h_upper_col_ptr: h_upper_symbolic.col_ptr().to_vec(),
h_upperrow_idx: h_upper_symbolic.row_idx().to_vec(),
h_upper_symbolic,
h_uppervalues,
p,
})
}
fn matches(
&self,
x: &SparseColMat<usize, f64>,
penalty_pattern: &SparsePenaltyPattern,
) -> bool {
self.xtwx_cache.matches(x)
&& self.penalty_pattern.nnz_upper == penalty_pattern.nnz_upper
&& self.penalty_pattern.upper_triplets == penalty_pattern.upper_triplets
}
fn stats(&self) -> SparsePenalizedSystemStats {
let upper_total = self.p.saturating_mul(self.p + 1) / 2;
SparsePenalizedSystemStats {
nnz_xtwx_symbolic: self.xtwx_cache.xtwx_symbolic.row_idx().len(),
nnz_s_lambda_upper: self.penalty_pattern.nnz_upper,
nnz_h_upper: self.h_upper_symbolic.row_idx().len(),
density_upper: if upper_total == 0 {
0.0
} else {
self.h_upper_symbolic.row_idx().len() as f64 / upper_total as f64
},
}
}
fn assemble_upper(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
ridge: f64,
precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
if weights.len() != self.xtwx_cache.nrows {
crate::bail_invalid_estim!(
"weights length {} does not match design rows {}",
weights.len(),
self.xtwx_cache.nrows
);
}
let use_precomputed = match precomputed_xtwx {
Some(pre) => {
let col_ptr_ok =
pre.xtwx_symbolic_col_ptr.as_slice() == self.xtwx_cache.xtwx_symbolic.col_ptr();
let row_idx_ok =
pre.xtwx_symbolic_row_idx.as_slice() == self.xtwx_cache.xtwx_symbolic.row_idx();
let values_ok = pre.xtwxvalues.len() == self.xtwx_cache.xtwxvalues.len();
if col_ptr_ok && row_idx_ok && values_ok {
self.xtwx_cache.xtwxvalues.copy_from_slice(&pre.xtwxvalues);
true
} else {
log::warn!(
"[sparse-xtwx-cache] precomputed XᵀWX pattern mismatch; \
falling back to per-call recompute"
);
false
}
}
None => false,
};
if !use_precomputed {
self.xtwx_cache.compute_numeric(x, weights)?;
}
self.h_uppervalues.fill(0.0);
let mut cursor = self.h_upper_col_ptr[..self.p].to_vec();
let xtwx_col_ptr = self.xtwx_cache.xtwx_symbolic.col_ptr();
let xtwxrow_idx = self.xtwx_cache.xtwx_symbolic.row_idx();
for col in 0..self.p {
let start = xtwx_col_ptr[col];
let end = xtwx_col_ptr[col + 1];
for idx in start..end {
let row = xtwxrow_idx[idx];
if row <= col {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < row
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != row
{
crate::bail_invalid_estim!("penalized symbolic pattern missing XtWX entry");
}
self.h_uppervalues[*cursor_idx] += self.xtwx_cache.xtwxvalues[idx];
}
}
}
cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
for &(row, col, value) in &self.penalty_pattern.upper_triplets {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < row
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != row
{
crate::bail_invalid_estim!("penalized symbolic pattern missing penalty entry");
}
self.h_uppervalues[*cursor_idx] += value;
}
if ridge > 0.0 {
cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
for col in 0..self.p {
let cursor_idx = &mut cursor[col];
while *cursor_idx < self.h_upper_col_ptr[col + 1]
&& self.h_upperrow_idx[*cursor_idx] < col
{
*cursor_idx += 1;
}
if *cursor_idx >= self.h_upper_col_ptr[col + 1]
|| self.h_upperrow_idx[*cursor_idx] != col
{
crate::bail_invalid_estim!("penalized symbolic pattern missing diagonal entry");
}
self.h_uppervalues[*cursor_idx] += ridge;
}
}
Ok(SparseColMat::new(
self.h_upper_symbolic.clone(),
self.h_uppervalues.clone(),
))
}
}
fn build_penalized_symbolic(
p: usize,
xtwx_col_ptr: &[usize],
xtwxrow_idx: &[usize],
penalty_triplets: &[(usize, usize, f64)],
) -> Result<SymbolicSparseColMat<usize>, EstimationError> {
let mut cols: Vec<BTreeSet<usize>> = (0..p).map(|_| BTreeSet::new()).collect();
for col in 0..p {
cols[col].insert(col);
let start = xtwx_col_ptr[col];
let end = xtwx_col_ptr[col + 1];
for &row in &xtwxrow_idx[start..end] {
if row <= col {
cols[col].insert(row);
}
}
}
for &(row, col, _) in penalty_triplets {
if row > col || col >= p {
crate::bail_invalid_estim!(
"penalty sparse pattern must be upper-triangular within bounds"
);
}
cols[col].insert(row);
}
let mut col_ptr = Vec::with_capacity(p + 1);
let mut row_idx = Vec::new();
col_ptr.push(0);
for rows in cols {
row_idx.extend(rows.into_iter());
col_ptr.push(row_idx.len());
}
Ok(unsafe { SymbolicSparseColMat::new_unchecked(p, p, col_ptr, None, row_idx) })
}
pub trait WorkingModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError>;
fn update_with_curvature(
&mut self,
beta: &Coefficients,
curvature_kind: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
assert!(core::mem::size_of_val(&curvature_kind) > 0);
self.update(beta)
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, curvature)
}
fn screen_candidate(
&mut self,
beta: &Coefficients,
arr: &Array1<f64>,
linear_predictor: &LinearPredictor,
curvature: HessianCurvatureKind,
) -> Result<CandidateEvaluation, EstimationError> {
assert!(arr.iter().all(|v| !v.is_nan()));
assert!(std::mem::size_of_val(linear_predictor) > 0);
self.update_candidate(beta, curvature)
.map(CandidateEvaluation::Full)
}
fn supports_observed_information_curvature(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct CandidateScreen {
pub penalized_objective: f64,
pub deviance: f64,
pub penalty_term: f64,
pub arithmetic_finite: bool,
}
pub enum CandidateEvaluation {
Screen(CandidateScreen),
Full(WorkingState),
}
impl CandidateEvaluation {
#[inline]
fn penalized_objective(&self, firth_bias_reduction: bool) -> f64 {
match self {
Self::Screen(s) => s.penalized_objective,
Self::Full(state) => {
let mut value = state.deviance + state.penalty_term;
if firth_bias_reduction && let Some(j) = state.jeffreys_logdet() {
value -= 2.0 * j;
}
value
}
}
}
#[inline]
fn arithmetic_finite(&self) -> bool {
match self {
Self::Screen(s) => s.arithmetic_finite,
Self::Full(state) => state.gradient.iter().all(|g| g.is_finite()),
}
}
#[inline]
fn into_full(self) -> Option<WorkingState> {
match self {
Self::Full(state) => Some(state),
Self::Screen(_) => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct PirlsAcceptedStateCacheKey {
curvature: HessianCurvatureKind,
firth_active: bool,
beta_bits: Vec<u64>,
arrow_latent_bits: Option<Vec<u64>>,
}
impl PirlsAcceptedStateCacheKey {
fn requested(
beta: &Coefficients,
curvature: HessianCurvatureKind,
options: &WorkingModelPirlsOptions,
) -> Self {
Self::new(beta, curvature, options.firth_bias_reduction, options)
}
fn accepted(
beta: &Coefficients,
state: &WorkingState,
options: &WorkingModelPirlsOptions,
) -> Self {
Self::new(
beta,
state.hessian_curvature,
matches!(state.firth, FirthDiagnostics::Active { .. }),
options,
)
}
fn new(
beta: &Coefficients,
curvature: HessianCurvatureKind,
firth_active: bool,
options: &WorkingModelPirlsOptions,
) -> Self {
let arrow_latent_bits = options.arrow_schur.as_ref().map(|arrow_cfg| {
arrow_cfg.snapshot_t.as_ref()()
.iter()
.map(|value| value.to_bits())
.collect()
});
Self {
curvature,
firth_active,
beta_bits: beta.as_ref().iter().map(|value| value.to_bits()).collect(),
arrow_latent_bits,
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct IntegratedWorkingInput<'a> {
pub quadctx: &'a crate::quadrature::QuadratureContext,
pub se: ArrayView1<'a, f64>,
pub mixture_link_state: Option<&'a MixtureLinkState>,
pub sas_link_state: Option<&'a SasLinkState>,
}
pub struct WorkingDerivativeBuffersMut<'a> {
c: &'a mut Array1<f64>,
d: &'a mut Array1<f64>,
dmu_deta: &'a mut Array1<f64>,
d2mu_deta2: &'a mut Array1<f64>,
d3mu_deta3: &'a mut Array1<f64>,
}
pub(super) struct WorkingSlices<'a> {
pub mu: &'a mut [f64],
pub weights: &'a mut [f64],
pub z: &'a mut [f64],
}
pub(super) struct WorkingDerivSlices<'a> {
pub c: &'a mut [f64],
pub d: &'a mut [f64],
pub dmu: &'a mut [f64],
pub d2: &'a mut [f64],
pub d3: &'a mut [f64],
}
#[inline]
pub(super) fn working_slices<'a>(
mu: &'a mut Array1<f64>,
weights: &'a mut Array1<f64>,
z: &'a mut Array1<f64>,
) -> WorkingSlices<'a> {
WorkingSlices {
mu: mu.as_slice_mut().expect("mu must be contiguous"),
weights: weights.as_slice_mut().expect("weights must be contiguous"),
z: z.as_slice_mut().expect("z must be contiguous"),
}
}
#[inline]
pub(super) fn working_deriv_slices<'a>(
derivs: &'a mut WorkingDerivativeBuffersMut<'_>,
) -> WorkingDerivSlices<'a> {
WorkingDerivSlices {
c: derivs.c.as_slice_mut().expect("c must be contiguous"),
d: derivs.d.as_slice_mut().expect("d must be contiguous"),
dmu: derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous"),
d2: derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous"),
d3: derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous"),
}
}
#[derive(Clone, Copy)]
struct WorkingBernoulliGeometry {
mu: f64,
weight: f64,
z: f64,
c: f64,
d: f64,
}
pub(crate) trait WorkingLikelihood {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError>;
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError>;
}
impl WorkingLikelihood for GlmLikelihoodSpec {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
match (&self.spec.response, &self.spec.link, integrated.is_some()) {
(ResponseFamily::Binomial, _, true) => {
let integ = integrated.unwrap();
update_glmvectors_integrated_by_family(
integ.quadctx,
y,
eta,
integ.se,
&self.spec,
priorweights,
mu,
weights,
z,
derivatives,
integ.mixture_link_state,
integ.sas_link_state,
)?;
Ok(())
}
(ResponseFamily::Binomial, link, false) => {
if matches!(link, InverseLink::Mixture(_)) {
crate::bail_invalid_estim!(
"BinomialMixture IRLS update requires explicit mixture link state"
.to_string(),
);
}
update_glmvectors(
y,
eta,
&self.spec.link,
priorweights,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Gaussian, _, _) => {
update_glmvectors(
y,
eta,
&InverseLink::Standard(StandardLink::Identity),
priorweights,
mu,
weights,
z,
None,
)?;
if let Some(phi) = self.scale.fixed_phi() {
if !(phi.is_finite() && phi > 0.0) {
crate::bail_invalid_estim!(
"Gaussian fixed dispersion phi must be finite and positive (got {})",
phi
);
}
if phi != 1.0 {
let inv_phi = 1.0 / phi;
weights.mapv_inplace(|w| w * inv_phi);
}
}
Ok(())
}
(ResponseFamily::Poisson, _, _) => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
(ResponseFamily::Tweedie { p }, _, _) => {
let p = *p;
write_tweedie_log_working_state(
y,
eta,
priorweights,
p,
fixed_glm_dispersion(self),
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::NegativeBinomial { theta, .. }, _, _) => {
let theta = *theta;
write_negative_binomial_log_working_state(
y,
eta,
priorweights,
theta,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Beta { phi }, _, _) => {
let phi = *phi;
write_beta_logit_working_state(
y,
eta,
priorweights,
phi,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Gamma, _, _) => {
write_gamma_log_working_state(
y,
eta,
priorweights,
self.gamma_shape().unwrap_or(1.0),
mu,
weights,
z,
derivatives,
);
Ok(())
}
(ResponseFamily::RoystonParmar, _, _) => Err(EstimationError::InvalidInput(
"RoystonParmar is survival-specific and not a GLM IRLS family".to_string(),
)),
}
}
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError> {
if matches!(self.spec.response, ResponseFamily::Tweedie { .. }) {
validate_tweedie_responses(&y, &priorweights)?;
}
Ok(calculate_deviance(y, mu, self, priorweights))
}
}
pub struct PirlsWorkspace {
pub wz: Array1<f64>,
pub eta_buf: Array1<f64>,
pub scaled_matrix: Array2<f64>, pub final_aug_matrix: Array2<f64>, pub rhs_full: Array1<f64>, pub working_residual: Array1<f64>,
pub weighted_residual: Array1<f64>,
pub delta_eta: Array1<f64>,
pub vec_buf_p: Array1<f64>,
sparse_penalized_system_cache: Option<SparsePenalizedSystemCache>,
pub factorization_scratch: MemBuffer,
pub perm: Vec<usize>,
pub perm_inv: Vec<usize>,
pub factorization_matrix: Array2<f64>,
pub weighted_xvalues: Vec<f64>,
pub weighted_x_chunk: Array2<f64>,
pub hessian_buf: Array2<f64>,
pub matvec_buf: Array1<f64>,
}
impl PirlsWorkspace {
pub fn new(n: usize, p: usize, idx: usize, idx2: usize) -> Self {
assert!(idx < usize::MAX);
assert!(idx2 < usize::MAX);
PirlsWorkspace {
wz: Array1::zeros(n),
eta_buf: Array1::zeros(n),
scaled_matrix: Array2::zeros((0, 0).f()),
final_aug_matrix: Array2::zeros((0, 0).f()),
rhs_full: Array1::zeros(0),
working_residual: Array1::zeros(n),
weighted_residual: Array1::zeros(n),
delta_eta: Array1::zeros(n),
vec_buf_p: Array1::zeros(p),
sparse_penalized_system_cache: None,
factorization_scratch: {
let par = faer::Par::Seq;
let req = faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<f64>(
1,
par,
Spec::new(<LltParams as Auto<f64>>::auto()),
);
MemBuffer::new(req)
},
perm: vec![0; p],
perm_inv: vec![0; p],
factorization_matrix: Array2::zeros((0, 0)),
weighted_xvalues: Vec::new(),
weighted_x_chunk: Array2::zeros((0, 0).f()),
hessian_buf: Array2::zeros((0, 0).f()),
matvec_buf: Array1::zeros(n),
}
}
pub(super) fn add_dense_xtwx_signed(
weights: &Array1<f64>,
weighted_x_scratch: &mut Array2<f64>,
x: &Array2<f64>,
out: &mut Array2<f64>,
) {
*out = crate::solver::estimate::reml::assembly::xt_diag_x_dense_into(
x,
weights,
weighted_x_scratch,
);
}
fn ensure_sparse_penalty_cache(
&mut self,
x: &SparseColMat<usize, f64>,
s_lambda: &Array2<f64>,
) -> Result<(), EstimationError> {
let penalty_pattern = SparsePenaltyPattern::from_dense_upper(s_lambda, 1e-12);
let rebuild = match self.sparse_penalized_system_cache.as_ref() {
Some(cache) => !cache.matches(x, &penalty_pattern),
None => true,
};
if rebuild {
self.sparse_penalized_system_cache =
Some(SparsePenalizedSystemCache::new(x, penalty_pattern)?);
}
Ok(())
}
pub(crate) fn sparse_penalized_system_stats(
&mut self,
x: &SparseColMat<usize, f64>,
s_lambda: &Array2<f64>,
) -> Result<SparsePenalizedSystemStats, EstimationError> {
self.ensure_sparse_penalty_cache(x, s_lambda)?;
Ok(self.sparse_penalized_system_cache.as_ref().unwrap().stats())
}
pub(super) fn assemble_sparse_penalized_hessian(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
s_lambda: &Array2<f64>,
ridge: f64,
precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
self.ensure_sparse_penalty_cache(x, s_lambda)?;
self.sparse_penalized_system_cache
.as_mut()
.unwrap()
.assemble_upper(x, weights, ridge, precomputed_xtwx)
}
}
#[derive(Clone, Debug)]
pub struct WorkingModelPirlsOptions {
pub max_iterations: usize,
pub convergence_tolerance: f64,
pub adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
pub max_step_halving: usize,
pub min_step_size: f64,
pub firth_bias_reduction: bool,
pub coefficient_lower_bounds: Option<Array1<f64>>,
pub linear_constraints: Option<LinearInequalityConstraints>,
pub initial_lm_lambda: Option<f64>,
pub geodesic_acceleration: bool,
pub arrow_schur: Option<ArrowSchurInnerConfig>,
}
#[derive(Clone)]
pub struct ArrowSchurInnerConfig {
pub n_rows: usize,
pub latent_dim: usize,
pub n_beta: usize,
pub build: std::sync::Arc<
dyn Fn(&Array1<f64>) -> Option<crate::solver::arrow_schur::ArrowSchurSystem> + Send + Sync,
>,
pub solver_mode: Option<crate::solver::arrow_schur::ArrowSolverMode>,
pub streaming_chunk_size: Option<usize>,
pub trust_region_radius: f64,
pub block_offsets: Option<Arc<[std::ops::Range<usize>]>>,
pub apply_delta_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
pub snapshot_t: std::sync::Arc<dyn Fn() -> Array1<f64> + Send + Sync>,
pub restore_t: std::sync::Arc<dyn Fn(&Array1<f64>) + Send + Sync>,
}
impl std::fmt::Debug for ArrowSchurInnerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowSchurInnerConfig")
.field("n_rows", &self.n_rows)
.field("latent_dim", &self.latent_dim)
.field("n_beta", &self.n_beta)
.field("solver_mode", &self.solver_mode)
.field("streaming_chunk_size", &self.streaming_chunk_size)
.field("trust_region_radius", &self.trust_region_radius)
.field(
"block_offsets",
&self.block_offsets.as_ref().map(|o| o.len()),
)
.finish_non_exhaustive()
}
}
fn restore_arrow_latent_if_needed(
options: &WorkingModelPirlsOptions,
snapshot: Option<Array1<f64>>,
) {
if let (Some(arrow_cfg), Some(snapshot)) = (options.arrow_schur.as_ref(), snapshot) {
arrow_cfg.restore_t.as_ref()(&snapshot);
}
}
pub(super) fn restore_pending_arrow_latent_if_needed(
options: &WorkingModelPirlsOptions,
pending_snapshot: &mut Option<Array1<f64>>,
) {
restore_arrow_latent_if_needed(options, pending_snapshot.take());
}
pub(super) fn commit_pending_arrow_latent(pending_snapshot: &mut Option<Array1<f64>>) {
drop(pending_snapshot.take());
}
pub(super) const FIXED_STABILIZATION_RIDGE: f64 = 1e-8;
pub(super) struct GamWorkingModel<'a> {
x_original: DesignMatrix,
coordinate_design: WorkingCoordinateDesign,
offset: Array1<f64>,
y: ArrayView1<'a, f64>,
priorweights: ArrayView1<'a, f64>,
penalty: PirlsPenalty,
workspace: PirlsWorkspace,
likelihood: GlmLikelihoodSpec,
link_kind: InverseLink,
firth_bias_reduction: bool,
lastmu: Array1<f64>,
lastweights: Array1<f64>,
lastz: Array1<f64>,
last_c: Array1<f64>,
last_d: Array1<f64>,
lasthessian_weights: Array1<f64>,
lasthessian_c: Array1<f64>,
lasthessian_d: Array1<f64>,
lasthessian_curvature: HessianCurvatureKind,
last_dmu_deta: Array1<f64>,
last_d2mu_deta2: Array1<f64>,
last_d3mu_deta3: Array1<f64>,
last_penalty_term: f64,
x_original_csr: Option<SparseRowMat<usize, f64>>,
covariate_se: Option<Array1<f64>>,
gamma_shape_locked: bool,
beta_phi_locked: bool,
tweedie_phi_locked: bool,
negbin_theta_locked: bool,
quadctx: crate::quadrature::QuadratureContext,
glm_first_step_gram: Option<Array2<f64>>,
glm_first_step_gram_consumed: bool,
}
pub(super) struct GamModelFinalState {
likelihood: GlmLikelihoodSpec,
coordinate_frame: PirlsCoordinateFrame,
finalmu: Array1<f64>,
finalweights: Array1<f64>,
scoreweights: Array1<f64>,
finalz: Array1<f64>,
final_c: Array1<f64>,
final_d: Array1<f64>,
final_dmu_deta: Array1<f64>,
final_d2mu_deta2: Array1<f64>,
final_d3mu_deta3: Array1<f64>,
penalty_term: f64,
}
impl<'a> GamWorkingModel<'a> {
fn new(
x_transformed: Option<DesignMatrix>,
x_original: DesignMatrix,
coordinate_frame: PirlsCoordinateFrame,
offset: ArrayView1<f64>,
y: ArrayView1<'a, f64>,
priorweights: ArrayView1<'a, f64>,
penalty: PirlsPenalty,
workspace: PirlsWorkspace,
likelihood: GlmLikelihoodSpec,
link_kind: InverseLink,
firth_bias_reduction: bool,
transform: Option<WorkingReparamTransform>,
quadctx: crate::quadrature::QuadratureContext,
glm_first_step_gram: Option<Array2<f64>>,
) -> Self {
let coordinate_design = match coordinate_frame {
PirlsCoordinateFrame::OriginalSparseNative => {
WorkingCoordinateDesign::OriginalSparseNative
}
PirlsCoordinateFrame::TransformedQs => {
if let Some(x_transformed) = x_transformed {
WorkingCoordinateDesign::TransformedExplicit {
x_csr: x_transformed.to_csr_cache(),
x_transformed,
}
} else {
WorkingCoordinateDesign::TransformedImplicit {
transform: transform.expect(
"TransformedQs PIRLS coordinate frame requires either x_transformed or qs",
),
}
}
}
};
let x_original_csr = x_original.to_csr_cache();
let n = match &coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => x_original.nrows(),
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
x_transformed.nrows()
}
WorkingCoordinateDesign::TransformedImplicit { .. } => x_original.nrows(),
};
GamWorkingModel {
x_original,
coordinate_design,
offset: offset.to_owned(),
y,
priorweights,
penalty,
workspace,
likelihood,
link_kind,
firth_bias_reduction,
lastmu: Array1::zeros(n),
lastweights: Array1::zeros(n),
lastz: Array1::zeros(n),
last_c: Array1::zeros(n),
last_d: Array1::zeros(n),
lasthessian_weights: Array1::zeros(n),
lasthessian_c: Array1::zeros(n),
lasthessian_d: Array1::zeros(n),
lasthessian_curvature: HessianCurvatureKind::Fisher,
last_dmu_deta: Array1::zeros(n),
last_d2mu_deta2: Array1::zeros(n),
last_d3mu_deta3: Array1::zeros(n),
last_penalty_term: 0.0,
x_original_csr,
covariate_se: None,
gamma_shape_locked: false,
beta_phi_locked: false,
tweedie_phi_locked: false,
negbin_theta_locked: false,
quadctx,
glm_first_step_gram,
glm_first_step_gram_consumed: false,
}
}
fn with_covariate_se(mut self, se: Array1<f64>) -> Self {
self.covariate_se = Some(se);
self
}
fn into_final_state(self) -> GamModelFinalState {
let GamWorkingModel {
coordinate_design,
lastmu,
lastweights,
lastz,
last_c: _,
last_d: _,
lasthessian_weights,
lasthessian_c,
lasthessian_d,
last_dmu_deta,
last_d2mu_deta2,
last_d3mu_deta3,
last_penalty_term,
..
} = self;
let coordinate_frame = match coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => {
PirlsCoordinateFrame::OriginalSparseNative
}
WorkingCoordinateDesign::TransformedExplicit { .. } => {
PirlsCoordinateFrame::TransformedQs
}
WorkingCoordinateDesign::TransformedImplicit { .. } => {
PirlsCoordinateFrame::TransformedQs
}
};
GamModelFinalState {
likelihood: self.likelihood.clone(),
coordinate_frame,
finalmu: lastmu,
finalweights: lasthessian_weights,
scoreweights: lastweights,
finalz: lastz,
final_c: lasthessian_c,
final_d: lasthessian_d,
final_dmu_deta: last_dmu_deta,
final_d2mu_deta2: last_d2mu_deta2,
final_d3mu_deta3: last_d3mu_deta3,
penalty_term: last_penalty_term,
}
}
fn transformed_matvec_into(&self, beta: &Coefficients, out: &mut Array1<f64>) {
self.transformed_matvec_array_into(beta.as_ref(), out);
}
fn transformed_matvec_array_into(&self, beta: &Array1<f64>, out: &mut Array1<f64>) {
match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
if let Some(dense) = x_transformed.as_dense() {
fast_av_into(dense, beta, out);
return;
}
out.assign(&x_transformed.matrixvectormultiply(beta));
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let beta_orig = transform.apply(beta);
if let Some(dense) = self.x_original.as_dense() {
fast_av_into(dense, &beta_orig, out);
} else {
out.assign(&self.x_original.apply(&beta_orig));
}
}
WorkingCoordinateDesign::OriginalSparseNative => {
out.assign(&self.x_original.matrixvectormultiply(beta));
}
}
}
fn transformed_transpose_matvec(&self, vec: &Array1<f64>) -> Array1<f64> {
match &self.coordinate_design {
WorkingCoordinateDesign::OriginalSparseNative => {
self.x_original.transpose_vector_multiply(vec)
}
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
x_transformed.transpose_vector_multiply(vec)
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let xtv = self.x_original.transpose_vector_multiply(vec);
transform.apply_transpose(&xtv)
}
}
}
fn compute_xtwx_blas(
workspace: &mut PirlsWorkspace,
design: &DesignMatrix,
weights: &Array1<f64>,
) -> Result<Array2<f64>, EstimationError> {
match design {
DesignMatrix::Dense(x) if x.is_materialized_dense() => {
let p = x.ncols();
let x_dense = x.to_dense_arc();
if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
workspace.hessian_buf = Array2::zeros((p, p).f());
} else {
workspace.hessian_buf.fill(0.0);
}
if crate::gpu::cuda_selected() {
return crate::solver::gpu::pirls_gpu::weighted_crossprod_gpu(
x_dense.view(),
weights.view(),
)
.map_err(EstimationError::InvalidInput);
}
crate::gpu::log_backend_inventory_once();
let gpu_decision = crate::gpu::decide(
crate::gpu::GpuKernel::DenseXtWX,
crate::gpu::GpuEligibility::BackendNotCompiled,
);
gpu_decision
.require_supported()
.map_err(EstimationError::InvalidInput)?;
gpu_decision.log();
if weights.iter().any(|&w| w < 0.0) {
PirlsWorkspace::add_dense_xtwx_signed(
weights,
&mut workspace.weighted_x_chunk,
x_dense.as_ref(),
&mut workspace.hessian_buf,
);
} else {
PirlsWorkspace::add_dense_xtwx_signed(
weights,
&mut workspace.weighted_x_chunk,
x_dense.as_ref(),
&mut workspace.hessian_buf,
);
}
Ok(std::mem::take(&mut workspace.hessian_buf))
}
_ => crate::matrix::xt_diag_x_signed(
design,
crate::matrix::SignedWeightsView::from_array(weights),
)
.map(|h| h.to_dense())
.map_err(EstimationError::InvalidInput),
}
}
fn penalized_hessian(&mut self, weights: &Array1<f64>) -> Result<Array2<f64>, EstimationError> {
let use_frozen_first_step = !self.glm_first_step_gram_consumed
&& self.glm_first_step_gram.is_some()
&& self.lasthessian_curvature == HessianCurvatureKind::Fisher
&& !matches!(
self.coordinate_design,
WorkingCoordinateDesign::TransformedExplicit { .. }
);
if use_frozen_first_step {
let xtwx = self
.glm_first_step_gram
.take()
.expect("frozen first-step Gram present by the guard above");
self.glm_first_step_gram_consumed = true;
log::debug!(
"[frozen-glm-gram] serving first Fisher-step XᵀWX n-free (p={})",
xtwx.nrows()
);
return match &self.coordinate_design {
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let mut h = transform.conjugate_matrix(&xtwx);
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::OriginalSparseNative => {
let mut h = xtwx;
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::TransformedExplicit { .. } => {
Err(EstimationError::InvalidInput(
"frozen first-step Gram path reached with TransformedExplicit \
coordinate design, which the gate excludes"
.to_string(),
))
}
};
}
match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit { x_transformed, .. } => {
let mut h = Self::compute_xtwx_blas(&mut self.workspace, x_transformed, weights)?;
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let xtwx = Self::compute_xtwx_blas(&mut self.workspace, &self.x_original, weights)?;
let mut h = transform.conjugate_matrix(&xtwx);
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
WorkingCoordinateDesign::OriginalSparseNative => {
let mut h =
Self::compute_xtwx_blas(&mut self.workspace, &self.x_original, weights)?;
self.penalty.add_to_hessian(&mut h);
Ok(h)
}
}
}
fn supports_observed_hessian_curvature(&self) -> bool {
supports_observed_hessian_curvature_for_likelihood(&self.likelihood, &self.link_kind)
}
fn update_hessian_curvature_arrays(
&mut self,
requested: HessianCurvatureKind,
) -> Result<HessianCurvatureKind, EstimationError> {
if requested == HessianCurvatureKind::Fisher || !self.supports_observed_hessian_curvature()
{
self.lasthessian_weights.assign(&self.lastweights);
self.lasthessian_c.assign(&self.last_c);
self.lasthessian_d.assign(&self.last_d);
return Ok(HessianCurvatureKind::Fisher);
}
compute_observed_hessian_curvature_arrays_into(
&self.likelihood,
&self.link_kind,
&self.workspace.eta_buf,
self.y,
&self.lastweights,
self.priorweights,
&mut self.lasthessian_weights,
&mut self.lasthessian_c,
&mut self.lasthessian_d,
)?;
Ok(HessianCurvatureKind::Observed)
}
fn sparse_penalized_hessian(
&mut self,
weights: &Array1<f64>,
ridge: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
let x_sparse = self.x_original.as_sparse().ok_or_else(|| {
EstimationError::InvalidInput(
"sparse-native PIRLS requires a sparse original design".to_string(),
)
})?;
let PirlsPenalty::Dense { s_transformed, .. } = &self.penalty else {
crate::bail_invalid_estim!(
"sparse-native PIRLS requires a dense transformed penalty matrix"
);
};
self.workspace.assemble_sparse_penalized_hessian(
x_sparse,
weights,
s_transformed,
ridge,
None,
)
}
fn screen_candidate_from_direction(
&mut self,
beta: &Coefficients,
direction: &Array1<f64>,
current_eta: &LinearPredictor,
) -> Result<CandidateScreen, EstimationError> {
let n = self.offset.len();
if self.workspace.eta_buf.len() != n {
self.workspace.eta_buf = Array1::zeros(n);
}
if self.workspace.delta_eta.len() != n {
self.workspace.delta_eta = Array1::zeros(n);
}
let mut delta_eta = std::mem::take(&mut self.workspace.delta_eta);
self.transformed_matvec_array_into(direction, &mut delta_eta);
Zip::from(&mut self.workspace.eta_buf)
.and(current_eta.as_ref())
.and(&delta_eta)
.par_for_each(|eta, &base, &d| *eta = base + d);
self.workspace.delta_eta = delta_eta;
let integrated = self.covariate_se.as_ref().map(|se| IntegratedWorkingInput {
quadctx: &self.quadctx,
se: se.view(),
mixture_link_state: self.link_kind.mixture_state(),
sas_link_state: self.link_kind.sas_state(),
});
match &self.link_kind {
InverseLink::Mixture(_)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_) => {
if let Some(integ) = integrated {
update_glmvectors_integrated_for_link(
integ.quadctx,
self.y,
&self.workspace.eta_buf,
integ.se,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
None,
)?;
} else {
update_glmvectors(
self.y,
&self.workspace.eta_buf,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
None,
)?;
}
}
InverseLink::Standard(_) => {
self.likelihood.irls_update(
self.y,
&self.workspace.eta_buf,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
integrated,
None,
)?;
}
}
let deviance = self
.likelihood
.loglik_deviance(self.y, &self.lastmu, self.priorweights)?;
let penalty_term = self.penalty.shifted_quadratic(beta.as_ref());
let penalized_objective = deviance + penalty_term;
let arithmetic_finite = penalized_objective.is_finite()
&& self.workspace.eta_buf.iter().all(|v| v.is_finite())
&& self.lastmu.iter().all(|v| v.is_finite())
&& self.lastweights.iter().all(|v| v.is_finite());
Ok(CandidateScreen {
penalized_objective,
deviance,
penalty_term,
arithmetic_finite,
})
}
}
impl<'a> WorkingModel for GamWorkingModel<'a> {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, HessianCurvatureKind::Fisher)
}
fn update_with_curvature(
&mut self,
beta: &Coefficients,
requested_curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
let n = self.offset.len();
if self.workspace.eta_buf.len() != n {
self.workspace.eta_buf = Array1::zeros(n);
}
if self.workspace.matvec_buf.len() != n {
self.workspace.matvec_buf = Array1::zeros(n);
}
let mut matvec_tmp = std::mem::take(&mut self.workspace.matvec_buf);
self.transformed_matvec_into(beta, &mut matvec_tmp);
self.workspace.eta_buf.assign(&self.offset);
self.workspace.eta_buf += &matvec_tmp;
self.workspace.matvec_buf = matvec_tmp;
if self.likelihood.scale.gamma_shape_is_estimated() && !self.gamma_shape_locked {
let shape =
estimate_gamma_shape_from_eta(self.y, &self.workspace.eta_buf, self.priorweights);
self.likelihood = self.likelihood.clone().with_gamma_shape(shape);
self.gamma_shape_locked = true;
}
if self.likelihood.scale.beta_phi_is_estimated() && !self.beta_phi_locked {
let phi =
estimate_beta_phi_from_eta(self.y, &self.workspace.eta_buf, self.priorweights);
self.likelihood = self.likelihood.clone().with_beta_phi(phi);
self.beta_phi_locked = true;
}
if self.likelihood.scale.tweedie_phi_is_estimated() && !self.tweedie_phi_locked {
if let ResponseFamily::Tweedie { p } = self.likelihood.spec.response {
let phi = estimate_tweedie_phi_from_eta(
self.y,
&self.workspace.eta_buf,
self.priorweights,
p,
);
self.likelihood = self.likelihood.clone().with_tweedie_phi(phi);
self.tweedie_phi_locked = true;
}
}
if self.likelihood.scale.negbin_theta_is_estimated() && !self.negbin_theta_locked {
let theta =
estimate_negbin_theta_from_eta(self.y, &self.workspace.eta_buf, self.priorweights);
self.likelihood = self.likelihood.clone().with_negbin_theta(theta);
self.negbin_theta_locked = true;
}
let integrated = self.covariate_se.as_ref().map(|se| IntegratedWorkingInput {
quadctx: &self.quadctx,
se: se.view(),
mixture_link_state: self.link_kind.mixture_state(),
sas_link_state: self.link_kind.sas_state(),
});
match &self.link_kind {
InverseLink::Mixture(_) => {
if let Some(integ) = integrated {
update_glmvectors_integrated_for_link(
integ.quadctx,
self.y,
&self.workspace.eta_buf,
integ.se,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
} else {
update_glmvectors(
self.y,
&self.workspace.eta_buf,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
InverseLink::LatentCLogLog(_) | InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => {
if let Some(integ) = integrated {
update_glmvectors_integrated_for_link(
integ.quadctx,
self.y,
&self.workspace.eta_buf,
integ.se,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
} else {
update_glmvectors(
self.y,
&self.workspace.eta_buf,
&self.link_kind,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
InverseLink::Standard(_) => {
self.likelihood.irls_update(
self.y,
&self.workspace.eta_buf,
self.priorweights,
&mut self.lastmu,
&mut self.lastweights,
&mut self.lastz,
integrated,
Some(WorkingDerivativeBuffersMut {
c: &mut self.last_c,
d: &mut self.last_d,
dmu_deta: &mut self.last_dmu_deta,
d2mu_deta2: &mut self.last_d2mu_deta2,
d3mu_deta3: &mut self.last_d3mu_deta3,
}),
)?;
}
}
let mut firth = FirthDiagnostics::Inactive;
if self.firth_bias_reduction {
if !inverse_link_has_fisher_weight_jet(&self.link_kind) {
crate::bail_invalid_estim!(
"Firth/Jeffreys PIRLS requested for unsupported inverse link {:?}",
self.link_kind
);
}
let (hat_diag, jeffreys_logdet, firth_score_shift) = match &self.coordinate_design {
WorkingCoordinateDesign::TransformedExplicit {
x_transformed,
x_csr,
} => {
if x_transformed.as_sparse().is_some() {
let csr = x_csr.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"missing CSR cache for sparse transformed design".to_string(),
)
})?;
compute_jeffreys_pirls_diagnostics_sparse(
&self.link_kind,
csr,
self.workspace.eta_buf.view(),
self.priorweights,
)?
} else {
let x_dense_cow = x_transformed.to_dense_cow();
compute_jeffreys_pirls_diagnostics(
&self.link_kind,
x_dense_cow.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
}
WorkingCoordinateDesign::TransformedImplicit { transform } => {
let x_t_dense =
fast_ab(&self.x_original.to_dense(), &transform.materialize_dense());
compute_jeffreys_pirls_diagnostics(
&self.link_kind,
x_t_dense.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
WorkingCoordinateDesign::OriginalSparseNative => {
if self.x_original.as_sparse().is_some() {
let csr = self.x_original_csr.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"missing CSR cache for sparse original design".to_string(),
)
})?;
compute_jeffreys_pirls_diagnostics_sparse(
&self.link_kind,
csr,
self.workspace.eta_buf.view(),
self.priorweights,
)?
} else {
let x_dense = self
.x_original
.try_to_dense_arc(
"Firth diagnostics require dense access to the original design",
)
.map_err(EstimationError::InvalidInput)?;
compute_jeffreys_pirls_diagnostics(
&self.link_kind,
x_dense.view(),
self.workspace.eta_buf.view(),
self.priorweights,
)?
}
}
};
firth = FirthDiagnostics::Active {
jeffreys_logdet,
hat_diag: hat_diag.clone(),
};
ndarray::Zip::from(&mut self.lastz)
.and(&firth_score_shift)
.and(&self.lastweights)
.par_for_each(|zi, &delta_i, &wi| {
if wi > 0.0 {
*zi += delta_i;
}
});
}
let z = &self.lastz;
ndarray::Zip::from(&mut self.workspace.weighted_residual)
.and(&mut self.workspace.working_residual)
.and(&self.workspace.eta_buf)
.and(z)
.and(&self.lastweights)
.par_for_each(|wr, r, &eta, &zi, &wi| {
let residual = eta - zi;
*r = residual;
*wr = residual * wi;
});
let mut gradient = self.transformed_transpose_matvec(&self.workspace.weighted_residual);
let score_norm = array1_l2_norm(&gradient);
let s_beta = self.penalty.shifted_gradient(beta.as_ref());
let s_beta_norm = array1_l2_norm(&s_beta);
gradient += &s_beta;
let hessian_curvature = self.update_hessian_curvature_arrays(requested_curvature)?;
self.lasthessian_curvature = hessian_curvature;
if self.workspace.matvec_buf.len() != n {
self.workspace.matvec_buf = Array1::zeros(n);
}
solver_hessian_weights_into(
&self.lasthessian_weights,
&self.lastweights,
&mut self.workspace.matvec_buf,
);
let solver_weights = std::mem::take(&mut self.workspace.matvec_buf);
let (penalized_hessian, sparsehessian, ridge_used) = if matches!(
self.coordinate_design,
WorkingCoordinateDesign::OriginalSparseNative
) {
let (h_sparse, _factor, ridge_used) =
ensure_sparse_positive_definitewithridge(|ridge| {
self.sparse_penalized_hessian(&solver_weights, ridge)
})?;
(Array2::zeros((0, 0)), Some(h_sparse), ridge_used)
} else {
let mut penalized_hessian = self.penalized_hessian(&solver_weights)?;
assert_symmetric_tol(&penalized_hessian, "PIRLS penalized Hessian", 1e-8);
let ridge_used = ensure_positive_definitewithridge(
&mut penalized_hessian,
"PIRLS penalized Hessian",
)?;
(penalized_hessian, None, ridge_used)
};
self.workspace.matvec_buf = solver_weights;
let deviance = self
.likelihood
.loglik_deviance(self.y, &self.lastmu, self.priorweights)?;
let log_likelihood = calculate_loglikelihood_omitting_constants(
self.y,
&self.lastmu,
&self.likelihood,
self.priorweights,
);
let mut penalty_term = self.penalty.shifted_quadratic(beta.as_ref());
let mut ridge_grad_norm = 0.0;
if ridge_used > 0.0 {
let ridge_penalty = ridge_used * beta.as_ref().dot(beta.as_ref());
penalty_term += ridge_penalty;
gradient.zip_mut_with(beta.as_ref(), |g, &b| *g += ridge_used * b);
ridge_grad_norm = ridge_used * array1_l2_norm(beta.as_ref());
}
self.last_penalty_term = penalty_term;
let gradient_natural_scale = score_norm + s_beta_norm + ridge_grad_norm;
Ok(WorkingState {
eta: LinearPredictor::new(std::mem::replace(
&mut self.workspace.eta_buf,
Array1::zeros(0),
)),
gradient,
hessian: match sparsehessian {
Some(h_sparse) => crate::linalg::matrix::SymmetricMatrix::Sparse(h_sparse),
None => crate::linalg::matrix::SymmetricMatrix::Dense(penalized_hessian),
},
log_likelihood,
deviance,
penalty_term,
firth,
ridge_used,
hessian_curvature,
gradient_natural_scale,
})
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
if !self.firth_bias_reduction {
return self.update_with_curvature(beta, curvature);
}
let firth_enabled = self.firth_bias_reduction;
self.firth_bias_reduction = false;
let result = self.update_with_curvature(beta, curvature);
self.firth_bias_reduction = firth_enabled;
result
}
fn screen_candidate(
&mut self,
beta: &Coefficients,
direction: &Array1<f64>,
current_eta: &LinearPredictor,
curvature: HessianCurvatureKind,
) -> Result<CandidateEvaluation, EstimationError> {
if self.firth_bias_reduction {
return self
.update_candidate(beta, curvature)
.map(CandidateEvaluation::Full);
}
self.screen_candidate_from_direction(beta, direction, current_eta)
.map(CandidateEvaluation::Screen)
}
fn supports_observed_information_curvature(&self) -> bool {
self.supports_observed_hessian_curvature()
}
}
const DENSE_OUTER_MAX_P: usize = 1024;
const DENSE_OUTER_PARALLEL_FLOP_THRESHOLD: u64 = 100_000;
enum XtWxBackend {
Dense(DenseOuterState),
Sparse(SparseSpGemmState),
}
struct DenseOuterState {
xtwx_dense: Array2<f64>,
thread_buffers: Vec<Array2<f64>>,
}
struct SparseSpGemmState {
wxvalues: Vec<f64>,
wx_tvalues: Vec<f64>,
sqrt_weights: Vec<f64>,
info: SparseMatMulInfo,
scratch: MemBuffer,
par: Par,
}
pub(crate) struct SparseXtWxCache {
xtwx_symbolic: SymbolicSparseColMat<usize>,
xtwxvalues: Vec<f64>,
nrows: usize,
ncols: usize,
nnz: usize,
x_col_ptr: Vec<usize>,
xrow_idx: Vec<usize>,
x_t_csc: SparseColMat<usize, f64>,
backend: XtWxBackend,
}
impl SparseXtWxCache {
fn new(x: &SparseColMat<usize, f64>) -> Result<Self, EstimationError> {
let x_t_csc =
x.as_ref().transpose().to_col_major().map_err(|_| {
EstimationError::InvalidInput("failed to transpose to CSC".to_string())
})?;
let (xtwx_symbolic, info) = sparse_sparse_matmul_symbolic(x_t_csc.symbolic(), x.symbolic())
.map_err(|_| {
EstimationError::InvalidInput("failed to build symbolic XtWX cache".to_string())
})?;
let xtwxvalues = vec![0.0; xtwx_symbolic.row_idx().len()];
let backend = if x.ncols() <= DENSE_OUTER_MAX_P {
XtWxBackend::Dense(DenseOuterState {
xtwx_dense: Array2::<f64>::zeros((x.ncols(), x.ncols())),
thread_buffers: Vec::new(),
})
} else {
let par = get_global_parallelism();
let scratch = MemBuffer::new(sparse_sparse_matmul_numeric_scratch::<usize, f64>(
xtwx_symbolic.as_ref(),
par,
));
XtWxBackend::Sparse(SparseSpGemmState {
wxvalues: vec![0.0; x.val().len()],
wx_tvalues: vec![0.0; x_t_csc.val().len()],
sqrt_weights: vec![0.0; x.nrows()],
info,
scratch,
par,
})
};
Ok(Self {
xtwx_symbolic,
xtwxvalues,
nrows: x.nrows(),
ncols: x.ncols(),
nnz: x.val().len(),
x_col_ptr: x.symbolic().col_ptr().to_vec(),
xrow_idx: x.symbolic().row_idx().to_vec(),
x_t_csc,
backend,
})
}
fn matches(&self, x: &SparseColMat<usize, f64>) -> bool {
if self.nrows != x.nrows() || self.ncols != x.ncols() || self.nnz != x.val().len() {
return false;
}
let sym = x.symbolic();
self.x_col_ptr.as_slice() == sym.col_ptr() && self.xrow_idx.as_slice() == sym.row_idx()
}
fn compute_numeric(
&mut self,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
) -> Result<(), EstimationError> {
if weights.len() != self.nrows {
crate::bail_invalid_estim!(
"weights length {} does not match design rows {}",
weights.len(),
self.nrows
);
}
match &mut self.backend {
XtWxBackend::Dense(state) => {
state.compute(self.x_t_csc.as_ref(), weights, self.nrows, self.ncols);
let col_ptr = self.xtwx_symbolic.col_ptr();
let row_idx = self.xtwx_symbolic.row_idx();
let dense = &state.xtwx_dense;
for col in 0..self.ncols {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
if row <= col {
self.xtwxvalues[idx] = dense[[row, col]];
}
}
}
}
XtWxBackend::Sparse(state) => state.compute(
x,
self.x_t_csc.as_ref(),
weights,
self.ncols,
self.xtwx_symbolic.as_ref(),
&mut self.xtwxvalues,
),
}
Ok(())
}
}
impl DenseOuterState {
fn compute(
&mut self,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
n: usize,
p: usize,
) {
assert_eq!(self.xtwx_dense.dim(), (p, p));
self.xtwx_dense.fill(0.0);
if n == 0 || p == 0 {
return;
}
let xtwx_start = std::time::Instant::now();
let nnz_total = x_t.symbolic().row_idx().len() as u64;
let work = nnz_total
.saturating_mul(nnz_total)
.checked_div(n as u64)
.unwrap_or(u64::MAX);
let n_threads = rayon::current_num_threads();
let parallelize = n_threads > 1 && work >= DENSE_OUTER_PARALLEL_FLOP_THRESHOLD;
if !parallelize {
accumulate_outer_upper(&mut self.xtwx_dense, x_t, weights, 0..n);
log::info!(
"[STAGE] PIRLS dense XᵀWX assembly (serial) n={} p={} flops~{} elapsed={:.3}s",
n,
p,
(n as u64).saturating_mul((p as u64).saturating_mul(p as u64)),
xtwx_start.elapsed().as_secs_f64(),
);
return;
}
if self.thread_buffers.len() != n_threads {
self.thread_buffers
.resize_with(n_threads, || Array2::<f64>::zeros((p, p)));
}
let chunk = n.div_ceil(n_threads);
self.thread_buffers
.par_iter_mut()
.enumerate()
.for_each(|(t, buf)| {
buf.fill(0.0);
let start = t * chunk;
let end = (start + chunk).min(n);
if start < end {
accumulate_outer_upper(buf, x_t, weights, start..end);
}
});
for buf in &self.thread_buffers {
self.xtwx_dense += buf;
}
log::info!(
"[STAGE] PIRLS dense XᵀWX assembly (parallel, threads={}) n={} p={} flops~{} elapsed={:.3}s",
rayon::current_num_threads(),
n,
p,
(n as u64).saturating_mul((p as u64).saturating_mul(p as u64)),
xtwx_start.elapsed().as_secs_f64(),
);
}
}
impl SparseSpGemmState {
fn compute(
&mut self,
x: &SparseColMat<usize, f64>,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
p: usize,
xtwx_symbolic: SymbolicSparseColMatRef<'_, usize>,
xtwxvalues: &mut [f64],
) {
let n = x_t.ncols();
assert_eq!(weights.len(), n);
assert_eq!(self.sqrt_weights.len(), n);
assert!(
weights.iter().all(|&w| w.is_finite() && w >= 0.0),
"SparseSpGemmState::compute requires finite nonnegative PIRLS weights"
);
let sqrt_w = self.sqrt_weights.as_mut_slice();
for (dst, &w) in sqrt_w.iter_mut().zip(weights.iter()) {
*dst = w.sqrt();
}
let sqrt_w: &[f64] = sqrt_w;
let x_ref = x.as_ref();
for col in 0..p {
let rows = x_ref.row_idx_of_col_raw(col);
let xvals = x_ref.val_of_col(col);
let range = x_ref.col_range(col);
let dst = &mut self.wxvalues[range];
for ((d, &s), row) in dst.iter_mut().zip(xvals.iter()).zip(rows.iter()) {
*d = s * sqrt_w[row.unbound()];
}
}
for col in 0..n {
let w = sqrt_w[col];
let xvals = x_t.val_of_col(col);
let range = x_t.col_range(col);
let dst = &mut self.wx_tvalues[range];
for (d, &s) in dst.iter_mut().zip(xvals.iter()) {
*d = s * w;
}
}
let wx_ref = SparseColMatRef::new(x.symbolic(), &self.wxvalues[..]);
let wx_t_ref = SparseColMatRef::new(x_t.symbolic(), &self.wx_tvalues[..]);
let stack = MemStack::new(&mut self.scratch);
let xtwxmut = SparseColMatMut::new(xtwx_symbolic, xtwxvalues);
sparse_sparse_matmul_numeric(
xtwxmut,
Accum::Replace,
wx_t_ref,
wx_ref,
1.0,
&self.info,
self.par,
stack,
);
}
}
#[inline]
fn accumulate_outer_upper(
acc: &mut Array2<f64>,
x_t: SparseColMatRef<'_, usize, f64>,
weights: &Array1<f64>,
rows: std::ops::Range<usize>,
) {
assert_eq!(acc.nrows(), acc.ncols());
let p = acc.ncols();
let acc_data = acc
.as_slice_mut()
.expect("dense XᵀWX accumulator is row-major and contiguous");
for i in rows {
let w_i = weights[i].max(0.0);
if w_i == 0.0 {
continue;
}
let cols = x_t.row_idx_of_col_raw(i);
let vals = x_t.val_of_col(i);
let nnz_i = cols.len();
for jj in 0..nnz_i {
let j = cols[jj].unbound();
let wvj = w_i * vals[jj];
let row = &mut acc_data[j * p..j * p + p];
for kk in jj..nnz_i {
let k = cols[kk].unbound();
row[k] += wvj * vals[kk];
}
}
}
}
pub(super) fn compute_jeffreys_pirls_diagnostics_sparse(
link: &InverseLink,
x_design_csr: &SparseRowMat<usize, f64>,
eta: ArrayView1<f64>,
observation_weights: ArrayView1<f64>,
) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
let n = x_design_csr.nrows();
let p = x_design_csr.ncols();
let mut x_dense = Array2::<f64>::zeros((n, p));
let xview = x_design_csr.as_ref();
for i in 0..n {
let vals = xview.val_of_row(i);
let cols = xview.col_idx_of_row_raw(i);
if cols.len() != vals.len() {
crate::bail_invalid_estim!(
"sparse row structure mismatch: column/value lengths differ"
);
}
for (idx, &col) in cols.iter().enumerate() {
x_dense[[i, col.unbound()]] = vals[idx];
}
}
compute_jeffreys_pirls_diagnostics(link, x_dense.view(), eta, observation_weights)
}
pub(super) fn compute_jeffreys_pirls_diagnostics(
link: &InverseLink,
x_design: ArrayView2<f64>,
eta: ArrayView1<f64>,
observation_weights: ArrayView1<f64>,
) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
let op = FirthDenseOperator::build_with_observation_weights_for_link(
link,
&x_design.to_owned(),
&eta.to_owned(),
observation_weights,
)?;
Ok((
op.pirls_hat_diag(),
op.jeffreys_logdet(),
op.pirls_firth_score_shift(),
))
}
fn ensure_positive_definitewithridge(
hess: &mut Array2<f64>,
label: &str,
) -> Result<f64, EstimationError> {
let ridge = if FIXED_STABILIZATION_RIDGE > 0.0 {
FIXED_STABILIZATION_RIDGE
} else {
0.0
};
if hess.cholesky(Side::Lower).is_ok() {
return Ok(0.0);
}
if ridge > 0.0 {
for i in 0..hess.nrows() {
hess[[i, i]] += ridge;
}
if hess.cholesky(Side::Lower).is_ok() {
log::debug!("{} stabilized with fixed ridge {:.1e}.", label, ridge);
return Ok(ridge);
}
}
if let Ok((evals, _)) = hess.eigh(Side::Lower) {
let min_eig = evals.iter().fold(f64::INFINITY, |a, &b| a.min(b));
return Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: min_eig,
});
}
Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NEG_INFINITY,
})
}
pub(super) fn solve_newton_direction_dense(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
) -> Result<(), EstimationError> {
solve_newton_direction_dense_with_factor(hessian, gradient, direction_out).map(|_| ())
}
pub(super) fn solve_direction_with_dense_factor(
factor: &FaerSymmetricFactor,
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
) {
if direction_out.len() != gradient.len() {
*direction_out = Array1::zeros(gradient.len());
}
direction_out.assign(gradient);
let mut rhsview = array1_to_col_matmut(direction_out);
factor.solve_in_place(rhsview.as_mut());
direction_out.mapv_inplace(|v| -v);
}
pub(super) fn solve_newton_direction_dense_with_factor(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
) -> Result<Option<FaerSymmetricFactor>, EstimationError> {
let dense_solve_start = std::time::Instant::now();
let p = hessian.nrows();
if direction_out.len() != gradient.len() {
*direction_out = Array1::zeros(gradient.len());
}
if crate::gpu::cuda_selected() {
let rhs = Array2::from_shape_vec((p, 1), gradient.to_vec()).map_err(|e| {
EstimationError::InvalidInput(format!("CUDA PIRLS RHS layout failed: {e}"))
})?;
let (solved, _) =
crate::solver::gpu::pirls_gpu::cholesky_solve_gpu(hessian.view(), rhs.view())
.map_err(EstimationError::InvalidInput)?;
direction_out.assign(&solved.column(0));
direction_out.mapv_inplace(|v| -v);
if array_is_finite(direction_out) {
log::info!(
"[STAGE] PIRLS dense newton solve backend=CUDA p={} flops~{} elapsed={:.3}s route=\"cuSOLVER potrf/potrs\"",
p,
(p as u64).saturating_mul((p as u64).saturating_mul(p as u64)) / 3,
dense_solve_start.elapsed().as_secs_f64(),
);
return Ok(None);
}
}
let cpu_route = String::from("CPU stable solver");
let factor = StableSolver::new("pirls newton direction")
.factorize(hessian)
.map_err(EstimationError::LinearSystemSolveFailed)?;
solve_direction_with_dense_factor(&factor, gradient, direction_out);
let validation_residual = {
let h_delta = hessian.dot(direction_out);
h_delta
.iter()
.zip(gradient.iter())
.map(|(h, g)| (h + g).abs())
.fold(0.0_f64, f64::max)
};
let g_inf = gradient.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
let rel = validation_residual / (1.0 + g_inf);
if !rel.is_finite() || rel > 1.0e-3 {
let rhs = gradient.mapv(|v| -v);
if let Some(pseudo) = StableSolver::new("pirls newton direction (pseudoinverse fallback)")
.solve_with_pseudoinverse_fallback(hessian, &rhs, 1.0e-10, 1.0e-3, 1.0e-10)
{
direction_out.assign(&pseudo);
log::info!(
"[STAGE] PIRLS dense newton solve backend=CPU p={} elapsed={:.3}s route=\"{} + pseudoinverse fallback (rel={:.3e} > 1e-3)\"",
p,
dense_solve_start.elapsed().as_secs_f64(),
cpu_route,
rel,
);
return Ok(Some(factor));
}
}
if array_is_finite(direction_out) {
log::info!(
"[STAGE] PIRLS dense newton solve backend=CPU p={} flops~{} elapsed={:.3}s route=\"{}\"",
p,
(p as u64).saturating_mul((p as u64).saturating_mul(p as u64)) / 3,
dense_solve_start.elapsed().as_secs_f64(),
cpu_route,
);
return Ok(Some(factor));
}
Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed {
context: "PIRLS dense newton solve exhausted",
},
))
}
pub fn solve_newton_direction_implicit<F>(
apply_xtwx: F,
xtwx_diag: ArrayView1<'_, f64>,
dense_penalties: &[(f64, &Array2<f64>)],
op_penalties: &[(f64, &dyn crate::terms::penalty_op::PenaltyOp)],
gradient: &Array1<f64>,
direction_out: &mut Array1<f64>,
ridge: f64,
rel_tol: f64,
max_iter: usize,
) -> Result<(), EstimationError>
where
F: Fn(&Array1<f64>) -> Array1<f64>,
{
let p = gradient.len();
if xtwx_diag.len() != p {
crate::bail_invalid_estim!(
"solve_newton_direction_implicit: xtwx_diag length {} != gradient length {}",
xtwx_diag.len(),
p
);
}
for (_, s) in dense_penalties.iter() {
if s.nrows() != p || s.ncols() != p {
crate::bail_invalid_estim!(
"solve_newton_direction_implicit: dense penalty dim {}×{} != p={}",
s.nrows(),
s.ncols(),
p
);
}
}
for (_, op) in op_penalties.iter() {
if op.dim() != p {
crate::bail_invalid_estim!(
"solve_newton_direction_implicit: op penalty dim {} != p={}",
op.dim(),
p
);
}
}
if direction_out.len() != p {
*direction_out = Array1::zeros(p);
}
let pcg_start = std::time::Instant::now();
let mut precond_diag = xtwx_diag.to_owned();
if ridge > 0.0 {
precond_diag.mapv_inplace(|d| d + ridge);
}
for (lambda, s) in dense_penalties.iter() {
if *lambda == 0.0 {
continue;
}
for i in 0..p {
precond_diag[i] += *lambda * s[[i, i]];
}
}
for (lambda, op) in op_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let d = op.diag();
for i in 0..p {
precond_diag[i] += *lambda * d[i];
}
}
let apply_h = |v: &Array1<f64>| -> Array1<f64> {
let mut hv = apply_xtwx(v);
if ridge > 0.0 {
hv.zip_mut_with(v, |h, &x| *h += ridge * x);
}
for (lambda, s) in dense_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let sv = fast_av(s, v);
hv.scaled_add(*lambda, &sv);
}
for (lambda, op) in op_penalties.iter() {
if *lambda == 0.0 {
continue;
}
let mut sv = Array1::<f64>::zeros(p);
op.matvec(v.view(), sv.view_mut());
hv.scaled_add(*lambda, &sv);
}
hv
};
let solution =
crate::linalg::utils::solve_spd_pcg(apply_h, gradient, &precond_diag, rel_tol, max_iter)
.ok_or(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed {
context: "PIRLS implicit PCG solve exhausted",
},
))?;
direction_out.assign(&solution);
direction_out.mapv_inplace(|v| -v);
if !array_is_finite(direction_out) {
return Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed {
context: "PIRLS implicit PCG non-finite direction",
},
));
}
log::info!(
"[STAGE] PIRLS implicit (PCG) newton solve p={} dense_pens={} op_pens={} elapsed={:.3}s",
p,
dense_penalties.len(),
op_penalties.len(),
pcg_start.elapsed().as_secs_f64(),
);
Ok(())
}
pub(super) fn project_coefficients_to_lower_bounds(
beta: &mut Array1<f64>,
lower_bounds: &Array1<f64>,
) {
for i in 0..beta.len() {
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] < lb {
beta[i] = lb;
}
}
}
const ACTIVE_BOUND_REL_TOL: f64 = 1e-6;
const ACTIVE_BOUND_ABS_TOL: f64 = 1e-10;
pub(super) fn projected_gradient_norm(
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: Option<&Array1<f64>>,
) -> f64 {
let Some(lb) = lower_bounds else {
return gradient.dot(gradient).sqrt();
};
let mut sum_sq = 0.0;
for i in 0..gradient.len() {
let g = gradient[i];
if lb[i].is_finite() && g > 0.0 {
let slack = beta[i] - lb[i];
let scale = beta[i].abs().max(lb[i].abs()).max(1.0);
let tol = ACTIVE_BOUND_REL_TOL * scale + ACTIVE_BOUND_ABS_TOL;
if slack < tol {
continue;
}
}
sum_sq += g * g;
}
sum_sq.sqrt()
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum PirlsSoftAccept {
NearStationaryPlateau,
BoundarySaturation,
RelativeBandPlateau,
}
#[derive(Clone, Copy, Debug)]
pub(super) enum SoftAcceptProgress {
Realized { dev_change: f64 },
Predicted {
predicted_reduction: f64,
current_penalized: f64,
},
}
#[inline]
pub(super) fn pirls_soft_acceptance(
state: &WorkingState,
projected_grad: f64,
progress: SoftAcceptProgress,
max_abs_eta: f64,
progress_tol: f64,
kkt_tol: f64,
) -> Option<PirlsSoftAccept> {
let objective_scale = state.deviance.abs().max(state.penalty_term.abs()).max(1.0);
let scaled_dev_tol = progress_tol * objective_scale;
let near_stationary_plateau = match progress {
SoftAcceptProgress::Realized { dev_change } => {
state.near_stationary_kkt(projected_grad, kkt_tol) && dev_change.abs() < scaled_dev_tol
}
SoftAcceptProgress::Predicted {
predicted_reduction,
current_penalized,
} => {
let reduction_noise_floor = current_penalized.abs().max(1.0) * 1e-12;
state.near_stationary_kkt(projected_grad, kkt_tol)
&& predicted_reduction.abs() <= reduction_noise_floor
}
};
if near_stationary_plateau {
return Some(PirlsSoftAccept::NearStationaryPlateau);
}
let dev_change = match progress {
SoftAcceptProgress::Realized { dev_change } => dev_change,
SoftAcceptProgress::Predicted { .. } => return None,
};
if max_abs_eta >= PIRLS_ETA_ABS_CAP * (1.0 - 1e-12) && dev_change.abs() < scaled_dev_tol {
return Some(PirlsSoftAccept::BoundarySaturation);
}
if projected_grad <= progress_tol.max(1e-6) * objective_scale
&& dev_change.abs() < scaled_dev_tol * 0.1
&& dev_change >= 0.0
{
return Some(PirlsSoftAccept::RelativeBandPlateau);
}
None
}
pub(super) fn constrained_stationarity_norm(
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: Option<&Array1<f64>>,
linear_constraints: Option<&LinearInequalityConstraints>,
) -> f64 {
if let Some(constraints) = linear_constraints {
let kkt = compute_constraint_kkt_diagnostics(beta, gradient, constraints);
return kkt
.primal_feasibility
.max(kkt.dual_feasibility)
.max(kkt.complementarity)
.max(kkt.stationarity);
}
projected_gradient_norm(gradient, beta, lower_bounds)
}
fn count_dense_upper_nnz(matrix: &Array2<f64>, tol: f64) -> usize {
let p = matrix.nrows().min(matrix.ncols());
let mut nnz = 0usize;
for col in 0..p {
for row in 0..=col {
if matrix[[row, col]].abs() > tol {
nnz += 1;
}
}
}
nnz
}
fn estimate_sparse_native_decision(
workspace: &mut PirlsWorkspace,
x_original: &DesignMatrix,
s_lambda: &Array2<f64>,
coefficient_lower_bounds: Option<&Array1<f64>>,
linear_constraints_original: Option<&LinearInequalityConstraints>,
) -> SparsePirlsDecision {
let p = x_original.ncols();
let nnz_s_lambda = count_dense_upper_nnz(s_lambda, 1e-12);
let dense_reject = |reason: &'static str, nnz_x: usize| SparsePirlsDecision {
path: PirlsLinearSolvePath::DenseTransformed,
reason,
p,
nnz_x,
nnz_xtwx_symbolic: None,
nnz_s_lambda,
nnz_h_est: None,
density_h_est: None,
};
let has_finite_lower_bounds = coefficient_lower_bounds
.map(|lb| lb.iter().any(|bound| bound.is_finite()))
.unwrap_or(false);
if has_finite_lower_bounds || linear_constraints_original.is_some() {
return dense_reject("constraints_present", 0);
}
let x_sparse = if let Some(sparse) = x_original.as_sparse() {
sparse
} else {
let row_chunk_start = std::time::Instant::now();
let n = x_original.nrows();
let chunk = row_chunk_for_byte_budget(n, x_original.ncols());
let mut nnz: usize = 0;
let mut chunks_processed = 0usize;
if chunk > 0 && n > 0 {
let mut start = 0;
while start < n {
let end = (start + chunk).min(n);
chunks_processed += 1;
match x_original.try_row_chunk(start..end) {
Ok(rows) => {
nnz = nnz.saturating_add(rows.iter().filter(|v| v.abs() > 1e-12).count());
}
Err(_) => {
nnz = nnz.saturating_add((end - start).saturating_mul(x_original.ncols()));
}
}
start = end;
}
}
log::info!(
"[STAGE] PIRLS row-chunk generation chunks={} n={} p={} nnz={} elapsed={:.3}s",
chunks_processed,
n,
x_original.ncols(),
nnz,
row_chunk_start.elapsed().as_secs_f64(),
);
return dense_reject("design_not_sparse", nnz);
};
let nnz_x = x_sparse.val().len();
match workspace.sparse_penalized_system_stats(x_sparse, s_lambda) {
Ok(stats) => SparsePirlsDecision {
path: if stats.density_upper <= SPARSE_NATIVE_MAX_H_DENSITY {
PirlsLinearSolvePath::SparseNative
} else {
PirlsLinearSolvePath::DenseTransformed
},
reason: if stats.density_upper <= SPARSE_NATIVE_MAX_H_DENSITY {
"sparse_native_eligible"
} else {
"penalized_hessian_too_dense"
},
p,
nnz_x,
nnz_xtwx_symbolic: Some(stats.nnz_xtwx_symbolic),
nnz_s_lambda: stats.nnz_s_lambda_upper,
nnz_h_est: Some(stats.nnz_h_upper),
density_h_est: Some(stats.density_upper),
},
Err(_) => dense_reject("sparse_stats_failed", nnz_x),
}
}
pub(super) fn should_use_sparse_native_pirls(
workspace: &mut PirlsWorkspace,
x_original: &DesignMatrix,
s_lambda: &Array2<f64>,
coefficient_lower_bounds: Option<&Array1<f64>>,
linear_constraints_original: Option<&LinearInequalityConstraints>,
) -> SparsePirlsDecision {
estimate_sparse_native_decision(
workspace,
x_original,
s_lambda,
coefficient_lower_bounds,
linear_constraints_original,
)
}
pub(crate) fn sparse_reml_penalized_hessian(
workspace: &mut PirlsWorkspace,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
s_lambda: &Array2<f64>,
ridge: f64,
precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
workspace.assemble_sparse_penalized_hessian(x, weights, s_lambda, ridge, precomputed_xtwx)
}
pub(super) fn ensure_sparse_positive_definitewithridge<F>(
mut assemble: F,
) -> Result<
(
SparseColMat<usize, f64>,
crate::linalg::sparse_exact::SparseExactFactor,
f64,
),
EstimationError,
>
where
F: FnMut(f64) -> Result<SparseColMat<usize, f64>, EstimationError>,
{
let h0 = assemble(0.0)?;
if let Ok(factor) = factorize_sparse_spd(&h0) {
return Ok((h0, factor, 0.0));
}
let h_eps = assemble(FIXED_STABILIZATION_RIDGE)?;
if let Ok(factor) = factorize_sparse_spd(&h_eps) {
return Ok((h_eps, factor, FIXED_STABILIZATION_RIDGE));
}
let (gershgorin_min, diag_scale) = gershgorin_min_eig_lower_bound(&h_eps);
let scale = diag_scale.max(1.0);
let margin = FIXED_STABILIZATION_RIDGE * scale;
let direct_ridge = (margin - gershgorin_min).max(FIXED_STABILIZATION_RIDGE);
log::warn!(
"sparse penalized Hessian is not positive definite (Gershgorin λ_min ≥ {:.3e}, \
diag scale {:.3e}); regularizing curvature with direct ridge {:.3e}. Exported \
curvature/SEs are stabilized, not exact — investigate rank-deficiency or weight \
underflow in the Hessian assembly.",
gershgorin_min,
scale,
direct_ridge,
);
for ridge in [direct_ridge, direct_ridge * 2.0] {
let h = assemble(ridge)?;
if let Ok(factor) = factorize_sparse_spd(&h) {
return Ok((h, factor, ridge));
}
}
Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: gershgorin_min,
})
}
fn gershgorin_min_eig_lower_bound(h: &SparseColMat<usize, f64>) -> (f64, f64) {
let n = h.ncols();
let mut diag = vec![0.0_f64; n];
let mut radius = vec![0.0_f64; n];
let (symbolic, values) = h.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..n {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
if row == col {
diag[col] += value;
} else {
let a = value.abs();
radius[row] += a;
radius[col] += a;
}
}
}
let mut min_bound = f64::INFINITY;
let mut diag_scale = 0.0_f64;
for i in 0..n {
min_bound = min_bound.min(diag[i] - radius[i]);
diag_scale = diag_scale.max(diag[i].abs());
}
if !min_bound.is_finite() {
min_bound = f64::NEG_INFINITY;
}
(min_bound, diag_scale)
}
fn solve_subsystem_direction(
h_sub: ndarray::ArrayView2<f64>,
g_sub: ndarray::ArrayView1<f64>,
out: &mut Array1<f64>,
) -> Result<(), EstimationError> {
let n = g_sub.len();
if out.len() != n {
*out = Array1::zeros(n);
}
if let Ok(factor) = StableSolver::new("pirls bounded subsystem").factorize_any(&h_sub) {
out.assign(&g_sub);
let mut rhs = array1_to_col_matmut(out);
factor.solve_in_place(rhs.as_mut());
out.mapv_inplace(|v| -v);
if array_is_finite(out) {
return Ok(());
}
}
let diag_scale = (0..n)
.map(|i| h_sub[[i, i]].abs())
.fold(0.0_f64, f64::max)
.max(1.0);
let mut tau = 1e-8 * diag_scale;
let mut h_reg = h_sub.to_owned();
for _ in 0..12 {
for i in 0..n {
h_reg[[i, i]] = h_sub[[i, i]] + tau;
}
if let Ok(factor) = StableSolver::new("pirls bounded subsystem ridge").factorize(&h_reg) {
out.assign(&g_sub);
let mut rhs = array1_to_col_matmut(out);
factor.solve_in_place(rhs.as_mut());
out.mapv_inplace(|v| -v);
if array_is_finite(out) {
return Ok(());
}
}
tau *= 10.0;
}
let gnorm = g_sub.dot(&g_sub).sqrt();
if gnorm > 0.0 {
let scale = 1.0 / gnorm.max(diag_scale);
for i in 0..n {
out[i] = -g_sub[i] * scale;
}
return Ok(());
}
out.fill(0.0);
Ok(())
}
pub(super) fn linear_constraints_from_lower_bounds(
lower_bounds: &Array1<f64>,
) -> Option<LinearInequalityConstraints> {
LinearInequalityConstraints::from_per_coordinate_lower_bounds(lower_bounds)
}
pub(super) fn compute_constraint_kkt_diagnostics(
beta: &Array1<f64>,
gradient: &Array1<f64>,
constraints: &LinearInequalityConstraints,
) -> ConstraintKktDiagnostics {
active_set::compute_constraint_kkt_diagnostics(beta, gradient, constraints)
}
pub(super) fn select_active_set_release(
gradient: &Array1<f64>,
hd: &Array1<f64>,
active_idx: &[usize],
use_blands: bool,
) -> Option<usize> {
if use_blands {
for &i in active_idx {
let lambda_i = gradient[i] + hd[i];
let scale = gradient[i].abs().max(hd[i].abs()).max(1.0);
let tol = 64.0 * f64::EPSILON * scale;
if lambda_i < -tol {
return Some(i);
}
}
None
} else {
let mut worst = 0.0_f64;
let mut idx = None;
for &i in active_idx {
let lambda_i = gradient[i] + hd[i];
if lambda_i < worst {
worst = lambda_i;
idx = Some(i);
}
}
idx
}
}
pub(crate) fn solve_newton_directionwith_lower_bounds(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
beta: &Array1<f64>,
lower_bounds: &Array1<f64>,
direction_out: &mut Array1<f64>,
active_hint: Option<&mut Vec<usize>>,
) -> Result<(), EstimationError> {
let p = gradient.len();
if lower_bounds.len() != p || beta.len() != p {
crate::bail_invalid_estim!(
"lower-bound size mismatch: beta={}, gradient={}, bounds={}",
beta.len(),
gradient.len(),
lower_bounds.len()
);
}
if direction_out.len() != p {
*direction_out = Array1::zeros(p);
}
direction_out.fill(0.0);
let has_active_hint = active_hint
.as_ref()
.map(|hint| !hint.is_empty())
.unwrap_or(false);
if !has_active_hint && solve_newton_direction_dense(hessian, gradient, direction_out).is_ok() {
let mut feasible = true;
for i in 0..p {
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] + direction_out[i] < lb {
feasible = false;
break;
}
}
if feasible {
return Ok(());
}
}
let mut active = vec![false; p];
if let Some(hint) = active_hint.as_ref() {
for &idx in hint.iter() {
if idx < p {
active[idx] = true;
}
}
}
for i in 0..p {
let lb = lower_bounds[i];
if lb.is_finite() && gradient[i] > 0.0 {
let scale = beta[i].abs().max(lb.abs()).max(1.0);
let tol = ACTIVE_BOUND_REL_TOL * scale + ACTIVE_BOUND_ABS_TOL;
if beta[i] <= lb + tol {
active[i] = true;
}
}
}
const BLANDS_RULE_GRACE: usize = 2;
let blands_threshold = BLANDS_RULE_GRACE * (p + 1);
let max_iters = 8 * (p + 1);
let mut d_free = Array1::<f64>::zeros(p);
let mut h_ff_buf = Array2::<f64>::zeros((p, p));
let mut g_f_buf = Array1::<f64>::zeros(p);
for it in 0..max_iters {
let use_blands = it >= blands_threshold;
let free_idx: Vec<usize> = (0..p).filter(|&i| !active[i]).collect();
let active_idx: Vec<usize> = (0..p).filter(|&i| active[i]).collect();
direction_out.fill(0.0);
for &i in &active_idx {
let lb = lower_bounds[i];
if lb.is_finite() {
direction_out[i] = lb - beta[i];
}
}
if free_idx.is_empty() {
let hd = fast_av(hessian, direction_out);
if let Some(idx) = select_active_set_release(gradient, &hd, &active_idx, use_blands) {
active[idx] = false;
continue;
}
if let Some(hint) = active_hint {
hint.clear();
hint.extend((0..p).filter(|&i| active[i]));
}
return Ok(());
}
let n_free = free_idx.len();
{
let mut h_ff = h_ff_buf.slice_mut(ndarray::s![..n_free, ..n_free]);
let mut g_f = g_f_buf.slice_mut(ndarray::s![..n_free]);
for (ii, &i) in free_idx.iter().enumerate() {
let mut gi = gradient[i];
for &j in &active_idx {
gi += hessian[[i, j]] * direction_out[j];
}
g_f[ii] = gi;
for (jj, &j) in free_idx.iter().enumerate() {
h_ff[[ii, jj]] = hessian[[i, j]];
}
}
}
solve_subsystem_direction(
h_ff_buf.slice(ndarray::s![..n_free, ..n_free]),
g_f_buf.slice(ndarray::s![..n_free]),
&mut d_free,
)?;
for (ii, &i) in free_idx.iter().enumerate() {
direction_out[i] = d_free[ii];
}
let mut hit_idx: Option<usize> = None;
let mut best_alpha = 1.0_f64;
for &i in &free_idx {
let lb = lower_bounds[i];
if !lb.is_finite() {
continue;
}
let slack = beta[i] - lb;
let di = direction_out[i];
if let Some(alpha_i) = boundary_hit_step_fraction(slack, di, best_alpha) {
best_alpha = alpha_i;
hit_idx = Some(i);
}
}
if let Some(i_hit) = hit_idx {
for i in 0..p {
direction_out[i] *= best_alpha;
}
active[i_hit] = true;
continue;
}
let hd = fast_av(hessian, direction_out);
if let Some(idx) = select_active_set_release(gradient, &hd, &active_idx, use_blands) {
active[idx] = false;
continue;
}
if let Some(hint) = active_hint {
hint.clear();
hint.extend((0..p).filter(|&i| active[i]));
}
return Ok(());
}
let gnorm = gradient.dot(gradient).sqrt();
if gnorm > 0.0 {
let diag_scale = (0..p)
.map(|i| hessian[[i, i]].abs())
.fold(0.0_f64, f64::max)
.max(1.0);
let step_scale = 1.0 / diag_scale;
for i in 0..p {
let di = -gradient[i] * step_scale;
let lb = lower_bounds[i];
if lb.is_finite() && beta[i] + di < lb {
direction_out[i] = lb - beta[i];
} else {
direction_out[i] = di;
}
}
} else {
direction_out.fill(0.0);
}
if let Some(hint) = active_hint {
hint.clear();
}
Ok(())
}
pub(super) fn solve_newton_directionwith_linear_constraints(
hessian: &Array2<f64>,
gradient: &Array1<f64>,
beta: &Array1<f64>,
constraints: &LinearInequalityConstraints,
direction_out: &mut Array1<f64>,
active_hint: Option<&mut Vec<usize>>,
) -> Result<(), EstimationError> {
active_set::solve_newton_direction_with_linear_constraints(
hessian,
gradient,
beta,
constraints,
direction_out,
active_hint,
)
}
use loop_driver::assert_symmetric_tol;
pub(crate) use loop_driver::fit_model_for_fixed_rho_with_adaptive_kkt;
pub use loop_driver::{PenaltyConfig, PirlsConfig, PirlsProblem, fit_model_for_fixed_rho};
#[inline]
pub(super) fn standard_inverse_link_jet(
inverse_link: &InverseLink,
eta: f64,
) -> Result<MixtureInverseLinkJet, EstimationError> {
crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta)
}
#[inline]
fn bernoulli_logit_geometry_from_jet(
eta_raw: f64,
eta_used: f64,
y: f64,
priorweight: f64,
jet: crate::mixture_link::LogitJet5,
zero_on_nonsmooth: bool,
) -> WorkingBernoulliGeometry {
let fisher = jet.d1;
let nonsmooth = eta_raw != eta_used || !fisher.is_finite() || fisher < 0.0;
let (c, d) = if nonsmooth && zero_on_nonsmooth {
(0.0, 0.0)
} else {
(priorweight * jet.d2, priorweight * jet.d3)
};
WorkingBernoulliGeometry {
mu: jet.mu,
weight: priorweight * fisher,
z: bernoulli_exact_working_response(eta_used, y, jet.mu, jet.d1),
c,
d,
}
}
#[inline]
fn bernoulli_geometry_from_jet(
eta_raw: f64,
eta_used: f64,
y: f64,
priorweight: f64,
jet: MixtureInverseLinkJet,
) -> WorkingBernoulliGeometry {
let mu = jet.mu;
let v = mu * (1.0 - mu);
let n0 = jet.d1 * jet.d1;
let fisher = if v.is_finite() && v > 0.0 {
n0 / v
} else {
0.0
};
let nonsmooth =
eta_raw != eta_used || !v.is_finite() || v <= 0.0 || !fisher.is_finite() || fisher < 0.0;
let (c, d) = if nonsmooth {
(0.0, 0.0)
} else {
let v1 = jet.d1 * (1.0 - 2.0 * mu);
let v2 = jet.d2 * (1.0 - 2.0 * mu) - 2.0 * jet.d1 * jet.d1;
let n1 = 2.0 * jet.d1 * jet.d2;
let n2 = 2.0 * (jet.d2 * jet.d2 + jet.d1 * jet.d3);
let numer1 = n1 * v - n0 * v1;
let c = priorweight * numer1 / (v * v);
let d = priorweight * ((n2 * v - n0 * v2) / (v * v) - 2.0 * numer1 * v1 / (v * v * v));
(c, d)
};
WorkingBernoulliGeometry {
mu,
weight: priorweight * fisher,
z: bernoulli_exact_working_response(eta_used, y, mu, jet.d1),
c,
d,
}
}
#[inline]
fn bernoulli_exact_working_response(eta: f64, y: f64, mu: f64, dmu_deta: f64) -> f64 {
if dmu_deta.is_finite() && dmu_deta > 0.0 {
let delta = (y - mu) / dmu_deta;
if delta.is_finite() {
return eta + delta;
}
}
eta
}
#[inline]
fn write_identityworking_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
mu.assign(eta);
weights.assign(&priorweights);
z.assign(&y);
if let Some(derivs) = derivatives {
derivs.c.fill(0.0);
derivs.d.fill(0.0);
derivs.dmu_deta.fill(1.0);
derivs.d2mu_deta2.fill(0.0);
derivs.d3mu_deta3.fill(0.0);
}
}
#[inline]
fn write_poisson_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
log_link_working_state::write_log_link_working_state(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::PoissonIdentity,
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 1.0,
d_ratio: 1.0,
},
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
y,
eta,
priorweights,
mu,
weights,
z,
derivatives,
);
}
#[inline]
fn write_gamma_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
shape: f64,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
log_link_working_state::write_log_link_working_state(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::Constant { factor: shape },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 0.0,
d_ratio: 0.0,
},
floor_weight: false,
zero_mu_jet_on_clamp: false,
},
y,
eta,
priorweights,
mu,
weights,
z,
derivatives,
);
}
pub const BETA_MU_EPS: f64 = 1.0e-12;
#[inline]
fn tweedie_log_weight_mu_power(mu: f64, p: f64) -> f64 {
mu.max(1.0e-300).powf(2.0 - p)
}
#[inline]
fn valid_negbin_theta(theta: f64) -> bool {
theta.is_finite() && theta > 0.0
}
#[inline]
fn valid_count_response(y: f64) -> bool {
y.is_finite() && y >= 0.0 && (y - y.round()).abs() <= 1e-9
}
fn validate_count_responses(
y: &ArrayView1<'_, f64>,
priorweights: &ArrayView1<'_, f64>,
family: &str,
) -> Result<(), EstimationError> {
for (i, (&yi, &wi)) in y.iter().zip(priorweights.iter()).enumerate() {
if wi > 0.0 && !valid_count_response(yi) {
crate::bail_invalid_estim!(
"{family} response must be a finite non-negative integer at positive-weight row {i}; got {yi}"
);
}
}
Ok(())
}
#[inline]
fn valid_beta_phi(phi: f64) -> bool {
phi.is_finite() && phi > 0.0
}
#[inline]
fn valid_beta_response(y: f64) -> bool {
y.is_finite() && y > 0.0 && y < 1.0
}
fn validate_beta_responses(
y: &ArrayView1<'_, f64>,
priorweights: &ArrayView1<'_, f64>,
) -> Result<(), EstimationError> {
for (i, (&yi, &wi)) in y.iter().zip(priorweights.iter()).enumerate() {
if wi > 0.0 && !valid_beta_response(yi) {
crate::bail_invalid_estim!(
"beta-regression response must be finite and strictly inside (0, 1) at positive-weight row {i}; got {yi}"
);
}
}
Ok(())
}
#[inline]
fn valid_tweedie_response(y: f64) -> bool {
y.is_finite() && y >= 0.0
}
fn validate_tweedie_responses(
y: &ArrayView1<'_, f64>,
priorweights: &ArrayView1<'_, f64>,
) -> Result<(), EstimationError> {
for (i, (&yi, &wi)) in y.iter().zip(priorweights.iter()).enumerate() {
if wi > 0.0 && !valid_tweedie_response(yi) {
crate::bail_invalid_estim!(
"Tweedie response must be finite and non-negative at positive-weight row {i}; got {yi}"
);
}
}
Ok(())
}
#[inline]
fn safe_beta_mu(mu: f64) -> f64 {
mu.clamp(BETA_MU_EPS, 1.0 - BETA_MU_EPS)
}
#[inline]
fn trigamma(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < 8.0 {
acc += 1.0 / (x * x);
x += 1.0;
}
let inv = 1.0 / x;
let inv2 = inv * inv;
acc + inv + 0.5 * inv2 + inv2 * inv / 6.0 - inv2 * inv2 * inv / 30.0
+ inv2 * inv2 * inv2 * inv / 42.0
- inv2 * inv2 * inv2 * inv2 * inv / 30.0
}
#[inline]
fn polygamma2(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < 8.0 {
acc -= 2.0 / (x * x * x);
x += 1.0;
}
let inv = 1.0 / x;
let inv2 = inv * inv;
let inv3 = inv2 * inv;
acc - inv2 - inv3 - 0.5 * inv2 * inv2 + inv3 * inv3 / 6.0 - inv2 * inv3 * inv3 / 6.0
+ 0.3 * inv2 * inv2 * inv3 * inv3
- 5.0 * inv2 * inv2 * inv2 * inv3 * inv3 / 6.0
}
#[inline]
fn polygamma3(mut x: f64) -> f64 {
if !(x.is_finite() && x > 0.0) {
return f64::NAN;
}
let mut acc = 0.0;
while x < 8.0 {
acc += 6.0 / (x * x * x * x);
x += 1.0;
}
let inv = 1.0 / x;
let inv2 = inv * inv;
let inv3 = inv2 * inv;
let inv4 = inv2 * inv2;
acc + 2.0 * inv3 + 3.0 * inv4 + 2.0 * inv4 * inv - inv4 * inv3 + 4.0 * inv4 * inv3 * inv2 / 3.0
- 3.0 * inv4 * inv3 * inv4
+ 10.0 * inv4 * inv4 * inv4 * inv
}
#[inline]
fn beta_logit_working_curvature_eta_derivatives(
prior_weight: f64,
phi: f64,
mu: f64,
q: f64,
a: f64,
b: f64,
trigamma_sum: f64,
) -> (f64, f64) {
let q_prime = q * (1.0 - 2.0 * mu);
let q_double_prime = q * (1.0 - 2.0 * mu) * (1.0 - 2.0 * mu) - 2.0 * q * q;
let psi2_diff = polygamma2(a) - polygamma2(b);
let psi3_sum = polygamma3(a) + polygamma3(b);
let phi_sq = phi * phi;
let q_sq = q * q;
let c = prior_weight * phi_sq * (2.0 * q * q_prime * trigamma_sum + q_sq * phi * q * psi2_diff);
let d = prior_weight
* phi_sq
* (2.0 * (q_prime * q_prime + q * q_double_prime) * trigamma_sum
+ 4.0 * q * q_prime * phi * q * psi2_diff
+ q_sq * (phi * q_prime * psi2_diff + phi_sq * q_sq * psi3_sum));
(c, d)
}
#[inline]
fn write_tweedie_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
p: f64,
phi: f64,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
if !is_valid_tweedie_power(p) {
crate::bail_invalid_estim!(
"Tweedie variance power must be finite and strictly between 1 and 2; got {p}",
p = p
);
}
if !(phi.is_finite() && phi > 0.0) {
crate::bail_invalid_estim!(
"Tweedie dispersion phi must be finite and > 0; got {phi}",
phi = phi
);
}
validate_tweedie_responses(&y, &priorweights)?;
let exponent = 2.0 - p;
log_link_working_state::write_log_link_working_state(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::TweediePower { p, phi },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: exponent,
d_ratio: exponent * exponent,
},
floor_weight: true,
zero_mu_jet_on_clamp: true,
},
y,
eta,
priorweights,
mu,
weights,
z,
derivatives,
);
Ok(())
}
#[inline]
fn write_negative_binomial_log_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
theta: f64,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
if !valid_negbin_theta(theta) {
crate::bail_invalid_estim!(
"negative-binomial theta must be finite and > 0; got {theta}",
theta = theta
);
}
validate_count_responses(&y, &priorweights, "negative-binomial")?;
log_link_working_state::write_log_link_working_state(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::NegativeBinomial { theta },
curvature: log_link_working_state::WorkingCurvature::NegativeBinomial { theta },
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
y,
eta,
priorweights,
mu,
weights,
z,
derivatives,
);
Ok(())
}
#[inline]
fn write_beta_logit_working_state(
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
phi: f64,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
if !valid_beta_phi(phi) {
crate::bail_invalid_estim!("beta-regression phi must be finite and > 0; got {phi}");
}
validate_beta_responses(&y, &priorweights)?;
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), dmu_o), d2_o), d3_o), c_o), d_o))| {
let eta_raw = eta[i];
let eta_i = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_i);
let mu_i = safe_beta_mu(jet.mu);
let q = (mu_i * (1.0 - mu_i)).max(BETA_MU_EPS);
let yi = y[i];
let a = (mu_i * phi).max(BETA_MU_EPS);
let b = ((1.0 - mu_i) * phi).max(BETA_MU_EPS);
let score_mu = phi * (digamma(b) - digamma(a) + yi.ln() - (1.0 - yi).ln());
let trigamma_sum = trigamma(a) + trigamma(b);
let info_mu = phi * phi * trigamma_sum;
let prior_weight = priorweights[i].max(0.0);
let raw_weight = prior_weight * q * q * info_mu;
let floor_active = raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
*mu_o = mu_i;
*w_o = if raw_weight > 0.0 {
raw_weight.max(MIN_WEIGHT)
} else {
0.0
};
*z_o = eta_i + score_mu / (q * info_mu).max(MIN_WEIGHT);
*dmu_o = q;
*d2_o = q * (1.0 - 2.0 * mu_i);
*d3_o = q * (1.0 - 6.0 * q);
if floor_active || eta_raw != eta_i {
*c_o = 0.0;
*d_o = 0.0;
} else {
let (c_i, d_i) = beta_logit_working_curvature_eta_derivatives(
prior_weight,
phi,
mu_i,
q,
a,
b,
trigamma_sum,
);
*c_o = c_i;
*d_o = d_i;
}
},
);
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_i = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_i);
let mu_i = safe_beta_mu(jet.mu);
let q = (mu_i * (1.0 - mu_i)).max(BETA_MU_EPS);
let yi = y[i];
let a = (mu_i * phi).max(BETA_MU_EPS);
let b = ((1.0 - mu_i) * phi).max(BETA_MU_EPS);
let score_mu = phi * (digamma(b) - digamma(a) + yi.ln() - (1.0 - yi).ln());
let info_mu = phi * phi * (trigamma(a) + trigamma(b));
let raw_weight = priorweights[i].max(0.0) * q * q * info_mu;
*mu_o = mu_i;
*w_o = if raw_weight > 0.0 {
raw_weight.max(MIN_WEIGHT)
} else {
0.0
};
*z_o = eta_i + score_mu / (q * info_mu).max(MIN_WEIGHT);
});
}
Ok(())
}
#[inline]
pub fn update_glmvectors(
y: ArrayView1<f64>,
eta: &Array1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if matches!(link, LinkFunction::Logit)
&& inverse_link.mixture_state().is_none()
&& inverse_link.sas_state().is_none()
{
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
true,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
},
);
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
true,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
});
}
return Ok(());
}
match link {
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => {
let zero_on_nonsmooth = matches!(link, LinkFunction::Logit);
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(
i,
(((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o),
)|
-> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
}
Ok(())
})?;
}
Ok(())
}
LinkFunction::Identity => {
write_identityworking_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
LinkFunction::Log => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
}
}
#[inline]
pub fn update_glmvectors_by_family(
y: ArrayView1<f64>,
eta: &Array1<f64>,
likelihood: &GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
) -> Result<(), EstimationError> {
likelihood.irls_update(y, eta, priorweights, mu, weights, z, None, None)
}
fn integrated_inverse_link_from_family(
spec: &LikelihoodSpec,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<InverseLink, EstimationError> {
match (&spec.response, &spec.link) {
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit))
| (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit))
| (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
Ok(spec.link.clone())
}
(ResponseFamily::Binomial, InverseLink::Sas(_)) => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialSas update requires explicit SasLinkState".to_string(),
)
})?;
Ok(InverseLink::Sas(*state))
}
(ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialBetaLogistic update requires explicit SasLinkState"
.to_string(),
)
})?;
Ok(InverseLink::BetaLogistic(*state))
}
(ResponseFamily::Binomial, InverseLink::Mixture(_)) => {
let state = mixture_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialMixture update requires explicit MixtureLinkState"
.to_string(),
)
})?;
Ok(InverseLink::Mixture(state.clone()))
}
_ => Err(EstimationError::InvalidInput(format!(
"Integrated link-runtime update is not supported for likelihood (response={:?}, link={:?})",
spec.response, spec.link
))),
}
}
#[inline]
pub fn update_glmvectors_integrated_for_link(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if !matches!(
inverse_link,
InverseLink::Standard(StandardLink::Logit)
| InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
) {
crate::bail_invalid_estim!(
"Integrated link-runtime update is not supported for inverse link {:?}",
inverse_link
);
}
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))|
-> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::families::lognormal_kernel::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(StandardLink::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let geom = bernoulli_geometry_from_jet(
eta[i],
e,
y[i],
priorweights[i],
local_jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = local_jet.d1;
*d2_o = local_jet.d2;
*d3_o = local_jet.d3;
Ok(())
},
)?;
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::families::lognormal_kernel::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(StandardLink::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let geom = bernoulli_geometry_from_jet(eta[i], e, y[i], priorweights[i], local_jet);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
Ok(())
})?;
}
Ok(())
}
#[inline]
pub fn update_glmvectors_integrated_by_family(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
spec: &LikelihoodSpec,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<(), EstimationError> {
let inverse_link =
integrated_inverse_link_from_family(spec, mixture_link_state, sas_link_state)?;
update_glmvectors_integrated_for_link(
quadctx,
y,
eta,
se,
&inverse_link,
priorweights,
mu,
weights,
z,
derivatives,
)
}
pub(crate) fn computeworkingweight_derivatives_from_eta(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<
(
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
),
EstimationError,
> {
let n = eta.len();
let mut c = Array1::<f64>::zeros(n);
let mut d = Array1::<f64>::zeros(n);
let mut dmu_deta = Array1::<f64>::zeros(n);
let mut d2mu_deta2 = Array1::<f64>::zeros(n);
let mut d3mu_deta3 = Array1::<f64>::zeros(n);
match &likelihood.spec.response {
ResponseFamily::Gaussian => {
dmu_deta.fill(1.0);
}
ResponseFamily::Poisson => {
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::PoissonIdentity,
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 1.0,
d_ratio: 1.0,
},
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Tweedie { p } => {
let p = *p;
let phi = fixed_glm_dispersion(likelihood);
if !is_valid_tweedie_power(p) {
crate::bail_invalid_estim!(
"Tweedie variance power must be finite and strictly between 1 and 2; got {p}",
p = p
);
}
if !(phi.is_finite() && phi > 0.0) {
crate::bail_invalid_estim!(
"Tweedie dispersion phi must be finite and > 0; got {phi}",
phi = phi
);
}
let exponent = 2.0 - p;
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::TweediePower { p, phi },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: exponent,
d_ratio: exponent * exponent,
},
floor_weight: true,
zero_mu_jet_on_clamp: true,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::NegativeBinomial { theta, .. } => {
let theta = *theta;
if !valid_negbin_theta(theta) {
crate::bail_invalid_estim!(
"negative-binomial theta must be finite and > 0; got {theta}",
theta = theta
);
}
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::NegativeBinomial { theta },
curvature: log_link_working_state::WorkingCurvature::NegativeBinomial { theta },
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Beta { phi } => {
let phi = *phi;
if !valid_beta_phi(phi) {
crate::bail_invalid_estim!("beta-regression phi must be finite and > 0; got {phi}");
}
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| {
let eta_raw = eta[i];
let eta_i = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_i);
let mu_i = safe_beta_mu(jet.mu);
let q = (mu_i * (1.0 - mu_i)).max(BETA_MU_EPS);
let a = (mu_i * phi).max(BETA_MU_EPS);
let b = ((1.0 - mu_i) * phi).max(BETA_MU_EPS);
let trigamma_sum = trigamma(a) + trigamma(b);
let prior_weight = priorweights[i].max(0.0);
let raw_weight = prior_weight * q * q * phi * phi * trigamma_sum;
let floor_active = raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
if floor_active || eta_raw != eta_i {
*c_o = 0.0;
*d_o = 0.0;
} else {
let (c_i, d_i) = beta_logit_working_curvature_eta_derivatives(
prior_weight,
phi,
mu_i,
q,
a,
b,
trigamma_sum,
);
*c_o = c_i;
*d_o = d_i;
}
*dmu_o = q;
*d2_o = q * (1.0 - 2.0 * mu_i);
*d3_o = q * (1.0 - 6.0 * q);
});
}
ResponseFamily::Gamma => {
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::Constant { factor: 1.0 },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 0.0,
d_ratio: 0.0,
},
floor_weight: false,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Binomial => {
let link = inverse_link.link_function();
let zero_on_nonsmooth = matches!(link, LinkFunction::Logit);
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| -> Result<(), EstimationError> {
let eta_used = match link {
LinkFunction::Logit => eta[i].clamp(-ETA_CLAMP, ETA_CLAMP),
LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => eta[i].clamp(-30.0, 30.0),
LinkFunction::Log => eta[i].clamp(-ETA_CLAMP, ETA_CLAMP),
LinkFunction::Identity => eta[i],
};
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
zero_on_nonsmooth,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
}
ResponseFamily::RoystonParmar => {
crate::bail_invalid_estim!(
"RoystonParmar is survival-specific and not a GLM IRLS family"
);
}
}
Ok((c, d, dmu_deta, d2mu_deta2, d3mu_deta3))
}
#[derive(Clone, Copy, Debug)]
pub struct VarianceJet {
pub v: f64,
pub v1: f64,
pub v2: f64,
pub v3: f64,
pub v4: f64,
}
impl VarianceJet {
const VARIANCE_MU_FLOOR: f64 = 1e-10;
#[inline]
pub fn bernoulli(mu: f64) -> Self {
Self {
v: mu * (1.0 - mu),
v1: 1.0 - 2.0 * mu,
v2: -2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn poisson(mu: f64) -> Self {
Self {
v: mu,
v1: 1.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gamma(mu: f64) -> Self {
Self {
v: mu * mu,
v1: 2.0 * mu,
v2: 2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn tweedie(mu: f64, p: f64) -> Self {
let mu = mu.max(Self::VARIANCE_MU_FLOOR);
Self {
v: mu.powf(p),
v1: p * mu.powf(p - 1.0),
v2: p * (p - 1.0) * mu.powf(p - 2.0),
v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
}
}
#[inline]
pub fn negative_binomial(mu: f64, theta: f64) -> Self {
let mu = mu.max(Self::VARIANCE_MU_FLOOR);
let inv_theta = if valid_negbin_theta(theta) {
1.0 / theta
} else {
f64::NAN
};
Self {
v: mu + mu * mu * inv_theta,
v1: 1.0 + 2.0 * mu * inv_theta,
v2: 2.0 * inv_theta,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gaussian() -> Self {
Self {
v: 1.0,
v1: 0.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn binomial_n(mu: f64) -> Self {
Self::bernoulli(mu)
}
#[inline]
pub fn beta(mu: f64, phi: f64) -> Self {
let scale = 1.0 / (1.0 + phi.max(1e-12));
let base = Self::bernoulli(mu);
Self {
v: base.v * scale,
v1: base.v1 * scale,
v2: base.v2 * scale,
v3: 0.0,
v4: 0.0,
}
}
}
const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
#[inline]
pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
(fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
.max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
}
pub fn outer_hessian_curvature_arrays(
hessian_weights: crate::matrix::SignedWeightsView<'_>,
fisher_weights: crate::matrix::PsdWeightsView<'_>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
eta: &Array1<f64>,
inverse_link: &InverseLink,
) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
let hessian_view = hessian_weights.view();
let fisher_view = fisher_weights.view();
let n = hessian_view.len();
let mut w_out = Array1::<f64>::zeros(n);
let mut c_out = Array1::<f64>::zeros(n);
let mut d_out = Array1::<f64>::zeros(n);
for i in 0..n {
let floor = solver_hessian_weight_floor(fisher_view[i]);
let w = hessian_view[i];
let clamp_active = eta_clamp_active(inverse_link, eta[i]);
let w_below_floor = !(w.is_finite() && w > floor);
if w_below_floor {
w_out[i] = floor;
c_out[i] = 0.0;
d_out[i] = 0.0;
} else if clamp_active {
w_out[i] = w;
c_out[i] = 0.0;
d_out[i] = 0.0;
} else {
w_out[i] = w;
c_out[i] = c_array[i];
d_out[i] = d_array[i];
}
}
(w_out, c_out, d_out)
}
#[inline]
fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
likelihood.fixed_phi().unwrap_or(1.0)
}
#[inline]
pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
match &likelihood.spec.response {
ResponseFamily::Gaussian => WeightFamily::Gaussian,
ResponseFamily::Poisson => WeightFamily::Poisson,
ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
ResponseFamily::NegativeBinomial { theta, .. } => {
WeightFamily::NegativeBinomial { theta: *theta }
}
ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
ResponseFamily::Gamma => WeightFamily::Gamma,
ResponseFamily::Binomial => WeightFamily::Binomial,
ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
}
}
#[inline]
fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
match inverse_link {
InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => WeightLink::Other,
}
}
#[inline]
fn supports_observed_hessian_curvature_for_likelihood(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
) -> bool {
let spec = &likelihood.spec;
if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
}
if matches!(spec.response, ResponseFamily::Gamma) {
return true;
}
if !matches!(spec.response, ResponseFamily::Binomial) {
return false;
}
matches!(
spec.link,
InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
)
}
#[inline]
fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
match inverse_link {
InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
eta.clamp(-ETA_CLAMP, ETA_CLAMP)
}
InverseLink::Standard(StandardLink::Identity) => eta,
InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
eta.clamp(-23.0, 3.0)
}
InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
eta.clamp(-20.0, 20.0)
}
}
}
#[inline]
pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
clamped != eta
}
fn solver_hessian_weights_into(
hessian_weights: &Array1<f64>,
fisher_weights: &Array1<f64>,
out: &mut Array1<f64>,
) {
if out.len() != hessian_weights.len() {
*out = Array1::<f64>::zeros(hessian_weights.len());
}
ndarray::Zip::from(out)
.and(hessian_weights)
.and(fisher_weights)
.par_for_each(|o, &w, &fw| {
let floor = solver_hessian_weight_floor(fw);
*o = if w.is_finite() && w > floor { w } else { floor };
});
}
fn compute_observed_hessian_curvature_arrays_into(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
hessian_weights: &mut Array1<f64>,
hessian_c: &mut Array1<f64>,
hessian_d: &mut Array1<f64>,
) -> Result<(), EstimationError> {
assert!(supports_observed_hessian_curvature_for_likelihood(
likelihood,
inverse_link
));
let n = eta.len();
if hessian_weights.len() != n {
*hessian_weights = Array1::<f64>::zeros(n);
}
if hessian_c.len() != n {
*hessian_c = Array1::<f64>::zeros(n);
}
if hessian_d.len() != n {
*hessian_d = Array1::<f64>::zeros(n);
}
let weight_family = weight_family_for_glm_likelihood(likelihood);
let weight_link = weight_link_for_inverse_link(inverse_link);
let phi = fixed_glm_dispersion(likelihood);
hessian_weights
.as_slice_mut()
.expect("hessian weights must be contiguous")
.par_iter_mut()
.zip(
hessian_c
.as_slice_mut()
.expect("hessian c must be contiguous")
.par_iter_mut(),
)
.zip(
hessian_d
.as_slice_mut()
.expect("hessian d must be contiguous")
.par_iter_mut(),
)
.enumerate()
.try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
let jet =
crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
inverse_link, eta_used,
)?;
let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
weight_family,
weight_link,
eta_used,
y[i],
jet.mu,
phi,
priorweights[i].max(0.0),
jet,
h4,
);
let fisher_weight = fisher_weights[i].max(0.0);
if !(w_obs.is_finite() && w_obs > 0.0) {
crate::bail_invalid_estim!(
"observed Hessian curvature is not positive finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
);
}
if !c_obs.is_finite() || !d_obs.is_finite() {
crate::bail_invalid_estim!(
"observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
);
}
*w_out = w_obs;
*c_out = c_obs;
*d_out = d_obs;
Ok(())
})
}
pub(crate) fn compute_observed_hessian_curvature_arrays(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let n = eta.len();
let mut hessian_weights = Array1::<f64>::zeros(n);
let mut hessian_c = Array1::<f64>::zeros(n);
let mut hessian_d = Array1::<f64>::zeros(n);
compute_observed_hessian_curvature_arrays_into(
likelihood,
inverse_link,
eta,
y,
fisher_weights,
priorweights,
&mut hessian_weights,
&mut hessian_c,
&mut hessian_d,
)?;
Ok((hessian_weights, hessian_c, hessian_d))
}
#[inline]
pub fn observed_weight_noncanonical(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> (f64, f64, f64) {
let VarianceJet {
v,
v1,
v2,
v3,
v4: _,
} = vj;
let phi_v = phi * v;
let phi_v2 = phi * v * v;
let phi_v3 = phi * v * v * v;
let h1_sq = h1 * h1;
let w_f = h1_sq / phi_v;
let n0 = h1_sq; let n1 = 2.0 * h1 * h2; let n2 = 2.0 * (h2 * h2 + h1 * h3); let vd1 = h1 * v1; let vd2 = h2 * v1 + h1_sq * v2;
let c_f = (n1 * v - n0 * vd1) / phi_v2;
let numer_cf = n1 * v - n0 * vd1;
let dnumer_cf = n2 * v - n0 * vd2;
let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
let b_num = h2 * v - h1_sq * v1;
let b = b_num / phi_v2;
let b_eta_num =
h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
let b_eta = b_eta_num / phi_v3;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
- 3.0 * (h2 * h2 + h1 * h3) * v * v1
- 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
- 3.0 * h1_sq * h2 * v * v2
- h1_cu * (h1 * v1 * v2 + v * h1 * v3)
+ 6.0 * h1_sq * h2 * v1 * v1
+ 4.0 * h1_qu * v1 * v2;
let phi_v4 = phi_v3 * v;
let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
let resid = y - mu;
let w_obs = w_f - resid * b;
let c_obs = c_f + h1 * b - resid * b_eta;
let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
(pw * w_obs, pw * c_obs, pw * d_obs)
}
#[inline]
pub fn e_obs_from_jets(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
h5: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> f64 {
let VarianceJet { v, v1, v2, v3, v4 } = vj;
let q = phi * v;
let h1_sq = h1 * h1;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let q1 = phi * v1 * h1;
let q2 = phi * (v1 * h2 + v2 * h1_sq);
let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
let q4 = phi
* (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
let t0 = h1 / q;
let t1 = (h2 - t0 * q1) / q;
let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
let resid = y - mu;
let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
pw * e_obs
}
#[inline]
pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let inv_phi = pw / phi;
let w = inv_phi * mu * (2.0 * mu - y);
let c = inv_phi * mu * (4.0 * mu - y);
let d = inv_phi * mu * (8.0 * mu - y);
(w, c, d)
}
#[inline]
pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let eta2 = eta * eta;
let eta4 = eta2 * eta2;
let eta5 = eta4 * eta;
let eta6 = eta4 * eta2;
let ey = eta * y;
let inv_phi = pw / phi;
let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
(w, c, d)
}
#[inline]
fn observed_weight_binomial_logit_from_jet(
n_trials: f64,
jet: MixtureInverseLinkJet,
pw: f64,
) -> (f64, f64, f64) {
let scale = pw * n_trials;
(scale * jet.d1, scale * jet.d2, scale * jet.d3)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightFamily {
Gaussian,
Binomial,
Poisson,
Tweedie { p: f64 },
NegativeBinomial { theta: f64 },
Beta { phi: f64 },
Gamma,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightLink {
Identity,
Log,
Logit,
Inverse,
Other,
}
#[inline]
pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
match family {
WeightFamily::Gaussian => VarianceJet::gaussian(),
WeightFamily::Binomial => VarianceJet::binomial_n(mu),
WeightFamily::Poisson => VarianceJet::poisson(mu),
WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
WeightFamily::Gamma => VarianceJet::gamma(mu),
}
}
pub fn observed_weight_dispatch(
family: WeightFamily,
link: WeightLink,
eta: f64,
y: f64,
mu: f64,
phi: f64,
prior_weight: f64,
jet: MixtureInverseLinkJet,
h4: f64,
) -> (f64, f64, f64) {
match (family, link) {
(WeightFamily::Gaussian, WeightLink::Log) => {
observed_weight_gaussian_log(y, mu, phi, prior_weight)
}
(WeightFamily::Gaussian, WeightLink::Inverse) => {
observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
}
(WeightFamily::Binomial, WeightLink::Logit) => {
observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
}
_ => {
let vj = variance_jet_for_weight_family(family, mu);
observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
}
}
}
#[derive(Clone)]
pub enum DirectionalWorkingCurvature {
Diagonal(Array1<f64>),
}
pub fn directionalworking_curvature_from_c_array(
c_array: &Array1<f64>,
hessian_weights: &Array1<f64>,
eta_direction: &Array1<f64>,
) -> DirectionalWorkingCurvature {
let mut w_direction = c_array * eta_direction;
for i in 0..w_direction.len() {
if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
w_direction[i] = 0.0;
}
}
DirectionalWorkingCurvature::Diagonal(w_direction)
}
const BINOMIAL_MU_EPS: f64 = 1e-12;
#[inline]
fn safe_mu_for_binomial(mu: f64) -> f64 {
mu.clamp(BINOMIAL_MU_EPS, 1.0 - BINOMIAL_MU_EPS)
}
#[inline]
fn xlogy(x: f64, y: f64) -> f64 {
if x == 0.0 { 0.0 } else { x * y.ln() }
}
#[inline]
fn log_gamma_stirling_correction(x: f64) -> f64 {
let inv = 1.0 / x;
let inv2 = inv * inv;
inv / 12.0 - inv * inv2 / 360.0 + inv * inv2 * inv2 / 1260.0
}
#[inline]
fn log_gamma_large_ratio(base: f64, delta: f64) -> f64 {
let ratio = delta / base;
delta * base.ln() + (base + delta - 0.5) * ratio.ln_1p() - delta
+ log_gamma_stirling_correction(base + delta)
- log_gamma_stirling_correction(base)
}
#[inline]
fn beta_log_normalizer(a: f64, b: f64, sum: f64) -> f64 {
let direct = ln_gamma(sum) - ln_gamma(a) - ln_gamma(b);
if direct.is_finite() {
return direct;
}
let small = a.min(b);
let large = a.max(b);
if small < 8.0 {
return log_gamma_large_ratio(large, small) - ln_gamma(small);
}
-xlogy(a, a / sum) - xlogy(b, b / sum)
+ 0.5 * (a.ln() + b.ln() - sum.ln() - (2.0 * std::f64::consts::PI).ln())
+ log_gamma_stirling_correction(sum)
- log_gamma_stirling_correction(a)
- log_gamma_stirling_correction(b)
}
#[inline]
fn poisson_unit_deviance(yi: f64, mui_c: f64) -> f64 {
xlogy(yi, yi / mui_c) - (yi - mui_c)
}
#[inline]
fn gamma_unit_deviance(yi_c: f64, mui_c: f64) -> f64 {
let ratio = yi_c / mui_c;
ratio - 1.0 - ratio.ln()
}
#[inline]
fn tweedie_unit_deviance(yi: f64, mui_c: f64, p: f64) -> f64 {
if !is_valid_tweedie_power(p) {
f64::NAN
} else if !valid_tweedie_response(yi) {
f64::NAN
} else if yi == 0.0 {
mui_c.powf(2.0 - p) / (2.0 - p)
} else {
yi.powf(2.0 - p) / ((1.0 - p) * (2.0 - p)) - yi * mui_c.powf(1.0 - p) / (1.0 - p)
+ mui_c.powf(2.0 - p) / (2.0 - p)
}
}
#[inline]
fn negative_binomial_unit_deviance(yi: f64, mui_c: f64, theta: f64) -> f64 {
if !valid_negbin_theta(theta) || !valid_count_response(yi) {
return f64::NAN;
}
let y_term = xlogy(yi, (yi * (theta + mui_c)) / (mui_c * (theta + yi)));
let theta_term = theta * ((theta + mui_c) / (theta + yi)).ln();
theta_term + y_term
}
#[inline]
fn beta_loglikelihood_full_unit(yi: f64, mui: f64, phi: f64) -> f64 {
if !valid_beta_phi(phi) || !valid_beta_response(yi) {
return f64::NAN;
}
let mui_c = safe_beta_mu(mui);
let a = (mui_c * phi).max(BETA_MU_EPS);
let b = ((1.0 - mui_c) * phi).max(BETA_MU_EPS);
beta_log_normalizer(a, b, phi) + phi * xlogy(mui_c, yi) + phi * xlogy(1.0 - mui_c, 1.0 - yi)
- yi.ln()
- (1.0 - yi).ln()
}
#[inline]
fn beta_unit_deviance(yi: f64, mui: f64, phi: f64) -> f64 {
if !valid_beta_response(yi) {
return f64::NAN;
}
beta_loglikelihood_full_unit(yi, yi, phi) - beta_loglikelihood_full_unit(yi, mui, phi)
}
#[inline]
pub fn calculate_deviance(
y: ArrayView1<f64>,
mu: &Array1<f64>,
likelihood: &GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
) -> f64 {
const EPS: f64 = 1e-8;
const MU_FLOOR: f64 = 1e-10;
match &likelihood.spec.response {
ResponseFamily::Binomial => {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total_residual: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = safe_mu_for_binomial(mu[i]);
let wi = priorweights[i];
let term1 = if yi > EPS {
yi * (yi.ln() - mui_c.ln())
} else {
0.0
};
let term2 = if yi < 1.0 - EPS {
(1.0 - yi) * ((1.0 - yi).ln() - (1.0 - mui_c).ln())
} else {
0.0
};
wi * (term1 + term2)
})
.sum();
2.0 * total_residual
}
ResponseFamily::Gaussian => {
let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
if !(phi.is_finite() && phi > 0.0) {
return f64::NAN;
}
let raw: f64 = ndarray::Zip::from(y)
.and(mu)
.and(priorweights)
.map_collect(|&yi, &mui, &wi| wi * (yi - mui) * (yi - mui))
.sum();
raw / phi
}
ResponseFamily::Poisson => {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i] * poisson_unit_deviance(yi, mui_c)
})
.sum();
2.0 * total
}
ResponseFamily::Tweedie { p } => {
let p = *p;
let phi = fixed_glm_dispersion(likelihood);
if !is_valid_tweedie_power(p) || !(phi.is_finite() && phi > 0.0) {
return f64::NAN;
}
if validate_tweedie_responses(&y, &priorweights).is_err() {
return f64::NAN;
}
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i] * tweedie_unit_deviance(yi, mui_c, p) / phi
})
.sum();
2.0 * total
}
ResponseFamily::NegativeBinomial { theta, .. } => {
let theta = *theta;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i] * negative_binomial_unit_deviance(yi, mui_c, theta)
})
.sum();
2.0 * total
}
ResponseFamily::Beta { phi } => {
let phi = *phi;
if !valid_beta_phi(phi) {
return f64::NAN;
}
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| priorweights[i] * beta_unit_deviance(y[i], mu[i], phi))
.sum();
2.0 * total
}
ResponseFamily::Gamma => {
let shape = likelihood.gamma_shape().unwrap_or(1.0);
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let total: f64 = (0..y.len())
.into_par_iter()
.map(|i| {
let yi_c = y[i].max(EPS);
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i] * shape * gamma_unit_deviance(yi_c, mui_c)
})
.sum();
2.0 * total
}
ResponseFamily::RoystonParmar => f64::NAN,
}
}
#[inline]
pub fn pointwise_loglikelihood_omitting_constants(
y: ArrayView1<f64>,
mu: &Array1<f64>,
likelihood: &GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
) -> Array1<f64> {
const MU_FLOOR: f64 = 1e-10;
const EPS: f64 = 1e-8;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = y.len();
let values: Vec<f64> = match &likelihood.spec.response {
ResponseFamily::Gaussian => {
let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
if !(phi.is_finite() && phi > 0.0) {
return Array1::from_elem(n, f64::NAN);
}
let inv_phi = 1.0 / phi;
(0..n)
.into_par_iter()
.map(|i| {
let resid = y[i] - mu[i];
-0.5 * priorweights[i] * resid * resid * inv_phi
})
.collect()
}
ResponseFamily::Binomial => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = safe_mu_for_binomial(mu[i]);
priorweights[i] * (y[i] * mui_c.ln() + (1.0 - y[i]) * (1.0 - mui_c).ln())
})
.collect(),
ResponseFamily::Poisson => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = mu[i].max(MU_FLOOR);
let log_term = if y[i] > 0.0 { y[i] * mui_c.ln() } else { 0.0 };
priorweights[i] * (log_term - mui_c)
})
.collect(),
ResponseFamily::Tweedie { p } => {
let p = *p;
let phi = fixed_glm_dispersion(likelihood);
if !is_valid_tweedie_power(p) || !(phi.is_finite() && phi > 0.0) {
return Array1::from_elem(n, f64::NAN);
}
if validate_tweedie_responses(&y, &priorweights).is_err() {
return Array1::from_elem(n, f64::NAN);
}
(0..n)
.into_par_iter()
.map(|i| {
let yi = y[i];
let mui_c = mu[i].max(MU_FLOOR);
-priorweights[i] * tweedie_unit_deviance(yi, mui_c, p) / phi
})
.collect()
}
ResponseFamily::NegativeBinomial { theta, .. } => {
let theta = *theta;
(0..n)
.into_par_iter()
.map(|i| {
if !valid_negbin_theta(theta) {
return f64::NAN;
}
let yi = y[i];
if !valid_count_response(yi) {
return f64::NAN;
}
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i]
* (ln_gamma(yi + theta) - ln_gamma(theta) - ln_gamma(yi + 1.0)
+ theta * (theta.ln() - (theta + mui_c).ln())
+ xlogy(yi, mui_c)
- yi * (theta + mui_c).ln())
})
.collect()
}
ResponseFamily::Beta { phi } => {
let phi = *phi;
(0..n)
.into_par_iter()
.map(|i| {
if !valid_beta_phi(phi) {
return f64::NAN;
}
priorweights[i] * beta_loglikelihood_full_unit(y[i], mu[i], phi)
})
.collect()
}
ResponseFamily::Gamma => {
let shape = likelihood.gamma_shape().unwrap_or(1.0);
(0..n)
.into_par_iter()
.map(|i| {
let yi_c = y[i].max(EPS);
let mui_c = mu[i].max(MU_FLOOR);
-priorweights[i] * shape * gamma_unit_deviance(yi_c, mui_c)
})
.collect()
}
ResponseFamily::RoystonParmar => vec![f64::NAN; n],
};
Array1::from_vec(values)
}
pub(crate) fn calculate_loglikelihood_omitting_constants(
y: ArrayView1<f64>,
mu: &Array1<f64>,
likelihood: &GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
) -> f64 {
const MU_FLOOR: f64 = 1e-10;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = y.len();
match &likelihood.spec.response {
ResponseFamily::Gaussian => {
let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
if !(phi.is_finite() && phi > 0.0) {
return f64::NAN;
}
let inv_phi = 1.0 / phi;
(0..n)
.into_par_iter()
.map(|i| {
let resid = y[i] - mu[i];
-0.5 * priorweights[i] * resid * resid * inv_phi
})
.sum()
}
ResponseFamily::Binomial => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = safe_mu_for_binomial(mu[i]);
priorweights[i] * (y[i] * mui_c.ln() + (1.0 - y[i]) * (1.0 - mui_c).ln())
})
.sum(),
ResponseFamily::Poisson => (0..n)
.into_par_iter()
.map(|i| {
let mui_c = mu[i].max(MU_FLOOR);
let log_term = if y[i] > 0.0 { y[i] * mui_c.ln() } else { 0.0 };
priorweights[i] * (log_term - mui_c)
})
.sum(),
ResponseFamily::Tweedie { p } => {
let p = *p;
let phi = fixed_glm_dispersion(likelihood);
if !is_valid_tweedie_power(p) || !(phi.is_finite() && phi > 0.0) {
return f64::NAN;
}
-0.5 * calculate_deviance(y, mu, likelihood, priorweights)
}
ResponseFamily::NegativeBinomial { theta, .. } => {
let theta = *theta;
(0..n)
.into_par_iter()
.map(|i| {
if !valid_negbin_theta(theta) {
return f64::NAN;
}
let yi = y[i];
if !valid_count_response(yi) {
return f64::NAN;
}
let mui_c = mu[i].max(MU_FLOOR);
priorweights[i]
* (ln_gamma(yi + theta) - ln_gamma(theta) - ln_gamma(yi + 1.0)
+ theta * (theta.ln() - (theta + mui_c).ln())
+ xlogy(yi, mui_c)
- yi * (theta + mui_c).ln())
})
.sum()
}
ResponseFamily::Beta { phi } => {
let phi = *phi;
(0..n)
.into_par_iter()
.map(|i| {
if !valid_beta_phi(phi) {
return f64::NAN;
}
priorweights[i] * beta_loglikelihood_full_unit(y[i], mu[i], phi)
})
.sum()
}
ResponseFamily::Gamma => {
-0.5 * calculate_deviance(y, mu, likelihood, priorweights)
}
ResponseFamily::RoystonParmar => f64::NAN,
}
}
use crate::linalg::low_rank_weight::LowRankWeight;
pub fn compute_xtwx_low_rank(
workspace: &mut PirlsWorkspace,
design: &DesignMatrix,
weight: &LowRankWeight<'_>,
) -> Result<Array2<f64>, EstimationError> {
let diag_owned = weight.diag.to_owned();
let mut xtwx = GamWorkingModel::compute_xtwx_blas(workspace, design, &diag_owned)?;
if weight.is_rank_zero() {
return Ok(xtwx);
}
weight
.add_low_rank_xtwx_correction(design, &mut xtwx)
.map_err(EstimationError::InvalidInput)?;
Ok(xtwx)
}
pub fn compute_xtwy_low_rank(
design: &DesignMatrix,
weight: &LowRankWeight<'_>,
y: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
weight
.xtw_y(design, y.view())
.map_err(EstimationError::InvalidInput)
}
pub fn dense_block_xtwx(
design: ArrayView2<'_, f64>,
fisher_blocks: ArrayView3<'_, f64>,
row_weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array2<f64>, EstimationError> {
let n = design.nrows();
let k = design.ncols();
let shape = fisher_blocks.shape();
if shape.len() != 3 || shape[0] != n || shape[1] != shape[2] {
crate::bail_invalid_estim!(
"dense block Fisher shape mismatch: expected ({n}, p, p), got {shape:?}"
);
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n {
crate::bail_invalid_estim!(
"dense block row weight length mismatch: expected {n}, got {}",
w.len()
);
}
if w.iter().any(|v| !v.is_finite() || *v < 0.0) {
crate::bail_invalid_estim!("dense block row weights must be finite and non-negative");
}
}
let p_out = shape[1];
let dim = k * p_out;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let nonfinite = (0..n)
.into_par_iter()
.filter_map(|row| {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
for b in 0..p_out {
if !(rw * fisher_blocks[[row, a, b]]).is_finite() {
return Some((row, a, b));
}
}
}
None
})
.min();
if let Some((row, a, b)) = nonfinite {
crate::bail_invalid_estim!("dense block Fisher entry ({row},{a},{b}) is not finite");
}
let mut out = (0..n)
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((dim, dim)),
|mut acc, row| {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
for b in 0..p_out {
let wab = rw * fisher_blocks[[row, a, b]];
if wab == 0.0 {
continue;
}
let row_a = a * k;
let row_b = b * k;
for i in 0..k {
let xi = design[[row, i]];
if xi == 0.0 {
continue;
}
let scaled = wab * xi;
for j in 0..k {
acc[[row_a + i, row_b + j]] += scaled * design[[row, j]];
}
}
}
}
acc
},
)
.reduce(
|| Array2::<f64>::zeros((dim, dim)),
|mut a, b| {
a += &b;
a
},
);
for i in 0..dim {
for j in (i + 1)..dim {
let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
out[[i, j]] = avg;
out[[j, i]] = avg;
}
}
Ok(out)
}
pub fn dense_block_xtwy(
design: ArrayView2<'_, f64>,
fisher_blocks: ArrayView3<'_, f64>,
response: ArrayView2<'_, f64>,
row_weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array1<f64>, EstimationError> {
let n = design.nrows();
let k = design.ncols();
let shape = fisher_blocks.shape();
if shape.len() != 3 || shape[0] != n || shape[1] != shape[2] {
crate::bail_invalid_estim!(
"dense block Fisher shape mismatch: expected ({n}, p, p), got {shape:?}"
);
}
let p_out = shape[1];
if response.dim() != (n, p_out) {
crate::bail_invalid_estim!(
"dense block response shape mismatch: expected ({n}, {p_out}), got {}x{}",
response.nrows(),
response.ncols()
);
}
if let Some(w) = row_weights.as_ref()
&& w.len() != n
{
crate::bail_invalid_estim!(
"dense block row weight length mismatch: expected {n}, got {}",
w.len()
);
}
let mut out = Array1::<f64>::zeros(k * p_out);
for row in 0..n {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
let mut wy = 0.0_f64;
for b in 0..p_out {
let wab = rw * fisher_blocks[[row, a, b]];
if !wab.is_finite() {
crate::bail_invalid_estim!(
"dense block Fisher entry ({row},{a},{b}) is not finite"
);
}
wy += wab * response[[row, b]];
}
for i in 0..k {
out[a * k + i] += design[[row, i]] * wy;
}
}
}
Ok(out)
}
pub fn woodbury_gram_capacitance(
a_inv_uhat: &Array2<f64>,
vhat: &Array2<f64>,
) -> Result<Array2<f64>, EstimationError> {
LowRankWeight::gram_capacitance(a_inv_uhat, vhat).map_err(EstimationError::InvalidInput)
}
#[cfg(test)]
mod low_rank_weight_pirls_tests {
use super::{
DesignMatrix, LowRankWeight, PirlsWorkspace, compute_xtwx_low_rank, compute_xtwy_low_rank,
woodbury_gram_capacitance,
};
use crate::linalg::matrix::{LinearOperator, SignedWeightsView};
use ndarray::{Array2, array};
fn tiny_design() -> DesignMatrix {
let x = array![
[1.0, 0.5, -0.2],
[0.3, 1.2, 0.4],
[-0.1, 0.7, 1.0],
[0.6, -0.3, 0.8],
[0.2, 0.9, -0.5],
];
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x))
}
#[test]
fn xtwx_low_rank_matches_diagonal_when_rank_zero() {
let design = tiny_design();
let d = array![1.0, 2.0, 0.5, 1.5, 0.8];
let u = Array2::<f64>::zeros((5, 0));
let v = Array2::<f64>::zeros((5, 0));
let weight = LowRankWeight::new(d.view(), u.view(), v.view()).unwrap();
let mut ws = PirlsWorkspace::new(5, 3, 0, 0);
let got = compute_xtwx_low_rank(&mut ws, &design, &weight).unwrap();
let want = design
.xt_diag_x_signed_op(SignedWeightsView::from_array(&d))
.unwrap();
let diff = (&got - &want).mapv(f64::abs).sum();
assert!(diff < 1e-12, "rank-0 path diverged from diagonal: {}", diff);
}
#[test]
fn xtwy_low_rank_matches_dense_reference() {
let design = tiny_design();
let d = array![1.0, 2.0, 0.5, 1.5, 0.8];
let u = array![
[0.1, -0.2],
[0.4, 0.3],
[-0.1, 0.5],
[0.2, 0.1],
[0.0, -0.3]
];
let v = array![[0.2, 0.1], [0.0, 0.4], [0.3, -0.2], [-0.1, 0.6], [0.5, 0.0]];
let weight = LowRankWeight::new(d.view(), u.view(), v.view()).unwrap();
let y = array![0.7, -1.2, 0.3, 0.9, -0.4];
let got = compute_xtwy_low_rank(&design, &weight, &y).unwrap();
let xdense = design.as_dense().unwrap().to_owned();
let mut w = Array2::<f64>::zeros((5, 5));
for i in 0..5 {
w[[i, i]] = d[i];
}
w += &u.dot(&v.t());
let want = xdense.t().dot(&w.dot(&y));
let diff: f64 = got
.iter()
.zip(want.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff < 1e-10, "xtwy_low_rank diverged: {}", diff);
}
#[test]
fn woodbury_capacitance_is_well_formed() {
let uhat = array![[0.5, 0.1], [-0.2, 0.7], [0.3, -0.4]];
let vhat = array![[0.1, 0.2], [0.6, -0.1], [-0.3, 0.4]];
let cap = woodbury_gram_capacitance(&uhat, &vhat).unwrap();
let want = {
let mut m = vhat.t().dot(&uhat);
for k in 0..2 {
m[[k, k]] += 1.0;
}
m
};
let diff: f64 = cap
.iter()
.zip(want.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff < 1e-12);
}
}