use super::*;
use crate::cache::Fingerprinter;
use crate::construction::{
create_balanced_penalty_root_from_canonical, precompute_reparam_invariant_from_canonical,
};
use crate::faer_ndarray::array2_to_matmut;
use crate::linalg::sparse_exact::build_sparse_penalty_blocks_from_canonical;
use crate::linalg::utils::{
StableSolver, boundary_hit_indices, enforce_symmetry, symmetric_spectrum_condition_number,
};
use crate::mixture_link::inverse_link_has_fisher_weight_jet;
use crate::pirls::PirlsWorkspace;
use crate::solver::estimate::reml::inner_strategy::HessianEvalStrategyKind;
use crate::solver::outer_strategy::{HessianResult, OuterEval};
use crate::solver::persistent_warm_start::{PersistentWarmStartRecord, load_record, store_record};
use crate::types::{
GlmLikelihoodSpec, InverseLink, LikelihoodSpec, LinkFunction, ResponseFamily, RhoPrior,
SasLinkState, StandardLink,
};
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock, Mutex, OnceLock};
const TK_BLOCK_SIZE: usize = 128;
const TK_CHUNK_MAX_ROWS: usize = 2048;
const TK_CHUNK_OVERSUBSCRIBE: usize = 4;
const TK_MAX_OBSERVATIONS: usize = 20_000;
const TK_MAX_COEFFICIENTS: usize = 2_000;
const ADAPTIVE_KKT_ETA: f64 = 0.1;
const ADAPTIVE_KKT_FLOOR_REML_DIVISOR: f64 = 100.0;
const TK_MAX_DENSE_WORK: usize = 5_000_000;
const LARGE_N_EFS_THRESHOLD: f64 = 1.0e8;
const EFS_SINGLE_LOOP_PIRLS_SWEEPS: usize = 2;
const EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL: usize = usize::MAX / 4;
const EFS_SINGLE_LOOP_BIAS_THRESHOLD: f64 = 0.10;
const EFS_SINGLE_LOOP_BIAS_CONSECUTIVE_LIMIT: usize = 3;
const HGB_INNER_FLOOR: f64 = 1e-12;
const HGB_LINEAR_FLOOR: f64 = 1e-12;
const HGB_TRACE_FLOOR: f64 = 1e-12;
const HGB_HISTORY_CAP: usize = 10;
const HGB_WARMUP_ITERS_MIN: usize = 3;
const HGB_WARMUP_ITERS_MAX: usize = 10;
const HGB_TARGET_FRACTION: f64 = 0.1;
const HGB_FD_DRHO_MAX_SQUARED: f64 = 1.0;
const HGB_MIN_PAIRS_FOR_SENSITIVITY: usize = 3;
const HGB_REGRESSION_RIDGE: f64 = 1e-6;
const HGB_SENS_STABILITY_RATIO: f64 = 1.5;
const S_INNER_INIT: f64 = 1.0;
const S_LINEAR_INIT: f64 = 1.0;
const S_TRACE_INIT: f64 = 1.0;
const HGB_SENS_FLOOR: f64 = 1e-6;
const IFT_QUALITY_HISTORY_CAP: usize = 5;
const ETA_OVERFLOW_CLAMP: f64 = 700.0;
const IFT_QUALITY_GROW_BAND: f64 = 1e-3;
const IFT_QUALITY_SHRINK_BAND: f64 = 1e-1;
const IFT_QUALITY_FLAT_FALLBACK_BAND: f64 = 0.5;
const IFT_STEP_CAP_GROW_FACTOR: f64 = 1.5;
const IFT_STEP_CAP_SHRINK_FACTOR: f64 = 0.5;
const KKT_TOL_PRIMAL: f64 = 1e-7;
const KKT_TOL_DUAL: f64 = 1e-7;
const KKT_TOL_COMP: f64 = 1e-7;
const KKT_TOL_STAT: f64 = 5e-6;
const ACTIVE_CONSTRAINT_SLACK_TOL: f64 = 1e-8;
const ORTHONORM_DROP_TOL: f64 = 1e-10;
#[derive(Debug, Clone)]
struct AloStabilizationEval {
cost: f64,
gradient: Option<Array1<f64>>,
k_hat: Option<f64>,
max_leverage: f64,
min_denominator: f64,
}
const ALO_STABILIZATION_MIN_N: usize = 20;
const ALO_EDF_FRACTION_SATURATION: f64 = 0.70;
const ALO_PERVASIVE_LEVERAGE_FRACTION: f64 = 0.25;
const ALO_PARAMETRIC_LEVERAGE_SHARE: f64 = 0.75;
const ALO_DENOM_INSTABILITY_THRESHOLD: f64 = 0.20;
const ALO_MAX_LEVERAGE_THRESHOLD: f64 = 0.80;
const ALO_TAU: f64 = 0.5;
const ALO_GAMMA: f64 = 0.5;
const ALO_DEVIANCE_SATURATION: f64 = 9.0;
const ALO_GRADIENT_MAX_WORK: usize = 4_000_000;
struct AloFactoredHessian<'a> {
x: &'a Array2<f64>,
chol: &'a crate::linalg::faer_ndarray::FaerCholeskyFactor,
h_inv_xt: &'a Array2<f64>,
}
fn alo_leverage_barrier(h: f64) -> f64 {
let excess = (h - ALO_MAX_LEVERAGE_THRESHOLD).max(0.0);
excess * excess
}
fn alo_leverage_barrier_derivative(h: f64) -> f64 {
if h > ALO_MAX_LEVERAGE_THRESHOLD {
2.0 * (h - ALO_MAX_LEVERAGE_THRESHOLD)
} else {
0.0
}
}
fn gaussian_alo_raw_deviance(y: f64, eta_loo: f64, prior_weight: f64, phi: f64) -> f64 {
let residual = y - eta_loo;
prior_weight * residual * residual / phi.max(f64::MIN_POSITIVE)
}
fn gaussian_alo_deviance(y: f64, eta_loo: f64, prior_weight: f64, phi: f64) -> f64 {
let raw = gaussian_alo_raw_deviance(y, eta_loo, prior_weight, phi);
ALO_DEVIANCE_SATURATION * (raw / ALO_DEVIANCE_SATURATION).tanh()
}
fn gaussian_alo_deviance_saturation_factor(raw: f64) -> f64 {
let t = (raw / ALO_DEVIANCE_SATURATION).tanh();
1.0 - t * t
}
fn transformed_penalty_matvec(
penalty: &crate::construction::CanonicalPenalty,
beta: &Array1<f64>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(beta.len());
let beta_block = beta.slice(ndarray::s![penalty.col_range.clone()]);
let centered = &beta_block - &penalty.prior_mean;
let local = penalty.local.dot(¢ered);
out.slice_mut(ndarray::s![penalty.col_range.clone()])
.assign(&local);
out
}
static OUTER_IFT_RESIDUAL_ENERGY: OnceLock<Mutex<HashMap<Vec<u64>, (f64, u64)>>> = OnceLock::new();
static OUTER_IFT_RESIDUAL_ENERGY_ITER: AtomicU64 = AtomicU64::new(0);
fn outer_ift_residual_energy_cache() -> &'static Mutex<HashMap<Vec<u64>, (f64, u64)>> {
OUTER_IFT_RESIDUAL_ENERGY.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) fn record_current_outer_iter_for_ift(iter: u64) {
OUTER_IFT_RESIDUAL_ENERGY_ITER.store(iter, Ordering::Relaxed);
}
pub(crate) fn current_outer_iter() -> u64 {
OUTER_IFT_RESIDUAL_ENERGY_ITER.load(Ordering::Relaxed)
}
pub(crate) fn clear_outer_ift_residual_energy_for_fit() {
if let Some(cache) = OUTER_IFT_RESIDUAL_ENERGY.get()
&& let Ok(mut cache) = cache.lock()
{
cache.clear();
}
OUTER_IFT_RESIDUAL_ENERGY_ITER.store(0, Ordering::Relaxed);
}
fn store_ift_residual_energy_for_outer_theta(theta: &Array1<f64>, energy: Option<f64>) {
let Some(key) = super::cache::sanitized_rhokey(theta) else {
return;
};
if let Ok(mut cache) = outer_ift_residual_energy_cache().lock() {
if let Some(energy) = energy.filter(|energy| energy.is_finite() && *energy >= 0.0) {
cache.insert(key, (energy, current_outer_iter()));
} else {
cache.remove(&key);
}
}
}
pub(super) struct PenaltySubspace {
evals: Array1<f64>,
evecs: Array2<f64>,
rank: usize,
}
pub(crate) struct HyperGradHistoryEntry {
pub(crate) rho: Array1<f64>,
pub(crate) g_outer: Array1<f64>,
pub(crate) e_inner: f64,
pub(crate) e_linear: f64,
pub(crate) sigma_sq: f64,
pub(crate) k: usize,
}
pub(crate) struct HyperGradientBudget {
pub(crate) target_mse: f64,
pub(crate) inner_floor: f64,
pub(crate) linear_floor: f64,
pub(crate) trace_floor: f64,
pub(crate) s_inner: f64,
pub(crate) s_linear: f64,
pub(crate) s_trace: f64,
pub(crate) history: VecDeque<HyperGradHistoryEntry>,
sensitivity_history: VecDeque<[f64; 3]>,
warmup_engaged: bool,
}
impl HyperGradientBudget {
fn new() -> Self {
Self {
target_mse: 0.0,
inner_floor: HGB_INNER_FLOOR,
linear_floor: HGB_LINEAR_FLOOR,
trace_floor: HGB_TRACE_FLOOR,
s_inner: S_INNER_INIT,
s_linear: S_LINEAR_INIT,
s_trace: S_TRACE_INIT,
history: VecDeque::with_capacity(HGB_HISTORY_CAP),
sensitivity_history: VecDeque::with_capacity(HGB_WARMUP_ITERS_MIN),
warmup_engaged: false,
}
}
fn push(&mut self, entry: HyperGradHistoryEntry) {
self.history.push_back(entry);
while self.history.len() > HGB_HISTORY_CAP {
self.history.pop_front();
}
}
fn previous_gradient_norm(&self) -> f64 {
self.history
.iter()
.rev()
.nth(1)
.or_else(|| self.history.back())
.map(|entry| l2_norm(&entry.g_outer))
.filter(|norm| norm.is_finite())
.unwrap_or(0.0)
}
fn reestimate_sensitivities(&mut self) -> Option<[f64; 3]> {
let pairs = self.finite_difference_pairs();
if pairs.len() < HGB_MIN_PAIRS_FOR_SENSITIVITY {
log::info!(
"[HGB] small-sample fallback to defaults: pairs={}, threshold={}",
pairs.len(),
HGB_MIN_PAIRS_FOR_SENSITIVITY
);
return None;
}
let Some(s_inner) = self.estimate_energy_sensitivity(&pairs, |entry| entry.e_inner) else {
return None;
};
let Some(s_linear) = self.estimate_energy_sensitivity(&pairs, |entry| entry.e_linear)
else {
return None;
};
let Some(s_trace) = self.estimate_trace_sensitivity() else {
return None;
};
let sensitivities = [
s_inner.max(HGB_SENS_FLOOR),
s_linear.max(HGB_SENS_FLOOR),
s_trace.max(HGB_SENS_FLOOR),
];
self.s_inner = sensitivities[0];
self.s_linear = sensitivities[1];
self.s_trace = sensitivities[2];
self.sensitivity_history.push_back(sensitivities);
while self.sensitivity_history.len() > HGB_WARMUP_ITERS_MIN {
self.sensitivity_history.pop_front();
}
Some(sensitivities)
}
fn estimate_energy_sensitivity<F>(
&self,
pairs: &[(Array1<f64>, Array1<f64>, usize)],
energy: F,
) -> Option<f64>
where
F: Fn(&HyperGradHistoryEntry) -> f64,
{
let mut estimates = Vec::new();
for i in 0..pairs.len() {
let (drho_i, dg_i, left_idx) = &pairs[i];
let rho_dim = drho_i.len();
let grad_dim = dg_i.len();
if rho_dim == 0 || grad_dim == 0 {
continue;
}
let mut xtx = Array2::<f64>::zeros((rho_dim, rho_dim));
for d in 0..rho_dim {
xtx[[d, d]] = HGB_REGRESSION_RIDGE;
}
let mut xty = Array2::<f64>::zeros((rho_dim, grad_dim));
let mut fit_pairs = 0usize;
for (j, (drho_j, dg_j, _)) in pairs.iter().enumerate() {
if i == j || drho_j.len() != rho_dim || dg_j.len() != grad_dim {
continue;
}
if drho_j.iter().any(|v| !v.is_finite()) || dg_j.iter().any(|v| !v.is_finite()) {
continue;
}
for row in 0..rho_dim {
for col in 0..rho_dim {
xtx[[row, col]] += drho_j[row] * drho_j[col];
}
for grad in 0..grad_dim {
xty[[row, grad]] += drho_j[row] * dg_j[grad];
}
}
fit_pairs += 1;
}
if fit_pairs == 0 {
continue;
}
let Ok(chol) = xtx.cholesky(Side::Lower) else {
continue;
};
chol.solve_mat_in_place(&mut xty);
let mut predicted = Array1::<f64>::zeros(grad_dim);
for grad in 0..grad_dim {
let mut value = 0.0;
for rho in 0..rho_dim {
value += drho_i[rho] * xty[[rho, grad]];
}
predicted[grad] = value;
}
let residual = dg_i - &predicted;
let e0 = energy(&self.history[*left_idx]);
let e1 = energy(&self.history[*left_idx + 1]);
let denom_energy = e0.max(e1).max(1e-300);
if !denom_energy.is_finite() || denom_energy < 0.0 {
continue;
}
let estimate = l2_norm(&residual) / (2.0 * denom_energy).sqrt();
if estimate.is_finite() && estimate > 0.0 {
estimates.push(estimate);
}
}
mean_positive(&estimates)
}
fn finite_difference_pairs(&self) -> Vec<(Array1<f64>, Array1<f64>, usize)> {
let entries: Vec<_> = self.history.iter().collect();
let mut pairs = Vec::new();
for i in 0..entries.len().saturating_sub(1) {
let a = entries[i];
let b = entries[i + 1];
if a.rho.len() != b.rho.len() || a.g_outer.len() != b.g_outer.len() {
continue;
}
let drho = &b.rho - &a.rho;
let dg = &b.g_outer - &a.g_outer;
let drho_norm_squared = drho.dot(&drho);
if drho.iter().all(|v| v.is_finite())
&& dg.iter().all(|v| v.is_finite())
&& drho_norm_squared > 0.0
&& drho_norm_squared <= HGB_FD_DRHO_MAX_SQUARED
{
pairs.push((drho, dg, i));
}
}
pairs
}
fn estimate_trace_sensitivity(&self) -> Option<f64> {
let last_k = self.history.back()?.k;
if last_k == 0 {
return None;
}
let fixed: Vec<&HyperGradHistoryEntry> = self
.history
.iter()
.rev()
.take_while(|entry| entry.k == last_k)
.collect();
if fixed.len() < HGB_WARMUP_ITERS_MIN {
return None;
}
if fixed
.iter()
.any(|entry| !entry.sigma_sq.is_finite() || entry.sigma_sq < 0.0)
{
return None;
}
let dim = fixed[0].g_outer.len();
if dim == 0 || fixed.iter().any(|entry| entry.g_outer.len() != dim) {
return None;
}
let mut means = Array1::<f64>::zeros(dim);
for entry in fixed.iter() {
means += &entry.g_outer;
}
means /= fixed.len() as f64;
let mut variance_sum = 0.0;
for entry in fixed.iter() {
let diff = &entry.g_outer - &means;
variance_sum += diff.dot(&diff);
}
let denom = ((fixed.len() - 1) * dim) as f64;
let std = (variance_sum / denom).max(0.0).sqrt();
(std.is_finite() && std > 0.0).then_some(std)
}
fn allocate_with_sensitivities(
&self,
s_inner: f64,
s_linear: f64,
s_trace: f64,
) -> (f64, f64, f64, [bool; 3]) {
let floors = self.inner_floor + self.linear_floor + self.trace_floor;
let usable = self.target_mse - floors;
if usable <= 0.0 || !usable.is_finite() {
log::warn!(
"[HGB] target_mse below mandatory floors; target_mse={:.3e} floors={:.3e}",
self.target_mse,
floors
);
return (
self.inner_floor,
self.linear_floor,
self.trace_floor,
[true, true, true],
);
}
let wi = s_inner * s_inner;
let wl = s_linear * s_linear;
let wt = s_trace * s_trace;
let sum = (wi + wl + wt).max(HGB_SENS_FLOOR * HGB_SENS_FLOOR);
(
self.inner_floor + usable * wi / sum,
self.linear_floor + usable * wl / sum,
self.trace_floor + usable * wt / sum,
[false, false, false],
)
}
fn sensitivities_stable(&self) -> bool {
if self.sensitivity_history.len() < HGB_WARMUP_ITERS_MIN {
return false;
}
for channel in 0..3 {
let mut min_recent = f64::INFINITY;
let mut max_recent: f64 = 0.0;
for sensitivities in self.sensitivity_history.iter() {
let value = sensitivities[channel];
if !value.is_finite() || value <= 0.0 {
return false;
}
min_recent = min_recent.min(value);
max_recent = max_recent.max(value);
}
if max_recent / min_recent >= HGB_SENS_STABILITY_RATIO {
return false;
}
}
true
}
}
struct HyperGradientRuntimeState {
budget: HyperGradientBudget,
adaptive_kkt_override: Option<f64>,
trace_state: Arc<Mutex<super::unified::StochasticTraceState>>,
}
impl HyperGradientRuntimeState {
fn new() -> Self {
Self {
budget: HyperGradientBudget::new(),
adaptive_kkt_override: None,
trace_state: Arc::new(Mutex::new(super::unified::StochasticTraceState::default())),
}
}
}
static HYPERGRADIENT_BUDGETS: OnceLock<Mutex<HashMap<usize, HyperGradientRuntimeState>>> =
OnceLock::new();
fn hypergradient_budgets() -> &'static Mutex<HashMap<usize, HyperGradientRuntimeState>> {
HYPERGRADIENT_BUDGETS.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Default)]
struct IftQualityRuntimeState {
quality_history: Vec<f64>,
next_step_cap: Option<f64>,
fallback_next_flat: bool,
}
static IFT_QUALITY_STATES: OnceLock<Mutex<HashMap<usize, IftQualityRuntimeState>>> =
OnceLock::new();
fn ift_quality_states() -> &'static Mutex<HashMap<usize, IftQualityRuntimeState>> {
IFT_QUALITY_STATES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Clone)]
struct IftModeResponseRuntimeCache {
rho: Array1<f64>,
rho_mode_response_cols: Option<Array2<f64>>,
ext_mode_response_cols: Option<Array2<f64>>,
}
static IFT_MODE_RESPONSE_CACHES: OnceLock<Mutex<HashMap<usize, IftModeResponseRuntimeCache>>> =
OnceLock::new();
fn ift_mode_response_caches() -> &'static Mutex<HashMap<usize, IftModeResponseRuntimeCache>> {
IFT_MODE_RESPONSE_CACHES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Clone)]
struct IftJointModeResponseRuntimeCache {
theta: Array1<f64>,
rho_dim: usize,
beta_original: Array1<f64>,
mode_response_cols: Array2<f64>,
active_constraints: bool,
}
static IFT_JOINT_MODE_RESPONSE_CACHES: OnceLock<
Mutex<HashMap<usize, IftJointModeResponseRuntimeCache>>,
> = OnceLock::new();
fn ift_joint_mode_response_caches()
-> &'static Mutex<HashMap<usize, IftJointModeResponseRuntimeCache>> {
IFT_JOINT_MODE_RESPONSE_CACHES.get_or_init(|| Mutex::new(HashMap::new()))
}
fn joint_ift_cache_matches_theta(
cache: &IftJointModeResponseRuntimeCache,
theta: &Array1<f64>,
new_rho: &Array1<f64>,
) -> bool {
if cache.theta.len() <= cache.rho_dim
|| theta.len() != cache.theta.len()
|| new_rho.len() != cache.rho_dim
{
return false;
}
for i in 0..cache.rho_dim {
if theta[i].to_bits() != new_rho[i].to_bits() {
return false;
}
}
for i in cache.rho_dim..theta.len() {
if theta[i].to_bits() != cache.theta[i].to_bits() {
return false;
}
}
true
}
static IFT_LATEST_OUTER_THETA: OnceLock<Mutex<Option<Array1<f64>>>> = OnceLock::new();
static IFT_LATEST_OUTER_RHO_UPPER_BOUNDS: OnceLock<Mutex<Option<Array1<f64>>>> = OnceLock::new();
pub(crate) fn record_current_outer_theta_for_ift(theta: &Array1<f64>) {
let value = if theta.is_empty() || theta.iter().any(|v| !v.is_finite()) {
None
} else {
Some(theta.clone())
};
*IFT_LATEST_OUTER_THETA
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap() = value;
}
pub(crate) fn record_current_outer_rho_upper_bounds_for_ift(upper: &Array1<f64>) {
let value = if upper.is_empty() || upper.iter().any(|v| !v.is_finite()) {
None
} else {
Some(upper.clone())
};
*IFT_LATEST_OUTER_RHO_UPPER_BOUNDS
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap() = value;
}
pub(crate) fn latest_outer_rho_upper_bounds_for_ift() -> Option<Array1<f64>> {
IFT_LATEST_OUTER_RHO_UPPER_BOUNDS
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap()
.clone()
}
pub(crate) fn latest_outer_theta_for_ift() -> Option<Array1<f64>> {
IFT_LATEST_OUTER_THETA
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap()
.clone()
}
fn l2_norm(values: &Array1<f64>) -> f64 {
values.iter().map(|v| v * v).sum::<f64>().sqrt()
}
fn mean_positive(values: &[f64]) -> Option<f64> {
let mut sum = 0.0;
let mut count = 0usize;
for &value in values {
if value.is_finite() && value > 0.0 {
sum += value;
count += 1;
}
}
(count > 0).then_some(sum / count as f64)
}
#[derive(Default)]
struct EfsSingleLoopBiasGuardState {
owner: usize,
consecutive: usize,
}
static EFS_SINGLE_LOOP_BIAS_GUARD: LazyLock<Mutex<EfsSingleLoopBiasGuardState>> =
LazyLock::new(|| Mutex::new(EfsSingleLoopBiasGuardState::default()));
#[inline]
fn compute_gradient_for_tk(mode: super::unified::EvalMode) -> bool {
mode != super::unified::EvalMode::ValueOnly
}
#[inline]
fn efs_single_loop_encoded_cap() -> usize {
EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL + EFS_SINGLE_LOOP_PIRLS_SWEEPS
}
#[inline]
fn decode_efs_single_loop_cap(raw_cap: usize) -> Option<usize> {
(raw_cap >= EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL)
.then(|| raw_cap - EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL)
.filter(|cap| *cap > 0)
}
#[inline]
fn screening_residual_penalty(cost: f64, pr: &PirlsResult) -> f64 {
if !cost.is_finite() || !pr.status.is_failed_max_iterations() {
return cost;
}
let r_g = pr.relative_gradient_norm();
if r_g.is_finite() {
cost + 0.5 * r_g * r_g
} else {
f64::INFINITY
}
}
fn hash_array_view(hasher: &mut Fingerprinter, values: ndarray::ArrayView1<'_, f64>) {
hasher.write_usize(values.len());
for &value in values {
hasher.write_f64(value);
}
}
fn hash_array2(hasher: &mut Fingerprinter, values: &Array2<f64>) {
hasher.write_usize(values.nrows());
hasher.write_usize(values.ncols());
for &value in values {
hasher.write_f64(value);
}
}
fn hash_aux_prior_strength(
hasher: &mut Fingerprinter,
strength: crate::terms::latent_coord::AuxPriorStrength,
) {
use crate::terms::latent_coord::AuxPriorStrength;
match strength {
AuxPriorStrength::Auto => hasher.write_str("auto"),
AuxPriorStrength::Fixed(value) => {
hasher.write_str("fixed");
hasher.write_f64(value);
}
}
}
pub(in crate::solver::estimate) fn latent_id_mode_cache_fingerprint(
id_mode: &crate::terms::latent_coord::LatentIdMode,
) -> u64 {
use crate::terms::latent_coord::{AuxPriorFamily, LatentIdMode};
let mut hasher = Fingerprinter::new();
hasher.write_str("latent-id-mode-cache-v1");
match id_mode {
LatentIdMode::AuxPrior {
u,
family,
strength,
} => {
hasher.write_str("aux-prior");
hash_array2(&mut hasher, u);
match family {
AuxPriorFamily::Ridge => hasher.write_str("ridge"),
AuxPriorFamily::Linear => hasher.write_str("linear"),
}
hash_aux_prior_strength(&mut hasher, *strength);
}
LatentIdMode::AuxPriorDimSelection {
u,
family,
strength,
..
} => {
hasher.write_str("aux-prior-dim-selection");
hash_array2(&mut hasher, u);
match family {
AuxPriorFamily::Ridge => hasher.write_str("ridge"),
AuxPriorFamily::Linear => hasher.write_str("linear"),
}
hash_aux_prior_strength(&mut hasher, *strength);
}
LatentIdMode::DimSelection { .. } => hasher.write_str("dim-selection"),
LatentIdMode::None => hasher.write_str("none"),
}
hasher.finish_u64()
}
fn hash_array3(hasher: &mut Fingerprinter, values: &ndarray::Array3<f64>) {
let (a, b, c) = values.dim();
hasher.write_usize(a);
hasher.write_usize(b);
hasher.write_usize(c);
for &value in values {
hasher.write_f64(value);
}
}
fn hash_psi_slice(hasher: &mut Fingerprinter, target: &crate::terms::analytic_penalties::PsiSlice) {
hasher.write_usize(target.range.start);
hasher.write_usize(target.range.end);
match target.latent_dim {
Some(latent_dim) => {
hasher.write_bool(true);
hasher.write_usize(latent_dim);
}
None => hasher.write_bool(false),
}
}
fn hash_scalar_weight_schedule(
hasher: &mut Fingerprinter,
schedule: &crate::terms::analytic_penalties::ScalarWeightSchedule,
) {
use crate::terms::sae_manifold::ScheduleKind;
hasher.write_f64(schedule.w_start);
hasher.write_f64(schedule.w_end);
match &schedule.kind {
ScheduleKind::Geometric { rate } => {
hasher.write_str("geometric");
hasher.write_f64(*rate);
}
ScheduleKind::Linear { steps } => {
hasher.write_str("linear");
hasher.write_usize(*steps);
}
ScheduleKind::ReciprocalIter => hasher.write_str("reciprocal-iter"),
}
hasher.write_usize(schedule.iter_count);
}
fn hash_weight_schedule_option(
hasher: &mut Fingerprinter,
schedule: &Option<crate::terms::analytic_penalties::ScalarWeightSchedule>,
) {
match schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_scalar_weight_schedule(hasher, schedule);
}
None => hasher.write_bool(false),
}
}
fn hash_gumbel_temperature_schedule(
hasher: &mut Fingerprinter,
schedule: &crate::terms::sae_manifold::GumbelTemperatureSchedule,
) {
use crate::terms::sae_manifold::ScheduleKind;
hasher.write_f64(schedule.tau_start);
hasher.write_f64(schedule.tau_min);
match &schedule.decay {
ScheduleKind::Geometric { rate } => {
hasher.write_str("geometric");
hasher.write_f64(*rate);
}
ScheduleKind::Linear { steps } => {
hasher.write_str("linear");
hasher.write_usize(*steps);
}
ScheduleKind::ReciprocalIter => hasher.write_str("reciprocal-iter"),
}
hasher.write_usize(schedule.iter_count);
}
fn hash_gumbel_schedule_option(
hasher: &mut Fingerprinter,
schedule: &Option<crate::terms::sae_manifold::GumbelTemperatureSchedule>,
) {
match schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_gumbel_temperature_schedule(hasher, schedule);
}
None => hasher.write_bool(false),
}
}
fn hash_isometry_reference(
hasher: &mut Fingerprinter,
reference: &crate::terms::analytic_penalties::IsometryReference,
) {
use crate::terms::analytic_penalties::IsometryReference;
match reference {
IsometryReference::Euclidean => hasher.write_str("euclidean"),
IsometryReference::UserSupplied(values) => {
hasher.write_str("user-supplied");
hash_array2(hasher, values.as_ref());
}
IsometryReference::MeanProfiled => hasher.write_str("mean-profiled"),
}
}
fn hash_weight_field(
hasher: &mut Fingerprinter,
field: &crate::terms::analytic_penalties::WeightField,
) {
use crate::terms::analytic_penalties::WeightField;
match field {
WeightField::Identity => hasher.write_str("identity"),
WeightField::Factored { u, rank, p_out } => {
hasher.write_str("factored");
hash_array2(hasher, u.as_ref());
hasher.write_usize(*rank);
hasher.write_usize(*p_out);
}
}
}
fn hash_sparsity_kind(
hasher: &mut Fingerprinter,
kind: crate::terms::analytic_penalties::SparsityKind,
) {
use crate::terms::analytic_penalties::SparsityKind;
match kind {
SparsityKind::SmoothedL1 { eps } => {
hasher.write_str("smoothed-l1");
hasher.write_f64(eps);
}
SparsityKind::Hoyer => hasher.write_str("hoyer"),
SparsityKind::Log { delta } => {
hasher.write_str("log");
hasher.write_f64(delta);
}
}
}
fn hash_difference_op_kind(
hasher: &mut Fingerprinter,
kind: &crate::terms::analytic_penalties::DifferenceOpKind,
) {
use crate::terms::analytic_penalties::DifferenceOpKind;
match kind {
DifferenceOpKind::ForwardDiff1D => hasher.write_str("forward-diff-1d"),
DifferenceOpKind::GraphEdges(edges) => {
hasher.write_str("graph-edges");
hasher.write_usize(edges.len());
for &(from, to) in edges {
hasher.write_usize(from);
hasher.write_usize(to);
}
}
}
}
fn hash_groups(hasher: &mut Fingerprinter, groups: &[Vec<usize>]) {
hasher.write_usize(groups.len());
for group in groups {
hasher.write_usize(group.len());
for &axis in group {
hasher.write_usize(axis);
}
}
}
fn hash_analytic_penalty_kind(
hasher: &mut Fingerprinter,
penalty: &crate::terms::analytic_penalties::AnalyticPenaltyKind,
) {
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, PenaltyConcavity};
hasher.write_str(penalty.name());
hasher.write_str(&format!("{:?}", penalty.tier()));
hasher.write_usize(penalty.rho_count());
match penalty {
AnalyticPenaltyKind::Isometry(p) => {
hasher.write_str("isometry");
hash_psi_slice(hasher, &p.target);
hash_isometry_reference(hasher, &p.reference);
hasher.write_usize(p.rho_index);
hasher.write_usize(p.p_out);
hash_weight_field(hasher, &p.weight);
hasher.write_f64(p.scalar_weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
match p.jacobian_cache() {
Some(values) => {
hasher.write_bool(true);
hash_array2(hasher, values.as_ref());
}
None => hasher.write_bool(false),
}
match p.jacobian_second_cache() {
Some(values) => {
hasher.write_bool(true);
hash_array2(hasher, values.as_ref());
}
None => hasher.write_bool(false),
}
match p.duchon_radial_source.as_ref() {
Some(source) => {
hasher.write_bool(true);
hash_array2(hasher, source.centers.as_ref());
hash_array2(hasher, source.radial_coefficients.as_ref());
match source.length_scale {
Some(length_scale) => {
hasher.write_bool(true);
hasher.write_f64(length_scale);
}
None => hasher.write_bool(false),
}
hasher.write_str(&format!("{:?}", source.nullspace_order));
}
None => hasher.write_bool(false),
}
match p.third_decoder_derivative() {
Some(values) => {
hasher.write_bool(true);
hash_array3(hasher, values.as_ref());
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::Sparsity(p) => {
hasher.write_str("sparsity");
hasher.write_str(&format!("{:?}", p.target_tier));
hash_sparsity_kind(hasher, p.kind);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
hasher.write_usize(p.strength_rho_index);
match p.eps_rho_index {
Some(idx) => {
hasher.write_bool(true);
hasher.write_usize(idx);
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::SoftmaxAssignmentSparsity(p) => {
hasher.write_str("softmax-assignment-sparsity");
hasher.write_usize(p.k_atoms);
hasher.write_f64(p.temperature);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::IBPAssignment(p) => {
hasher.write_str("ibp-assignment");
hasher.write_usize(p.k_max);
hasher.write_f64(p.alpha);
hasher.write_f64(p.tau);
hash_gumbel_schedule_option(hasher, &p.temperature_schedule);
hasher.write_bool(p.learnable_alpha);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Ard(p) => {
hasher.write_str("ard");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
hasher.write_usize(p.rho_indices.len());
for &idx in &p.rho_indices {
hasher.write_usize(idx);
}
hasher.write_f64(p.n_eff);
}
AnalyticPenaltyKind::TopKActivation(p) => {
hasher.write_str("topk-activation");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.k);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::JumpReLU(p) => {
hasher.write_str("jumprelu");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hash_array_view(hasher, p.thresholds.view());
hasher.write_f64(p.weight);
hasher.write_f64(p.smoothing_eps);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::TotalVariation(p) => {
hasher.write_str("total-variation");
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hash_difference_op_kind(hasher, &p.difference_op);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::NuclearNorm(p) => {
hasher.write_str("nuclear-norm");
hash_psi_slice(hasher, &p.target);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.smoothing_eps);
match p.max_rank {
Some(max_rank) => {
hasher.write_bool(true);
hasher.write_usize(max_rank);
}
None => hasher.write_bool(false),
}
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::BlockSparsity(p) => {
hasher.write_str("block-sparsity");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.groups);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::MechanismSparsity(p) => {
hasher.write_str("mechanism-sparsity");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.feature_groups);
hasher.write_f64(p.weight);
hasher.write_f64(p.smoothing_eps);
hasher.write_f64(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
match &p.weight_schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_scalar_weight_schedule(hasher, schedule.as_ref());
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
hasher.write_str("row-precision-prior");
hash_array3(hasher, &p.lambda_per_row);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::IvaeRidgeMeanGauge(p) => {
hasher.write_str("ivae-ridge-mean-gauge");
hash_array2(hasher, &p.aux);
hash_array2(hasher, &p.ridge_inv);
hasher.write_f64(p.ridge_eps);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
hasher.write_str("parametric-row-precision-prior");
hash_array2(hasher, &p.aux);
hash_array_view(hasher, p.log_alpha.view());
hash_array_view(hasher, p.raw_beta.view());
hash_array2(hasher, &p.mu);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::ScadMcp(p) => {
hasher.write_str("scad-mcp");
hash_psi_slice(hasher, &p.target);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.gamma);
hasher.write_f64(p.smoothing_eps);
match p.variant {
PenaltyConcavity::Mcp => hasher.write_str("mcp"),
PenaltyConcavity::Scad => hasher.write_str("scad"),
}
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::BlockOrthogonality(p) => {
hasher.write_str("block-orthogonality");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.groups);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::DecoderIncoherence(p) => {
hasher.write_str("decoder-incoherence");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.block_sizes.len());
for &m in &p.block_sizes {
hasher.write_usize(m);
}
hasher.write_usize(p.p_out);
hash_array2(hasher, &p.coactivation);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Orthogonality(p) => {
hasher.write_str("orthogonality");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::NestedPrefix(p) => {
hasher.write_str("nested-prefix");
hash_psi_slice(hasher, &p.target);
hasher.write_str(&format!("{:?}", p.target_tier));
hasher.write_usize(p.prefix_sizes.len());
for &m in &p.prefix_sizes {
hasher.write_usize(m);
}
hasher.write_usize(p.shell_weights.len());
for &w in &p.shell_weights {
hasher.write_f64(w);
}
hasher.write_f64(p.eps);
hasher.write_usize(p.rho_indices.len());
for &idx in &p.rho_indices {
hasher.write_usize(idx);
}
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Monotonicity(p) => {
hasher.write_str("monotonicity");
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.direction);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::SheafConsistency(p) => {
hasher.write_str("sheaf-consistency");
hasher.write_f64(p.weight());
let dims = p.stalk_dims();
hasher.write_usize(dims.len());
for &d in dims {
hasher.write_usize(d);
}
}
}
}
pub(crate) fn analytic_penalty_registry_fingerprint(
registry: &crate::terms::analytic_penalties::AnalyticPenaltyRegistry,
) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("analytic-penalty-registry-v1");
hasher.write_usize(registry.penalties.len());
for penalty in ®istry.penalties {
hash_analytic_penalty_kind(&mut hasher, penalty);
}
hasher.finish_u64()
}
fn hash_design_matrix(hasher: &mut Fingerprinter, design: &DesignMatrix) -> Result<(), String> {
const HASH_CHUNK_TARGET_BYTES: usize = 8 * 1024 * 1024;
const HASH_CHUNK_MIN_ROWS: usize = 1;
const HASH_CHUNK_MAX_ROWS: usize = 4096;
let n = design.nrows();
let p = design.ncols();
hasher.write_usize(n);
hasher.write_usize(p);
let bytes_per_row = p.saturating_mul(std::mem::size_of::<f64>()).max(1);
let chunk_rows =
(HASH_CHUNK_TARGET_BYTES / bytes_per_row).clamp(HASH_CHUNK_MIN_ROWS, HASH_CHUNK_MAX_ROWS);
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.try_row_chunk(start..end)
.map_err(|e| format!("persistent warm-start design hash failed: {e}"))?;
hash_array2(hasher, &chunk);
}
Ok(())
}
fn hash_canonical_penalties(
hasher: &mut Fingerprinter,
penalties: &[crate::construction::CanonicalPenalty],
) {
hasher.write_usize(penalties.len());
for penalty in penalties {
hasher.write_usize(penalty.col_range.start);
hasher.write_usize(penalty.col_range.end);
hasher.write_usize(penalty.total_dim);
hasher.write_usize(penalty.nullity);
hash_array2(hasher, &penalty.root);
hash_array2(hasher, &penalty.local);
hash_array_view(hasher, penalty.prior_mean.view());
hasher.write_usize(penalty.positive_eigenvalues.len());
for &value in &penalty.positive_eigenvalues {
hasher.write_f64(value);
}
hasher.write_bool(penalty.op.is_some());
}
}
fn finite_positive_from_bits(bits: u64) -> Option<f64> {
if bits == 0 {
return None;
}
let value = f64::from_bits(bits);
if value.is_finite() && value > 0.0 {
Some(value)
} else {
None
}
}
fn finite_nonnegative_from_bits(bits: u64) -> Option<f64> {
let value = f64::from_bits(bits);
if value.is_finite() && value >= 0.0 {
Some(value)
} else {
None
}
}
fn finite_nonnegative_bits_or_no_signal(value: Option<f64>) -> u64 {
value
.filter(|v| v.is_finite() && *v >= 0.0)
.map(f64::to_bits)
.unwrap_or(IFT_RESIDUAL_NO_SIGNAL_BITS)
}
struct TkCorrectionTerms {
value: f64,
gradient: Option<Array1<f64>>,
hessian: Option<Array2<f64>>,
}
struct TkSharedIntermediates {
h_diag: Array1<f64>,
x_m: Array1<f64>,
y: Array1<f64>,
active_blocks: Vec<TkActiveBlock>,
}
struct TkActiveBlock {
start: usize,
end: usize,
entries: Vec<(usize, f64)>,
}
struct DerivativeContext {
deriv_provider: Box<dyn super::unified::HessianDerivativeProvider>,
dispersion: super::unified::DispersionHandling,
log_likelihood: f64,
firth_op: Option<std::sync::Arc<super::FirthDenseOperator>>,
barrier_config: Option<super::unified::BarrierConfig>,
}
#[inline]
fn reml_spec(likelihood: &GlmLikelihoodSpec) -> LikelihoodSpec {
likelihood.spec.clone()
}
#[inline]
fn reml_is_gaussian_identity(likelihood: &GlmLikelihoodSpec) -> bool {
reml_spec(likelihood).is_gaussian_identity()
}
#[inline]
fn reml_jeffreys_supported_link(likelihood: &GlmLikelihoodSpec) -> Option<InverseLink> {
let spec = reml_spec(likelihood);
if !matches!(spec.response, ResponseFamily::Binomial) {
return None;
}
if inverse_link_has_fisher_weight_jet(&spec.link) {
Some(spec.link.clone())
} else {
None
}
}
#[inline]
pub(super) fn reml_robust_jeffreys_link(config: &RemlConfig) -> Option<InverseLink> {
if !config.firth_bias_reduction {
return None;
}
reml_jeffreys_supported_link(&config.likelihood)
}
const FIRTH_DEFAULT_PC_UPPER: f64 = 10.0;
const FIRTH_DEFAULT_PC_TAIL_PROB: f64 = 0.01;
#[inline]
fn firth_default_pc_prior() -> RhoPrior {
RhoPrior::PenalizedComplexity {
upper: FIRTH_DEFAULT_PC_UPPER,
tail_prob: FIRTH_DEFAULT_PC_TAIL_PROB,
}
}
fn firth_default_coord_mask(configured: &RhoPrior, len: usize) -> Vec<bool> {
match configured {
RhoPrior::Flat => vec![true; len],
RhoPrior::Independent(priors) if priors.len() == len => {
priors.iter().map(|p| matches!(p, RhoPrior::Flat)).collect()
}
_ => vec![false; len],
}
}
fn resolve_effective_rho_prior(configured: &RhoPrior) -> std::borrow::Cow<'_, RhoPrior> {
match configured {
RhoPrior::Flat => std::borrow::Cow::Owned(firth_default_pc_prior()),
RhoPrior::Independent(priors) if priors.iter().any(|p| matches!(p, RhoPrior::Flat)) => {
let filled = priors
.iter()
.map(|p| match p {
RhoPrior::Flat => firth_default_pc_prior(),
other => other.clone(),
})
.collect();
std::borrow::Cow::Owned(RhoPrior::Independent(filled))
}
other => std::borrow::Cow::Borrowed(other),
}
}
#[inline]
fn reml_fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
let spec = reml_spec(likelihood);
match (&spec.response, &spec.link) {
(ResponseFamily::Beta { phi }, _) => *phi,
(ResponseFamily::NegativeBinomial { .. }, _) => 1.0,
(ResponseFamily::Tweedie { .. }, _) => likelihood.fixed_phi().unwrap_or(1.0),
(
ResponseFamily::Gaussian
| ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Gamma,
_,
) => likelihood.fixed_phi().unwrap_or(1.0),
(ResponseFamily::RoystonParmar, _) => likelihood.fixed_phi().unwrap_or(1.0),
}
}
const MIN_IMPORTANCE_ESS_FRACTION: f64 = 0.10;
struct Gam784BlockTarget<'t> {
x_transformed: &'t Array2<f64>,
block_vecs: Array2<f64>,
block_lambdas: Array1<f64>,
eta_hat: Array1<f64>,
weights_obs: Array1<f64>,
y: Array1<f64>,
prior_weights: Array1<f64>,
likelihood: GlmLikelihoodSpec,
inverse_link: InverseLink,
phi: f64,
penalty_scores: Vec<Array1<f64>>,
lambdas: Vec<f64>,
base_deviance: f64,
}
impl Gam784BlockTarget<'_> {
fn displacement(&self, t: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
let delta = self.block_vecs.dot(t);
let s = crate::faer_ndarray::fast_av(self.x_transformed, &delta);
(delta, s)
}
}
impl crate::inference::hmc::BlockExcessTarget for Gam784BlockTarget<'_> {
fn block_dim(&self) -> usize {
self.block_lambdas.len()
}
fn rho_dim(&self) -> usize {
self.lambdas.len()
}
fn block_curvatures(&self) -> &Array1<f64> {
&self.block_lambdas
}
fn excess(&self, t: &Array1<f64>) -> f64 {
let (delta, s) = self.displacement(t);
let mut mu_disp = Array1::<f64>::zeros(self.eta_hat.len());
for i in 0..self.eta_hat.len() {
let eta_i = self.eta_hat[i] + s[i];
match crate::mixture_link::inverse_link_jet_for_inverse_link(&self.inverse_link, eta_i)
{
Ok(jet) => mu_disp[i] = jet.mu,
Err(_) => return f64::INFINITY,
}
}
let dev_disp = crate::pirls::calculate_deviance(
self.y.view(),
&mu_disp,
&self.likelihood,
self.prior_weights.view(),
);
if !dev_disp.is_finite() {
return f64::INFINITY;
}
let neg_loglik_diff = (dev_disp - self.base_deviance) / (2.0 * self.phi);
let mut penalty_term = 0.0_f64;
for (score, &lam) in self.penalty_scores.iter().zip(self.lambdas.iter()) {
penalty_term += lam * score.dot(&delta);
}
let mut curv = 0.0_f64;
for i in 0..s.len() {
curv += self.weights_obs[i] * s[i] * s[i];
}
neg_loglik_diff + penalty_term - 0.5 * curv
}
fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
let (delta, _s) = self.displacement(t);
let mut grad = Array1::<f64>::zeros(self.lambdas.len());
for (k, (score, &lam)) in self
.penalty_scores
.iter()
.zip(self.lambdas.iter())
.enumerate()
{
grad[k] = lam * score.dot(&delta);
}
grad
}
}
impl<'a> RemlState<'a> {
const POLISH_NORM_RATIO: f64 = 0.25;
fn hypergradient_owner_key(&self) -> usize {
self as *const _ as usize
}
fn ift_quality_step_cap(&self, default_cap: f64) -> f64 {
let states = ift_quality_states().lock().unwrap();
states
.get(&self.hypergradient_owner_key())
.and_then(|state| state.next_step_cap)
.filter(|cap| cap.is_finite() && *cap > 0.0)
.unwrap_or(default_cap)
}
fn take_ift_quality_flat_override(&self) -> bool {
let mut states = ift_quality_states().lock().unwrap();
let Some(state) = states.get_mut(&self.hypergradient_owner_key()) else {
return false;
};
let fallback = state.fallback_next_flat;
state.fallback_next_flat = false;
fallback
}
fn clear_ift_quality_runtime_state(&self) {
let mut states = ift_quality_states().lock().unwrap();
states.remove(&self.hypergradient_owner_key());
}
fn record_ift_prediction_quality(&self, quality: f64, current_cap: f64) -> Option<f64> {
if !quality.is_finite() || quality < 0.0 || !current_cap.is_finite() || current_cap <= 0.0 {
return None;
}
let mut states = ift_quality_states().lock().unwrap();
let state = states.entry(self.hypergradient_owner_key()).or_default();
state.quality_history.push(quality);
while state.quality_history.len() > IFT_QUALITY_HISTORY_CAP {
state.quality_history.remove(0);
}
let rolling_quality =
state.quality_history.iter().sum::<f64>() / state.quality_history.len() as f64;
let next_step_cap = if rolling_quality < IFT_QUALITY_GROW_BAND {
current_cap * IFT_STEP_CAP_GROW_FACTOR
} else if rolling_quality < IFT_QUALITY_SHRINK_BAND {
current_cap
} else {
current_cap * IFT_STEP_CAP_SHRINK_FACTOR
};
state.next_step_cap = Some(next_step_cap);
state.fallback_next_flat = rolling_quality >= IFT_QUALITY_FLAT_FALLBACK_BAND;
Some(next_step_cap)
}
fn reset_hypergradient_budget_controller(&self) {
let mut budgets = hypergradient_budgets().lock().unwrap();
budgets.remove(&self.hypergradient_owner_key());
}
fn hypergradient_trace_state(&self) -> Arc<Mutex<super::unified::StochasticTraceState>> {
let mut budgets = hypergradient_budgets().lock().unwrap();
let state = budgets
.entry(self.hypergradient_owner_key())
.or_insert_with(HyperGradientRuntimeState::new);
Arc::clone(&state.trace_state)
}
fn reset_hypergradient_trace_telemetry(
trace_state: &Arc<Mutex<super::unified::StochasticTraceState>>,
) {
let mut trace = match trace_state.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
trace.last_linear_residual_norm = None;
trace.last_probe_sigma_sq = None;
trace.last_probe_count = 0;
}
fn hypergradient_adaptive_kkt_override(
&self,
pirls_config: &pirls::PirlsConfig,
) -> Option<pirls::AdaptiveKktTolerance> {
let budgets = hypergradient_budgets().lock().unwrap();
let tau = budgets
.get(&self.hypergradient_owner_key())?
.adaptive_kkt_override?;
if !tau.is_finite() || tau <= 0.0 {
return None;
}
let ceiling = pirls_config.convergence_tolerance;
let floor =
(self.config.reml_convergence_tolerance / ADAPTIVE_KKT_FLOOR_REML_DIVISOR).min(ceiling);
if !(floor > 0.0 && ceiling >= floor) {
return None;
}
let tau = tau.clamp(floor, ceiling);
Some(pirls::AdaptiveKktTolerance {
eta: 1.0,
floor: tau,
ceiling: tau,
outer_grad_norm: tau,
})
}
fn update_hypergradient_budget_after_outer_eval(
&self,
rho: &Array1<f64>,
gradient: &Array1<f64>,
ift_residual_energy: Option<f64>,
) {
if rho.iter().any(|v| !v.is_finite()) || gradient.iter().any(|v| !v.is_finite()) {
return;
}
let mut budgets = hypergradient_budgets().lock().unwrap();
let state = budgets
.entry(self.hypergradient_owner_key())
.or_insert_with(HyperGradientRuntimeState::new);
let (e_linear, sigma_sq, k, current_floor) = {
let trace = match state.trace_state.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let linear_residual_norm = trace.last_linear_residual_norm.unwrap_or(0.0).max(0.0);
(
0.5 * linear_residual_norm * linear_residual_norm,
trace.last_probe_sigma_sq.unwrap_or(0.0).max(0.0),
trace.last_probe_count,
trace.monotone_probe_floor,
)
};
let e_inner = ift_residual_energy.unwrap_or(0.0).max(0.0);
state.budget.push(HyperGradHistoryEntry {
rho: rho.clone(),
g_outer: gradient.clone(),
e_inner,
e_linear,
sigma_sq,
k,
});
let sensitivity_estimate = state.budget.reestimate_sensitivities();
let sensitivity_stable = state.budget.sensitivities_stable();
let force_engage = state.budget.history.len() >= HGB_WARMUP_ITERS_MAX;
if !state.budget.warmup_engaged
&& state.budget.history.len() >= HGB_WARMUP_ITERS_MIN
&& (sensitivity_stable || force_engage)
{
if sensitivity_stable {
log::info!(
"[HGB] engage after {} iters (sensitivity stable)",
state.budget.history.len()
);
} else {
log::info!(
"[HGB] engage after {} iters (max warmup reached)",
state.budget.history.len()
);
}
state.budget.warmup_engaged = true;
}
if !state.budget.warmup_engaged {
state.adaptive_kkt_override = None;
match state.trace_state.lock() {
Ok(mut trace) => trace.solve_rel_tol_override = None,
Err(poisoned) => {
let mut trace = poisoned.into_inner();
trace.solve_rel_tol_override = None;
}
}
return;
}
let [s_inner, s_linear, s_trace] = if let Some(sensitivities) = sensitivity_estimate {
sensitivities
} else {
log::warn!("[HGB] sensitivity_unavailable falling_back_to_per_channel");
[S_INNER_INIT, S_LINEAR_INIT, S_TRACE_INIT]
};
let previous_grad_norm = state.budget.previous_gradient_norm().max(1e-12);
state.budget.target_mse = (HGB_TARGET_FRACTION * previous_grad_norm).powi(2);
let (eps2_inner, eps2_linear, eps2_trace, floor_active) = state
.budget
.allocate_with_sensitivities(s_inner, s_linear, s_trace);
let ceiling = self.config.pirls_convergence_tolerance;
let pirls_floor =
(self.config.reml_convergence_tolerance / ADAPTIVE_KKT_FLOOR_REML_DIVISOR).min(ceiling);
let tau_raw = eps2_inner.sqrt() / s_inner;
let tau_inner = if pirls_floor > 0.0 && ceiling >= pirls_floor {
tau_raw.clamp(pirls_floor, ceiling)
} else {
tau_raw
}
.max(0.0);
state.adaptive_kkt_override =
(tau_inner.is_finite() && tau_inner > 0.0).then_some(tau_inner);
let rel_tol = eps2_linear.sqrt() / s_linear;
let k_target = if eps2_trace > 0.0 && sigma_sq.is_finite() && sigma_sq > 0.0 {
(sigma_sq / eps2_trace).ceil().clamp(0.0, usize::MAX as f64) as usize
} else {
current_floor
};
let raised_floor = current_floor.max(k_target);
{
let mut trace = match state.trace_state.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
trace.solve_rel_tol_override =
(rel_tol.is_finite() && rel_tol > 0.0).then_some(rel_tol);
if raised_floor > trace.monotone_probe_floor {
trace.monotone_probe_floor = raised_floor;
}
}
let active = ["i", "l", "t"]
.iter()
.zip(floor_active.iter())
.filter_map(|(name, active)| active.then_some(*name))
.collect::<Vec<_>>()
.join(",");
log::info!(
"[HGB] target_mse={:.3e} s_i={:.3e} s_l={:.3e} s_t={:.3e} eps²_i={:.3e} eps²_l={:.3e} eps²_t={:.3e} τ={:.3e} rtol={:.3e} k={} floor_active=[{}]",
state.budget.target_mse,
s_inner,
s_linear,
s_trace,
eps2_inner,
eps2_linear,
eps2_trace,
tau_inner,
rel_tol,
k_target,
active,
);
}
fn apply_inner_polish_step_to_warm_start(
&self,
bundle: &EvalShared,
solution_beta: &Array1<f64>,
polish_step: &Array1<f64>,
) {
if !self.warm_start_enabled.load(Ordering::Relaxed)
|| solution_beta.len() != polish_step.len()
{
return;
}
let polish_norm_squared = polish_step.dot(polish_step);
let beta_norm_squared = solution_beta.dot(solution_beta);
if !polish_norm_squared.is_finite()
|| !beta_norm_squared.is_finite()
|| polish_norm_squared > Self::POLISH_NORM_RATIO * beta_norm_squared
{
log::info!(
"[POLISH-SKIP] reason=large_step polish_norm²={} beta_norm²={}",
polish_norm_squared,
beta_norm_squared
);
return;
}
let polished_solution_beta = solution_beta - polish_step;
let pirls_result = bundle.pirls_result.as_ref();
let beta_original = match pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => {
if self.active_constraint_free_basis(pirls_result).is_some() {
return;
}
polished_solution_beta
}
pirls::PirlsCoordinateFrame::TransformedQs => {
if self.active_constraint_free_basis(pirls_result).is_some()
|| polished_solution_beta.len() != self.p
{
return;
}
polished_solution_beta
}
};
if beta_original.len() != self.p || beta_original.iter().any(|v| !v.is_finite()) {
return;
}
self.warm_start_beta
.write()
.unwrap()
.replace(Coefficients::new(beta_original.clone()));
if self.ift_warm_start_cache.read().unwrap().is_some() {
let lambda_s_beta_blocks = {
use rayon::prelude::*;
let blocks: Vec<ndarray::Array1<f64>> = self
.canonical_penalties
.par_iter()
.map(|cp| {
let r = &cp.col_range;
let beta_block = beta_original.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
cp.local.dot(¢ered)
})
.collect();
(!blocks.is_empty()).then_some(blocks)
};
if let Some(cache) = self.ift_warm_start_cache.write().unwrap().as_mut() {
cache.beta_original = beta_original.clone();
cache.lambda_s_beta_blocks = lambda_s_beta_blocks;
}
}
}
#[inline]
fn large_n_efs_single_loop_lane(&self) -> bool {
(self.x.nrows() as f64) * (self.x.ncols() as f64) > LARGE_N_EFS_THRESHOLD
}
#[inline]
fn efs_single_loop_cap_active(&self) -> bool {
decode_efs_single_loop_cap(self.outer_inner_cap.load(Ordering::Relaxed)).is_some()
}
fn record_efs_single_loop_bias(
&self,
rho: &Array1<f64>,
diagnostics: super::unified::EfsSingleLoopDiagnostics,
) -> Result<(), EstimationError> {
if !self.efs_single_loop_cap_active() {
return Ok(());
}
let owner = self as *const _ as usize;
let mut state = EFS_SINGLE_LOOP_BIAS_GUARD.lock().unwrap();
if state.owner != owner {
state.owner = owner;
state.consecutive = 0;
}
if diagnostics.bias_proxy >= EFS_SINGLE_LOOP_BIAS_THRESHOLD {
state.consecutive = state.consecutive.saturating_add(1);
} else {
state.consecutive = 0;
}
log::info!(
"[EFS-single-loop] bias_proxy={:.3e} gradient_residual={:.3e} inner_residual={:.3e} \
|g|={:.3e} |step|inf={:.3e} consecutive={}/{} rho[..4]=[{}]",
diagnostics.bias_proxy,
diagnostics.gradient_residual,
diagnostics.inner_residual,
diagnostics.gradient_norm,
diagnostics.step_inf_norm,
state.consecutive,
EFS_SINGLE_LOOP_BIAS_CONSECUTIVE_LIMIT,
rho.iter()
.take(4)
.map(|v| format!("{v:.3}"))
.collect::<Vec<_>>()
.join(","),
);
if state.consecutive >= EFS_SINGLE_LOOP_BIAS_CONSECUTIVE_LIMIT {
state.consecutive = 0;
return Err(EstimationError::RemlOptimizationFailed(format!(
"{} EFS single-loop bias guard fired: bias_proxy={:.3e} \
threshold={:.3e} consecutive_limit={} rho_dim={}",
crate::solver::outer_strategy::EFS_FIRST_ORDER_FALLBACK_MARKER,
diagnostics.bias_proxy,
EFS_SINGLE_LOOP_BIAS_THRESHOLD,
EFS_SINGLE_LOOP_BIAS_CONSECUTIVE_LIMIT,
rho.len(),
)));
}
Ok(())
}
pub(crate) fn analytic_outer_hessian_enabled(&self) -> bool {
if self.large_n_efs_single_loop_lane() {
log::info!(
"[EFS-single-loop] large-n lane engaged: n={} p={} n*p={:.3e} threshold={:.3e}; \
declining analytic outer Hessian so the EFS fixed-point route runs first",
self.x.nrows(),
self.x.ncols(),
(self.x.nrows() as f64) * (self.x.ncols() as f64),
LARGE_N_EFS_THRESHOLD,
);
return false;
}
let n_obs = self.x.nrows();
let p_dim = self.x.ncols();
let k_outer = self.canonical_penalties.len();
let operator_path_available =
super::unified::prefer_outer_hessian_operator(n_obs, p_dim, k_outer);
if n_obs > 50_000 && !operator_path_available {
log::info!(
"[standard-GAM] declining analytic outer Hessian for \
n={n_obs} p={p_dim} k={k_outer} (matrix-free operator \
path unavailable, dense LAML pairwise assembly is \
O(k²·n·p²)); routing to BFGS"
);
return false;
}
if reml_robust_jeffreys_link(&self.config).is_some()
&& !self.tk_correction_is_canonical_logit()
{
return false;
}
true
}
fn tk_correction_is_canonical_logit(&self) -> bool {
let spec = reml_spec(&self.config.likelihood);
matches!(spec.response, ResponseFamily::Binomial)
&& matches!(spec.link, InverseLink::Standard(StandardLink::Logit))
&& self.runtime_mixture_link_state.is_none()
}
pub(super) fn sparse_exact_beta_original(&self, pirls_result: &PirlsResult) -> Array1<f64> {
match pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => {
pirls_result.beta_transformed.as_ref().clone()
}
pirls::PirlsCoordinateFrame::TransformedQs => pirls_result
.reparam_result
.qs
.dot(pirls_result.beta_transformed.as_ref()),
}
}
fn bundle_matrix_in_original_basis(
&self,
pirls_result: &PirlsResult,
matrix: &Array2<f64>,
) -> Array2<f64> {
match pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => matrix.clone(),
pirls::PirlsCoordinateFrame::TransformedQs => {
let qs = &pirls_result.reparam_result.qs;
let tmp = crate::faer_ndarray::fast_ab(qs, matrix);
crate::faer_ndarray::fast_abt(&tmp, qs)
}
}
}
pub(crate) fn last_ridge_used(&self) -> Option<f64> {
self.cache_manager
.current_eval_bundle
.read()
.unwrap()
.as_ref()
.map(|bundle| bundle.ridge_passport.delta)
}
fn dense_penalty_logdet_derivs(
&self,
rho: &Array1<f64>,
e_for_logdet: &Array2<f64>,
penalty_roots: &[Array2<f64>],
ridge_passport: RidgePassport,
penalty_subspace: Option<&PenaltySubspace>,
mode: super::unified::EvalMode,
) -> Result<(usize, super::unified::PenaltyLogdetDerivs), EstimationError> {
let logdet_s_start = std::time::Instant::now();
let lambdas = rho.mapv(f64::exp);
let ridge = ridge_passport.penalty_logdet_ridge();
let kron_logdet = self
.kronecker_penalty_system
.as_ref()
.filter(|kron| self.kronecker_factored.is_some() && kron.num_penalties() == rho.len())
.map(|kron| kron.logdet_rank_and_derivatives(lambdas.as_slice().unwrap(), ridge));
let (penalty_rank, log_det_s) = if let Some((logdet, rank, _, _)) = kron_logdet.as_ref() {
(*rank, *logdet)
} else {
let owned_subspace;
let subspace = if let Some(penalty_subspace) = penalty_subspace {
penalty_subspace
} else {
owned_subspace = self.compute_penalty_subspace(e_for_logdet, ridge_passport)?;
&owned_subspace
};
self.fixed_subspace_penalty_rank_and_logdet_from_subspace(subspace)
};
log::info!(
"[STAGE] logdet S rho_dim={} penalty_rank={} elapsed={:.3}s",
rho.len(),
penalty_rank,
logdet_s_start.elapsed().as_secs_f64(),
);
let (det1, det2_full) = if let Some((_, _, det1, det2)) = kron_logdet {
(det1, det2)
} else if !self.canonical_penalties.is_empty()
&& self.canonical_penalties.len() == rho.len()
{
self.structural_penalty_logdet_derivatives_block_local(&lambdas, ridge)?
} else if !penalty_roots.is_empty() {
self.structural_penalty_logdet_derivatives(penalty_roots, &lambdas, ridge)?
} else {
(
Array1::zeros(rho.len()),
Array2::zeros((rho.len(), rho.len())),
)
};
let det2 = if mode == super::unified::EvalMode::ValueGradientHessian {
Some(det2_full)
} else {
None
};
Ok((
penalty_rank,
super::unified::PenaltyLogdetDerivs {
value: log_det_s,
first: det1,
second: det2,
},
))
}
fn tk_shared_intermediates<S>(
x_dense: &Array2<f64>,
z: &Array2<f64>,
c_array: &Array1<f64>,
context: &str,
h_inv_solve: &S,
) -> Result<TkSharedIntermediates, EstimationError>
where
S: Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError>,
{
let n = x_dense.nrows();
let active_blocks = Self::tk_active_blocks(c_array);
use rayon::prelude::*;
let h_diag_vec: Vec<f64> = (0..n)
.into_par_iter()
.map(|i| {
let val = x_dense.row(i).dot(&z.column(i));
if !val.is_finite() {
crate::bail_invalid_estim!(
"{context} produced non-finite leverage at row {i}: {val}"
);
}
Ok(val)
})
.collect::<Result<_, _>>()?;
let h_diag = Array1::from(h_diag_vec);
let m_vec = c_array * &h_diag;
let x_m = crate::faer_ndarray::fast_atv(x_dense, &m_vec);
let y = h_inv_solve(&x_m)?;
Ok(TkSharedIntermediates {
h_diag,
x_m,
y,
active_blocks,
})
}
fn tk_scalar_from_shared(
x_dense: &Array2<f64>,
z: &Array2<f64>,
d_array: &Array1<f64>,
shared: &TkSharedIntermediates,
gram: &mut Array2<f64>,
) -> Result<f64, EstimationError> {
let q_term = -0.125
* d_array
.iter()
.zip(shared.h_diag.iter())
.map(|(&d_i, &h_i)| d_i * h_i * h_i)
.sum::<f64>();
let t2_term = 0.125 * shared.x_m.dot(&shared.y);
let mut t1_sum = 0.0_f64;
for (j_block_idx, j_block) in shared.active_blocks.iter().enumerate() {
let j0 = j_block.start;
let j1 = j_block.end;
for i_block in &shared.active_blocks[..=j_block_idx] {
let i0 = i_block.start;
let i1 = i_block.end;
Self::tk_fill_gram_block(x_dense, z, i0, i1, j0, j1, gram);
let mut block_sum = 0.0_f64;
for &(bi, ci) in &i_block.entries {
for &(bj, cj) in &j_block.entries {
let kij = gram[[bi, bj]];
block_sum += ci * cj * kij * kij * kij;
}
}
t1_sum += if i0 == j0 { block_sum } else { 2.0 * block_sum };
}
}
let value = q_term + t1_sum / 12.0 + t2_term;
if !value.is_finite() {
crate::bail_invalid_estim!(
"Tierney-Kadane correction produced non-finite value: {value}"
);
}
Ok(value)
}
fn tk_active_blocks(c_array: &Array1<f64>) -> Vec<TkActiveBlock> {
let n = c_array.len();
let mut blocks = Vec::with_capacity(n.div_ceil(TK_BLOCK_SIZE));
for start in (0..n).step_by(TK_BLOCK_SIZE) {
let end = (start + TK_BLOCK_SIZE).min(n);
let entries = c_array
.slice(s![start..end])
.iter()
.enumerate()
.filter_map(|(offset, &value)| (value != 0.0).then_some((offset, value)))
.collect::<Vec<_>>();
if !entries.is_empty() {
blocks.push(TkActiveBlock {
start,
end,
entries,
});
}
}
blocks
}
fn tk_active_weighted_trace(
active_blocks: &[TkActiveBlock],
x_vk: &Array1<f64>,
lev_p: &Array1<f64>,
) -> f64 {
let mut trace = 0.0;
for block in active_blocks {
for &(offset, c_i) in &block.entries {
let i = block.start + offset;
trace += c_i * x_vk[i] * lev_p[i];
}
}
trace
}
fn tk_fill_gram_block(
x_dense: &Array2<f64>,
z: &Array2<f64>,
i0: usize,
i1: usize,
j0: usize,
j1: usize,
gram: &mut Array2<f64>,
) {
let rows = i1 - i0;
let cols = j1 - j0;
assert!(rows <= gram.nrows());
assert!(cols <= gram.ncols());
let x_block = x_dense.slice(s![i0..i1, ..]);
let z_block = z.slice(s![.., j0..j1]);
let mut target = gram.slice_mut(s![..rows, ..cols]);
ndarray::linalg::general_mat_mul(1.0, &x_block, &z_block, 0.0, &mut target);
}
fn tk_fill_gram_block_entries_scalar(
x_dense: &Array2<f64>,
z: &Array2<f64>,
i_block: &TkActiveBlock,
j_block: &TkActiveBlock,
gram: &mut Array2<f64>,
) {
let p = x_dense.ncols();
for &(bi, _) in &i_block.entries {
let ii = i_block.start + bi;
for &(bj, _) in &j_block.entries {
let jj = j_block.start + bj;
gram[[bi, bj]] = (0..p)
.map(|col| x_dense[[ii, col]] * z[[col, jj]])
.sum::<f64>();
}
}
}
fn tk_gradient_from_shared(
x_dense: &Array2<f64>,
z: &Array2<f64>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
e_array: &Array1<f64>,
tk_penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
ext_drifts: &[Array2<f64>],
ext_eta_fixed: &[Option<Array1<f64>>],
ext_x_fixed: &[Option<Array2<f64>>],
x_vks: &[Array1<f64>],
beta_dirs: &[Array1<f64>],
firth_op: Option<&super::FirthDenseOperator>,
shared: &TkSharedIntermediates,
gram: &mut Array2<f64>,
) -> Result<Array1<f64>, EstimationError> {
let n = x_dense.nrows();
let p = x_dense.ncols();
let k = tk_penalties.len();
let total_k = k + ext_drifts.len();
if x_vks.len() != total_k {
crate::bail_invalid_estim!(
"Tierney-Kadane correction internal gradient arity mismatch: {} response modes for {} coordinates",
x_vks.len(),
total_k
);
}
if beta_dirs.len() != total_k {
crate::bail_invalid_estim!(
"Tierney-Kadane correction internal beta-direction arity mismatch: {} beta directions for {} coordinates",
beta_dirs.len(),
total_k
);
}
let x_y = crate::faer_ndarray::fast_av(x_dense, &shared.y);
let mut diag_combined = Array1::<f64>::zeros(n);
ndarray::Zip::from(&mut diag_combined)
.and(d_array)
.and(&shared.h_diag)
.and(c_array)
.and(&x_y)
.par_for_each(|o, &d, &h, &c, &xy| *o = d * h - c * xy);
let chunk_len = (n
/ (rayon::current_num_threads()
.saturating_mul(TK_CHUNK_OVERSUBSCRIBE)
.max(1)))
.clamp(TK_BLOCK_SIZE, TK_CHUNK_MAX_ROWS);
let chunks = n.div_ceil(chunk_len);
let mut p_total = (0..chunks)
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((p, p)),
|mut local, chunk_idx| {
let i0 = chunk_idx * chunk_len;
let i1 = (i0 + chunk_len).min(n);
for i in i0..i1 {
let wi = diag_combined[i];
if wi == 0.0 {
continue;
}
for a in 0..p {
let wa = wi * z[[a, i]];
for b in a..p {
let val = wa * z[[b, i]];
local[[a, b]] += val;
if a != b {
local[[b, a]] += val;
}
}
}
}
local
},
)
.reduce(
|| Array2::<f64>::zeros((p, p)),
|mut left, right| {
left += &right;
left
},
);
p_total.mapv_inplace(|v| 0.25 * v);
for a in 0..p {
for b in 0..p {
p_total[[a, b]] -= 0.125 * shared.y[a] * shared.y[b];
}
}
let active_pairs: Vec<(usize, usize)> = (0..shared.active_blocks.len())
.flat_map(|j_block_idx| {
(0..=j_block_idx).map(move |i_block_idx| (i_block_idx, j_block_idx))
})
.collect();
let active_total = active_pairs
.par_iter()
.fold(
|| Array2::<f64>::zeros((p, p)),
|mut local, &(i_block_idx, j_block_idx)| {
let i_block = &shared.active_blocks[i_block_idx];
let j_block = &shared.active_blocks[j_block_idx];
let sym_factor = if i_block_idx == j_block_idx { 1.0 } else { 2.0 };
for &(bi, ci) in &i_block.entries {
let ii = i_block.start + bi;
for &(bj, cj) in &j_block.entries {
let jj = j_block.start + bj;
let gij = (0..p)
.map(|col| x_dense[[ii, col]] * z[[col, jj]])
.sum::<f64>();
let weight = ci * cj * gij * gij;
let scale = -0.25 * weight * sym_factor;
if ii == jj {
for a in 0..p {
let za = z[[a, ii]];
for b in a..p {
let val = scale * za * z[[b, ii]];
local[[a, b]] += val;
if a != b {
local[[b, a]] += val;
}
}
}
} else {
let half_scale = 0.5 * scale;
for a in 0..p {
let z_ii_a = z[[a, ii]];
let z_jj_a = z[[a, jj]];
for b in 0..p {
local[[a, b]] += half_scale
* (z_ii_a * z[[b, jj]] + z_jj_a * z[[b, ii]]);
}
}
}
}
}
local
},
)
.reduce(
|| Array2::<f64>::zeros((p, p)),
|mut left, right| {
left += &right;
left
},
);
p_total += &active_total;
let xp = crate::faer_ndarray::fast_ab(x_dense, &p_total);
let mut lev_p = Array1::<f64>::zeros(n);
ndarray::Zip::from(&mut lev_p)
.and(xp.rows())
.and(x_dense.rows())
.par_for_each(|o, xp_row, x_row| *o = xp_row.dot(&x_row));
let mut gradient = Array1::<f64>::zeros(total_k);
for idx in 0..k {
let cp = &tk_penalties[idx];
let r = &cp.col_range;
let p_block = p_total.slice(s![r.start..r.end, r.start..r.end]);
let rk_p = cp.root.dot(&p_block);
let trace_ak_p = lambdas[idx]
* (0..cp.rank())
.map(|row| rk_p.row(row).dot(&cp.root.row(row)))
.sum::<f64>();
let correction_trace =
Self::tk_active_weighted_trace(&shared.active_blocks, &x_vks[idx], &lev_p);
let firth_trace =
Self::tk_firth_beta_hessian_trace(firth_op, &beta_dirs[idx], &p_total)?;
let eta_total = x_vks[idx].mapv(|value| -value);
let direct = Self::tk_direct_gradient_from_cd_and_design(
x_dense, z, c_array, d_array, e_array, &eta_total, None, shared, gram, true,
)?;
gradient[idx] = trace_ak_p - correction_trace + firth_trace + direct;
}
let ext_values = (0..ext_drifts.len())
.into_par_iter()
.map(|extra_idx| -> Result<(usize, f64), EstimationError> {
let drift = &ext_drifts[extra_idx];
if drift.raw_dim() != p_total.raw_dim() {
crate::bail_invalid_estim!(
"Tierney-Kadane ext penalty drift shape mismatch: expected {}x{}, got {}x{}",
p,
p,
drift.nrows(),
drift.ncols()
);
}
let mut trace_ak_p = 0.0;
for row in 0..p {
for col in 0..p {
trace_ak_p += drift[[row, col]] * p_total[[col, row]];
}
}
let x_vk_idx = k + extra_idx;
let correction_trace =
Self::tk_active_weighted_trace(&shared.active_blocks, &x_vks[x_vk_idx], &lev_p);
let firth_trace =
Self::tk_firth_beta_hessian_trace(firth_op, &beta_dirs[x_vk_idx], &p_total)?;
let mut eta_total = x_vks[x_vk_idx].mapv(|value| -value);
if let Some(eta_fixed) = ext_eta_fixed
.get(extra_idx)
.and_then(|value| value.as_ref())
{
if eta_fixed.len() != n {
crate::bail_invalid_estim!(
"Tierney-Kadane ext fixed eta length mismatch: expected {}, got {}",
n,
eta_fixed.len()
);
}
eta_total += eta_fixed;
}
let x_fixed = ext_x_fixed.get(extra_idx).and_then(|value| value.as_ref());
if let Some(x_fixed) = x_fixed
&& x_fixed.raw_dim() != x_dense.raw_dim() {
crate::bail_invalid_estim!(
"Tierney-Kadane ext fixed design shape mismatch: expected {}x{}, got {}x{}",
x_dense.nrows(),
x_dense.ncols(),
x_fixed.nrows(),
x_fixed.ncols()
);
}
let mut local_gram = Array2::<f64>::zeros((TK_BLOCK_SIZE, TK_BLOCK_SIZE));
let direct = Self::tk_direct_gradient_from_cd_and_design(
x_dense,
z,
c_array,
d_array,
e_array,
&eta_total,
x_fixed,
shared,
&mut local_gram,
false,
)?;
Ok((x_vk_idx, trace_ak_p - correction_trace + firth_trace + direct))
})
.collect::<Result<Vec<_>, _>>()?;
for (idx, value) in ext_values {
gradient[idx] = value;
}
for g in gradient.iter_mut() {
if !g.is_finite() {
crate::bail_invalid_estim!(
"Tierney-Kadane correction produced a non-finite gradient entry"
);
}
}
Ok(gradient)
}
fn tk_firth_beta_hessian_trace(
firth_op: Option<&super::FirthDenseOperator>,
beta_dir: &Array1<f64>,
p_total: &Array2<f64>,
) -> Result<f64, EstimationError> {
let Some(firth_op) = firth_op else {
return Ok(0.0);
};
if beta_dir.len() != p_total.nrows() {
crate::bail_invalid_estim!(
"Tierney-Kadane Firth beta-direction length mismatch: expected {}, got {}",
p_total.nrows(),
beta_dir.len()
);
}
let deta = crate::faer_ndarray::fast_av(&firth_op.x_dense, beta_dir);
let dir = firth_op.direction_from_deta(deta);
let hphi = firth_op.hphi_direction(&dir);
if hphi.raw_dim() != p_total.raw_dim() {
crate::bail_invalid_estim!(
"Tierney-Kadane Firth Hessian derivative shape mismatch: expected {}x{}, got {}x{}",
p_total.nrows(),
p_total.ncols(),
hphi.nrows(),
hphi.ncols()
);
}
let mut trace = 0.0;
for row in 0..hphi.nrows() {
for col in 0..hphi.ncols() {
trace -= hphi[[row, col]] * p_total[[col, row]];
}
}
Ok(trace)
}
fn tk_direct_gradient_from_cd_and_design(
x_dense: &Array2<f64>,
z: &Array2<f64>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
e_array: &Array1<f64>,
eta_total: &Array1<f64>,
x_fixed: Option<&Array2<f64>>,
shared: &TkSharedIntermediates,
gram: &mut Array2<f64>,
use_dense_kernels: bool,
) -> Result<f64, EstimationError> {
let n = x_dense.nrows();
if eta_total.len() != n || e_array.len() != n {
crate::bail_invalid_estim!(
"Tierney-Kadane direct derivative length mismatch: n={}, eta={}, e={}",
n,
eta_total.len(),
e_array.len()
);
}
let mut c_prime = Array1::<f64>::zeros(n);
let mut d_prime = Array1::<f64>::zeros(n);
ndarray::Zip::from(&mut c_prime)
.and(d_array)
.and(eta_total)
.par_for_each(|out, &d, &eta| *out = d * eta);
ndarray::Zip::from(&mut d_prime)
.and(e_array)
.and(eta_total)
.par_for_each(|out, &e, &eta| *out = e * eta);
let mut h_prime = Array1::<f64>::zeros(n);
let mut design_q_prime = Array1::<f64>::zeros(x_dense.ncols());
let has_design_deriv = x_fixed.is_some();
if let Some(x_theta) = x_fixed {
let ch = c_array * &shared.h_diag;
design_q_prime += &crate::faer_ndarray::fast_atv(x_theta, &ch);
ndarray::Zip::from(&mut h_prime)
.and(x_theta.rows())
.and(z.columns())
.par_for_each(|o, xr, zc| *o = 2.0 * xr.dot(&zc));
}
let q_weight_prime = &(&c_prime * &shared.h_diag) + &(c_array * &h_prime);
let q_prime = crate::faer_ndarray::fast_atv(x_dense, &q_weight_prime) + design_q_prime;
let q_term_prime = 0.25 * q_prime.dot(&shared.y);
let d_term_prime = -0.125
* d_prime
.iter()
.zip(shared.h_diag.iter())
.map(|(&dp, &h)| dp * h * h)
.sum::<f64>()
- 0.25
* d_array
.iter()
.zip(shared.h_diag.iter())
.zip(h_prime.iter())
.map(|((&d, &h), &hp)| d * h * hp)
.sum::<f64>();
let direct_blocks = Self::tk_cd_direct_active_blocks(c_array, &c_prime);
let mut c_term_prime = 0.0_f64;
let mut block_scratch = Array2::<f64>::zeros((TK_BLOCK_SIZE, TK_BLOCK_SIZE));
let mut reverse_scratch = Array2::<f64>::zeros((TK_BLOCK_SIZE, TK_BLOCK_SIZE));
for (j_block_idx, j_block) in direct_blocks.iter().enumerate() {
let j0 = j_block.start;
let j1 = j_block.end;
for i_block in &direct_blocks[..=j_block_idx] {
let i0 = i_block.start;
let i1 = i_block.end;
if use_dense_kernels {
Self::tk_fill_gram_block(x_dense, z, i0, i1, j0, j1, gram);
} else {
Self::tk_fill_gram_block_entries_scalar(x_dense, z, i_block, j_block, gram);
}
let design_gram_active = if has_design_deriv && use_dense_kernels {
let x_theta = x_fixed.expect("design derivative checked above");
let rows = i1 - i0;
let cols = j1 - j0;
let mut block = block_scratch.slice_mut(s![..rows, ..cols]);
let x_theta_i = x_theta.slice(s![i0..i1, ..]);
let z_j = z.slice(s![.., j0..j1]);
ndarray::linalg::general_mat_mul(1.0, &x_theta_i, &z_j, 0.0, &mut block);
let mut reverse = reverse_scratch.slice_mut(s![..cols, ..rows]);
let x_theta_j = x_theta.slice(s![j0..j1, ..]);
let z_i = z.slice(s![.., i0..i1]);
ndarray::linalg::general_mat_mul(1.0, &x_theta_j, &z_i, 0.0, &mut reverse);
true
} else {
false
};
let mut block_sum = 0.0_f64;
for &(bi, _) in &i_block.entries {
let ii = i0 + bi;
let ci = c_array[ii];
let cpi = c_prime[ii];
for &(bj, _) in &j_block.entries {
let jj = j0 + bj;
let cj = c_array[jj];
let gij = gram[[bi, bj]];
let cpj = c_prime[jj];
let c_direct = (cpi * cj + ci * cpj) * gij * gij * gij / 12.0;
let k_direct = if design_gram_active {
let kp = block_scratch[[bi, bj]] + reverse_scratch[[bj, bi]];
0.25 * ci * cj * gij * gij * kp
} else if let Some(x_theta) = x_fixed {
let kp = (0..x_dense.ncols())
.map(|col| {
x_theta[[ii, col]] * z[[col, jj]]
+ x_theta[[jj, col]] * z[[col, ii]]
})
.sum::<f64>();
0.25 * ci * cj * gij * gij * kp
} else {
0.0
};
block_sum += c_direct + k_direct;
}
}
let sym_factor = if i0 == j0 { 1.0 } else { 2.0 };
c_term_prime += sym_factor * block_sum;
}
}
let value = d_term_prime + c_term_prime + q_term_prime;
if !value.is_finite() {
crate::bail_invalid_estim!(
"Tierney-Kadane direct c/d derivative produced non-finite value: {value}"
);
}
Ok(value)
}
fn tk_cd_direct_active_blocks(
c_array: &Array1<f64>,
c_prime: &Array1<f64>,
) -> Vec<TkActiveBlock> {
let n = c_array.len();
let mut blocks = Vec::with_capacity(n.div_ceil(TK_BLOCK_SIZE));
for start in (0..n).step_by(TK_BLOCK_SIZE) {
let end = (start + TK_BLOCK_SIZE).min(n);
let mut entries = Vec::new();
for offset in 0..(end - start) {
let idx = start + offset;
if c_array[idx] != 0.0 || c_prime[idx] != 0.0 {
entries.push((offset, 0.0));
}
}
if !entries.is_empty() {
blocks.push(TkActiveBlock {
start,
end,
entries,
});
}
}
blocks
}
fn tk_penalty_dense(
cp: &crate::construction::CanonicalPenalty,
lambda: f64,
p: usize,
) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((p, p));
let r = &cp.col_range;
for a in 0..cp.block_dim() {
for b in 0..cp.block_dim() {
let mut val = 0.0;
for row in 0..cp.rank() {
val += cp.root[[row, a]] * cp.root[[row, b]];
}
out[[r.start + a, r.start + b]] = lambda * val;
}
}
out
}
fn tk_xt_diag_x(x_dense: &Array2<f64>, diag: &Array1<f64>) -> Array2<f64> {
let mut weighted = Array2::<f64>::zeros(x_dense.raw_dim());
Self::xt_diag_x_dense_into(x_dense, diag, &mut weighted)
}
fn tk_hessian_rho_canonical_logit<S>(
x_dense: &Array2<f64>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
e_array: &Array1<f64>,
f_array: &Array1<f64>,
tk_penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
beta: &Array1<f64>,
firth_op: Option<&super::FirthDenseOperator>,
h_inv_solve: &S,
) -> Result<Array2<f64>, EstimationError>
where
S: Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError>,
{
let n = x_dense.nrows();
let p = x_dense.ncols();
let k = tk_penalties.len();
if k == 0 {
return Ok(Array2::zeros((0, 0)));
}
if c_array.len() != n || d_array.len() != n || e_array.len() != n || f_array.len() != n {
crate::bail_invalid_estim!(
"Tierney-Kadane Hessian derivative arrays have inconsistent lengths"
);
}
let mut k_mat = Array2::<f64>::zeros((p, p));
for col in 0..p {
let mut rhs = Array1::<f64>::zeros(p);
rhs[col] = 1.0;
let sol = h_inv_solve(&rhs)?;
k_mat.column_mut(col).assign(&sol);
}
enforce_symmetry(&mut k_mat);
let mut a_mats = Vec::with_capacity(k);
let mut v = Vec::with_capacity(k);
let mut eta_i = Vec::with_capacity(k);
for idx in 0..k {
let a = Self::tk_penalty_dense(&tk_penalties[idx], lambdas[idx], p);
let rhs = a.dot(beta);
let vi = h_inv_solve(&rhs)?;
let bi = vi.mapv(|value| -value);
let ei = crate::faer_ndarray::fast_av(x_dense, &bi);
a_mats.push(a);
v.push(vi);
eta_i.push(ei);
}
let mut h_i = Vec::with_capacity(k);
for idx in 0..k {
let diag = c_array * &eta_i[idx];
let mut h = &a_mats[idx] + &Self::tk_xt_diag_x(x_dense, &diag);
if let Some(op) = firth_op {
let dir = op.direction_from_deta(eta_i[idx].clone());
h -= &op.hphi_direction(&dir);
}
enforce_symmetry(&mut h);
h_i.push(h);
}
let mut beta_ij: Vec<Vec<Array1<f64>>> = (0..k)
.map(|_| (0..k).map(|_| Array1::<f64>::zeros(p)).collect())
.collect();
let mut h_ij: Vec<Vec<Array2<f64>>> = (0..k)
.map(|_| (0..k).map(|_| Array2::<f64>::zeros((p, p))).collect())
.collect();
for i in 0..k {
for j in 0..=i {
let mut rhs = h_i[j].dot(&v[i]);
rhs += &a_mats[i].dot(&v[j]);
if i == j {
rhs -= &a_mats[i].dot(beta);
}
let bij = h_inv_solve(&rhs)?;
beta_ij[i][j] = bij.clone();
beta_ij[j][i] = bij;
}
}
for i in 0..k {
for j in 0..=i {
let eta_ij = crate::faer_ndarray::fast_av(x_dense, &beta_ij[i][j]);
let diag = c_array * &eta_ij + &(d_array * &(&eta_i[i] * &eta_i[j]));
let mut h = Self::tk_xt_diag_x(x_dense, &diag);
if i == j {
h += &a_mats[i];
}
if let Some(op) = firth_op {
let dir_ij = op.direction_from_deta(eta_ij);
h -= &op.hphi_direction(&dir_ij);
let dir_i = op.direction_from_deta(eta_i[i].clone());
let dir_j = op.direction_from_deta(eta_i[j].clone());
let eye = Array2::<f64>::eye(p);
h -= &op.hphisecond_direction_apply(&dir_i, &dir_j, &eye);
}
enforce_symmetry(&mut h);
h_ij[i][j] = h.clone();
h_ij[j][i] = h;
}
}
let mut k_i = Vec::with_capacity(k);
for i in 0..k {
k_i.push(-k_mat.dot(&h_i[i]).dot(&k_mat));
}
let mut k_ij: Vec<Vec<Array2<f64>>> = (0..k)
.map(|_| (0..k).map(|_| Array2::<f64>::zeros((p, p))).collect())
.collect();
for i in 0..k {
for j in 0..=i {
let kij = k_mat.dot(&h_i[j]).dot(&k_mat).dot(&h_i[i]).dot(&k_mat)
+ k_mat.dot(&h_i[i]).dot(&k_mat).dot(&h_i[j]).dot(&k_mat)
- k_mat.dot(&h_ij[i][j]).dot(&k_mat);
k_ij[i][j] = kij.clone();
k_ij[j][i] = kij;
}
}
#[derive(Clone)]
struct Jet {
v: f64,
g: Array1<f64>,
h: Array2<f64>,
}
impl Jet {
fn constant(v: f64, k: usize) -> Self {
Self {
v,
g: Array1::zeros(k),
h: Array2::zeros((k, k)),
}
}
fn add(&self, other: &Self) -> Self {
Self {
v: self.v + other.v,
g: &self.g + &other.g,
h: &self.h + &other.h,
}
}
fn scale(&self, a: f64) -> Self {
Self {
v: self.v * a,
g: self.g.mapv(|x| x * a),
h: self.h.mapv(|x| x * a),
}
}
fn mul(&self, other: &Self) -> Self {
let k = self.g.len();
let mut h = &self.h * other.v + &other.h * self.v;
for i in 0..k {
for j in 0..k {
h[[i, j]] += self.g[i] * other.g[j] + other.g[i] * self.g[j];
}
}
Self {
v: self.v * other.v,
g: &self.g * other.v + &other.g * self.v,
h,
}
}
fn square(&self) -> Self {
self.mul(self)
}
fn cube(&self) -> Self {
self.mul(self).mul(self)
}
}
let mut hdiag = Vec::with_capacity(n);
for row in 0..n {
let x = x_dense.row(row);
let mut jet = Jet::constant(x.dot(&k_mat.dot(&x.to_owned())), k);
for a in 0..k {
jet.g[a] = x.dot(&k_i[a].dot(&x.to_owned()));
}
for a in 0..k {
for b in 0..k {
jet.h[[a, b]] = x.dot(&k_ij[a][b].dot(&x.to_owned()));
}
}
hdiag.push(jet);
}
let mut cjet = Vec::with_capacity(n);
let mut djet = Vec::with_capacity(n);
for row in 0..n {
let mut eta = Jet::constant(0.0, k);
for a in 0..k {
eta.g[a] = eta_i[a][row];
for b in 0..k {
eta.h[[a, b]] = x_dense.row(row).dot(&beta_ij[a][b]);
}
}
let mut c = Jet::constant(c_array[row], k);
let mut d = Jet::constant(d_array[row], k);
for a in 0..k {
c.g[a] = d_array[row] * eta.g[a];
d.g[a] = e_array[row] * eta.g[a];
for b in 0..k {
c.h[[a, b]] = d_array[row] * eta.h[[a, b]] + e_array[row] * eta.g[a] * eta.g[b];
d.h[[a, b]] = e_array[row] * eta.h[[a, b]] + f_array[row] * eta.g[a] * eta.g[b];
}
}
cjet.push(c);
djet.push(d);
}
let mut total = Jet::constant(0.0, k);
for row in 0..n {
total = total.add(&djet[row].mul(&hdiag[row].square()).scale(-0.125));
}
for irow in 0..n {
let xi = x_dense.row(irow).to_owned();
for jrow in 0..n {
let xj = x_dense.row(jrow).to_owned();
let mut kg = Jet::constant(xi.dot(&k_mat.dot(&xj)), k);
for a in 0..k {
kg.g[a] = xi.dot(&k_i[a].dot(&xj));
}
for a in 0..k {
for b in 0..k {
kg.h[[a, b]] = xi.dot(&k_ij[a][b].dot(&xj));
}
}
let term = cjet[irow]
.mul(&cjet[jrow])
.mul(&kg.cube())
.scale(1.0 / 12.0);
total = total.add(&term);
}
}
let mut qjets: Vec<Jet> = (0..p).map(|_| Jet::constant(0.0, k)).collect();
for row in 0..n {
let wh = cjet[row].mul(&hdiag[row]);
for col in 0..p {
qjets[col] = qjets[col].add(&wh.scale(x_dense[[row, col]]));
}
}
for a in 0..p {
for b in 0..p {
let mut kj = Jet::constant(k_mat[[a, b]], k);
for i in 0..k {
kj.g[i] = k_i[i][[a, b]];
}
for i in 0..k {
for j in 0..k {
kj.h[[i, j]] = k_ij[i][j][[a, b]];
}
}
total = total.add(&qjets[a].mul(&kj).mul(&qjets[b]).scale(0.125));
}
}
if total.h.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_estim!(
"Tierney-Kadane analytic Hessian produced a non-finite entry"
);
}
Ok(total.h)
}
fn tierney_kadane_analytic_core<S>(
&self,
x_dense: &Array2<f64>,
z: &Array2<f64>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
e_array: &Array1<f64>,
f_array: &Array1<f64>,
tk_penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
ext_coords: &[super::unified::HyperCoord],
beta: &Array1<f64>,
firth_op: Option<&super::FirthDenseOperator>,
compute_gradient: bool,
compute_hessian: bool,
h_inv_solve: &S,
) -> Result<TkCorrectionTerms, EstimationError>
where
S: Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError>,
{
let p = x_dense.ncols();
let k = tk_penalties.len();
let shared = Self::tk_shared_intermediates(
x_dense,
z,
c_array,
"Tierney-Kadane correction",
h_inv_solve,
)?;
let mut gram = Array2::<f64>::zeros((TK_BLOCK_SIZE, TK_BLOCK_SIZE));
let value = Self::tk_scalar_from_shared(x_dense, z, d_array, &shared, &mut gram)?;
if !compute_gradient {
return Ok(TkCorrectionTerms {
value,
gradient: None,
hessian: None,
});
}
let mut x_vks: Vec<Array1<f64>> = Vec::with_capacity(k + ext_coords.len());
let mut beta_dirs: Vec<Array1<f64>> = Vec::with_capacity(k + ext_coords.len());
for idx in 0..k {
let cp = &tk_penalties[idx];
let r = &cp.col_range;
let beta_block = beta.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
let r_beta = cp.root.dot(¢ered);
let mut s_k_beta = Array1::<f64>::zeros(p);
for a in 0..cp.block_dim() {
s_k_beta[r.start + a] = (0..cp.rank())
.map(|row| cp.root[[row, a]] * r_beta[row])
.sum::<f64>();
}
let a_k_beta = &s_k_beta * lambdas[idx];
let v_k = h_inv_solve(&a_k_beta)?;
x_vks.push(crate::faer_ndarray::fast_av(x_dense, &v_k));
beta_dirs.push(v_k.mapv(|value| -value));
}
let mut ext_drifts = Vec::with_capacity(ext_coords.len());
let mut ext_eta_fixed = Vec::with_capacity(ext_coords.len());
let mut ext_x_fixed = Vec::with_capacity(ext_coords.len());
for coord in ext_coords {
let drift = coord.drift.materialize();
if drift.ncols() != beta.len() || drift.nrows() != beta.len() {
crate::bail_invalid_estim!(
"Tierney-Kadane ext drift shape mismatch: expected {}x{}, got {}x{}",
beta.len(),
beta.len(),
drift.nrows(),
drift.ncols()
);
}
if coord.g.len() != beta.len() {
crate::bail_invalid_estim!(
"Tierney-Kadane ext mode RHS length mismatch: expected {}, got {}",
beta.len(),
coord.g.len()
);
}
let beta_theta = h_inv_solve(&coord.g)?;
x_vks.push(crate::faer_ndarray::fast_av(x_dense, &beta_theta));
beta_dirs.push(beta_theta.mapv(|value| -value));
ext_drifts.push(drift);
ext_eta_fixed.push(coord.tk_eta_fixed.clone());
ext_x_fixed.push(coord.tk_x_fixed.clone());
}
let gradient = Self::tk_gradient_from_shared(
x_dense,
z,
c_array,
d_array,
e_array,
tk_penalties,
lambdas,
&ext_drifts,
&ext_eta_fixed,
&ext_x_fixed,
&x_vks,
&beta_dirs,
firth_op,
&shared,
&mut gram,
)?;
let hessian = if compute_hessian {
Some(Self::tk_hessian_rho_canonical_logit(
x_dense,
c_array,
d_array,
e_array,
f_array,
tk_penalties,
lambdas,
beta,
firth_op,
h_inv_solve,
)?)
} else {
None
};
Ok(TkCorrectionTerms {
value,
gradient: Some(gradient),
hessian,
})
}
fn tierney_kadane_terms(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
ext_coords: &[super::unified::HyperCoord],
) -> Result<TkCorrectionTerms, EstimationError> {
if reml_is_gaussian_identity(&self.config.likelihood) {
return Ok(TkCorrectionTerms {
value: 0.0,
gradient: None,
hessian: None,
});
}
if reml_robust_jeffreys_link(&self.config).is_none() {
return Ok(TkCorrectionTerms {
value: 0.0,
gradient: None,
hessian: None,
});
}
if !self.tk_correction_is_canonical_logit() {
return Ok(TkCorrectionTerms {
value: 0.0,
gradient: None,
hessian: None,
});
}
let compute_gradient = compute_gradient_for_tk(mode);
let zero_correction = || TkCorrectionTerms {
value: 0.0,
gradient: if compute_gradient {
Some(Array1::zeros(rho.len() + ext_coords.len()))
} else {
None
},
hessian: if mode == super::unified::EvalMode::ValueGradientHessian {
Some(Array2::zeros((
rho.len() + ext_coords.len(),
rho.len() + ext_coords.len(),
)))
} else {
None
},
};
let n_x = self.x().nrows();
let p_x = self.x().ncols();
if !super::firth_problem_scale_allows(n_x, p_x) {
return Ok(zero_correction());
}
let dense_work = n_x.saturating_mul(p_x);
if n_x > TK_MAX_OBSERVATIONS || p_x > TK_MAX_COEFFICIENTS || dense_work > TK_MAX_DENSE_WORK
{
return Ok(zero_correction());
}
let pirls_result = bundle.pirls_result.as_ref();
let (c_array, d_array, e_array, f_array) = self.hessian_cdef_arrays(pirls_result)?;
if let Some(idx) = c_array.iter().position(|v| !v.is_finite()) {
crate::bail_invalid_estim!(
"Tierney-Kadane correction received non-finite c derivative at row {idx}: {}",
c_array[idx]
);
}
if let Some(idx) = d_array.iter().position(|v| !v.is_finite()) {
crate::bail_invalid_estim!(
"Tierney-Kadane correction received non-finite d derivative at row {idx}: {}",
d_array[idx]
);
}
if c_array.is_empty() || d_array.is_empty() {
return Ok(zero_correction());
}
if let Some(sparse) = bundle.sparse_exact.as_ref() {
let x_dense = self
.x()
.try_to_dense_arc("frozen-curvature TK correction requires dense design access")
.map_err(EstimationError::InvalidInput)?;
let xt = x_dense.t().to_owned();
let z_mat =
crate::linalg::sparse_exact::solve_sparse_spdmulti(sparse.factor.as_ref(), &xt)?;
let factor_ref = sparse.factor.clone();
let h_inv_solve = |rhs: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
crate::linalg::sparse_exact::solve_sparse_spd(&factor_ref, rhs)
};
let lambdas: Vec<f64> = rho.iter().map(|r| r.exp()).collect();
let beta = self.sparse_exact_beta_original(pirls_result);
let firth_op = if let Some(jeffreys_link) = reml_robust_jeffreys_link(&self.config) {
if let Some(cached) = bundle.firth_dense_operator_original.as_ref() {
Some(cached.clone())
} else {
Some(std::sync::Arc::new(
Self::build_firth_dense_operator_for_link(
&jeffreys_link,
x_dense.as_ref(),
&pirls_result.final_eta,
self.weights,
)?,
))
}
} else {
None
};
return self.tierney_kadane_analytic_core(
x_dense.as_ref(),
&z_mat,
&c_array,
&d_array,
&e_array,
&f_array,
&self.canonical_penalties,
&lambdas,
ext_coords,
&beta,
firth_op.as_deref(),
compute_gradient,
mode == super::unified::EvalMode::ValueGradientHessian,
&h_inv_solve,
);
}
let free_basis_opt = self.active_constraint_free_basis(pirls_result);
let use_original_basis = matches!(
pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
) && free_basis_opt.is_none();
let h_tk_source = bundle.h_total.as_ref();
let h_tk_eval = if use_original_basis {
self.bundle_matrix_in_original_basis(pirls_result, h_tk_source)
} else if let Some(z) = free_basis_opt.as_ref() {
Self::projectwith_basis(h_tk_source, z)
} else {
h_tk_source.clone()
};
let x_eff_dense = if use_original_basis {
self.x()
.try_to_dense_arc("Tierney-Kadane correction requires dense original design access")
.map_err(EstimationError::InvalidInput)?
.as_ref()
.clone()
} else if let Some(z) = free_basis_opt.as_ref() {
pirls_result.x_transformed.to_dense().dot(z)
} else {
pirls_result.x_transformed.to_dense()
};
let xt = x_eff_dense.t().to_owned();
let p = x_eff_dense.ncols();
let n = x_eff_dense.nrows();
enum HFactor {
Cholesky(crate::linalg::faer_ndarray::FaerCholeskyFactor),
Eigh {
evals: Array1<f64>,
evecs: Array2<f64>,
},
}
let h_factor = if let Ok(chol) = h_tk_eval.cholesky(Side::Lower) {
HFactor::Cholesky(chol)
} else if let Ok((evals, evecs)) = h_tk_eval.eigh(Side::Lower) {
const TK_HESSIAN_PD_EIGENVALUE_FLOOR: f64 = 1e-12;
if let Some((idx, ev)) = evals
.iter()
.enumerate()
.find(|(_, ev)| **ev <= TK_HESSIAN_PD_EIGENVALUE_FLOOR)
{
crate::bail_invalid_estim!(
"Tierney-Kadane correction requires a positive definite Hessian; eigenvalue {idx} is {ev}"
);
}
HFactor::Eigh { evals, evecs }
} else {
crate::bail_invalid_estim!(
"Tierney-Kadane correction could not factor the effective Hessian"
);
};
let z_mat = match &h_factor {
HFactor::Cholesky(chol) => {
let mut solved = xt.clone();
chol.solve_mat_in_place(&mut solved);
solved
}
HFactor::Eigh { evals, evecs } => {
let mut solved = Array2::<f64>::zeros((p, n));
for m in 0..evals.len() {
let ev = evals[m];
let u = evecs.column(m);
let coeffs = xt.t().dot(&u).mapv(|v| v / ev);
for row in 0..p {
let u_row = u[row];
for col in 0..n {
solved[[row, col]] += u_row * coeffs[col];
}
}
}
solved
}
};
let h_inv_solve = move |rhs: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
match &h_factor {
HFactor::Cholesky(chol) => Ok(chol.solvevec(rhs)),
HFactor::Eigh { evals, evecs } => {
let mut sol = Array1::<f64>::zeros(rhs.len());
for m in 0..evals.len() {
let u = evecs.column(m);
let coeff = u.dot(rhs) / evals[m];
for row in 0..sol.len() {
sol[row] += coeff * u[row];
}
}
Ok(sol)
}
}
};
let p_eff = x_eff_dense.ncols();
let tk_penalties: Vec<crate::construction::CanonicalPenalty> = if use_original_basis {
self.canonical_penalties.as_ref().clone()
} else if let Some(z) = free_basis_opt.as_ref() {
pirls_result
.reparam_result
.canonical_transformed
.iter()
.map(|cp| {
crate::construction::CanonicalPenalty::from_dense_root(cp.root.dot(z), p_eff)
})
.collect()
} else {
pirls_result.reparam_result.canonical_transformed.clone()
};
let lambdas: Vec<f64> = rho.iter().map(|r| r.exp()).collect();
let beta = if use_original_basis {
self.sparse_exact_beta_original(pirls_result)
} else if let Some(z) = free_basis_opt.as_ref() {
z.t().dot(pirls_result.beta_transformed.as_ref())
} else {
pirls_result.beta_transformed.as_ref().clone()
};
let firth_op = if let Some(jeffreys_link) = reml_robust_jeffreys_link(&self.config) {
Some(std::sync::Arc::new(
Self::build_firth_dense_operator_for_link(
&jeffreys_link,
&x_eff_dense,
&pirls_result.final_eta,
self.weights,
)?,
))
} else {
None
};
self.tierney_kadane_analytic_core(
&x_eff_dense,
&z_mat,
&c_array,
&d_array,
&e_array,
&f_array,
&tk_penalties,
&lambdas,
ext_coords,
&beta,
firth_op.as_deref(),
compute_gradient,
mode == super::unified::EvalMode::ValueGradientHessian,
&h_inv_solve,
)
}
fn validate_tk_ext_coords(
&self,
mode: super::unified::EvalMode,
ext_coords: &[super::unified::HyperCoord],
) -> Result<(), EstimationError> {
if reml_is_gaussian_identity(&self.config.likelihood)
|| reml_robust_jeffreys_link(&self.config).is_none()
|| !compute_gradient_for_tk(mode)
{
return Ok(());
}
for (idx, coord) in ext_coords.iter().enumerate() {
if coord.tk_eta_fixed.is_none() || coord.tk_x_fixed.is_none() {
crate::bail_invalid_estim!(
"Tierney-Kadane external gradient coordinate {idx} is missing analytic fixed-beta design/eta derivative carriers"
);
}
}
Ok(())
}
fn apply_tk_to_result(
&self,
mut result: super::unified::RemlLamlResult,
tk_terms: TkCorrectionTerms,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
result.cost += tk_terms.value;
if let Some(tk_hess) = tk_terms.hessian.as_ref() {
result
.hessian
.add_rho_block_dense(tk_hess)
.map_err(EstimationError::InvalidInput)?;
}
if let (Some(ref mut grad), Some(tk_grad)) = (result.gradient.as_mut(), tk_terms.gradient) {
if tk_grad.len() == grad.len() {
**grad += &tk_grad;
} else {
crate::bail_invalid_estim!(
"Tierney-Kadane gradient coordinate count mismatch: evaluator produced {} entries, analytic c/d propagation produced {}; this indicates the TK term and the unified evaluator were assembled against different coordinate sets",
grad.len(),
tk_grad.len()
);
}
}
Ok(result)
}
fn runtime_inverse_link(&self) -> InverseLink {
let link_function = self.config.link_function();
if let Some(state) = self.runtime_mixture_link_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = self.runtime_sas_link_state {
if matches!(link_function, LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
InverseLink::Standard(
StandardLink::try_from(link_function)
.expect("state-bearing link without runtime state in runtime_inverse_link"),
)
}
}
fn block_local_sampled_correction(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
n_ext: usize,
) -> Result<TkCorrectionTerms, EstimationError> {
use crate::inference::hmc::{
block_sampled_marginal_correction, laplace_directional_cubic_diagnostic,
laplace_trustworthiness_from_skewness,
};
let n_rho = self.canonical_penalties.len();
let zero = || TkCorrectionTerms {
value: 0.0,
gradient: Some(Array1::zeros(n_rho + n_ext)),
hessian: None,
};
if reml_is_gaussian_identity(&self.config.likelihood) {
return Ok(zero());
}
if rho.len() != n_rho || n_rho == 0 {
return Ok(zero());
}
let pirls_result = bundle.pirls_result.as_ref();
let h_total = bundle.h_total.as_ref();
let c_weights = &pirls_result.solve_c_array;
let x_design = &pirls_result.x_transformed;
let p = h_total.nrows();
if p == 0 || c_weights.len() != x_design.nrows() {
return Ok(zero());
}
let n_obs = x_design.nrows();
let dense_work = n_obs.saturating_mul(p);
if n_obs > TK_MAX_OBSERVATIONS || p > TK_MAX_COEFFICIENTS || dense_work > TK_MAX_DENSE_WORK
{
return Ok(zero());
}
let (max_abs, directional) =
laplace_directional_cubic_diagnostic(h_total, x_design, c_weights)
.map_err(EstimationError::InvalidInput)?;
if !max_abs.is_finite() || max_abs == 0.0 {
return Ok(zero());
}
let n_eff = c_weights.iter().filter(|&&c| c != 0.0).count() as f64;
let verdict = laplace_trustworthiness_from_skewness(&directional, n_eff);
if !verdict.fallback_required() {
return Ok(zero());
}
let sym_h = (h_total + &h_total.t()) * 0.5;
let (evals, evecs) = sym_h.eigh(Side::Lower).map_err(|e| {
EstimationError::InvalidInput(format!(
"#784 block-local fallback eigendecomposition failed: {e}"
))
})?;
let mut block_cols: Vec<usize> = Vec::new();
for &r in &verdict.untrustworthy_directions {
if r < evals.len() && evals[r] > 0.0 {
block_cols.push(r);
}
}
if block_cols.is_empty() {
return Ok(zero());
}
let m = block_cols.len();
let mut block_vecs = Array2::<f64>::zeros((p, m));
let mut block_lambdas = Array1::<f64>::zeros(m);
for (j, &r) in block_cols.iter().enumerate() {
block_vecs.column_mut(j).assign(&evecs.column(r));
block_lambdas[j] = evals[r];
}
let beta_hat = pirls_result.beta_transformed.as_ref().clone();
let penalty_scores: Vec<Array1<f64>> = self
.canonical_penalties
.iter()
.map(|pen| transformed_penalty_matvec(pen, &beta_hat))
.collect();
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let phi = match reml_spec(&self.config.likelihood).response {
ResponseFamily::Gaussian => 1.0,
_ => reml_fixed_glm_dispersion(&self.config.likelihood),
};
let phi = if phi.is_finite() && phi > 0.0 {
phi
} else {
1.0
};
let x_dense = x_design
.try_to_dense_arc("#784 block-local fallback requires dense design access")
.map_err(EstimationError::InvalidInput)?;
let target = Gam784BlockTarget {
x_transformed: x_dense.as_ref(),
block_vecs,
block_lambdas,
eta_hat: pirls_result.final_eta.clone(),
weights_obs: pirls_result.finalweights.clone(),
y: self.y.to_owned(),
prior_weights: self.weights.to_owned(),
likelihood: self.config.likelihood.clone(),
inverse_link: self.runtime_inverse_link(),
phi,
penalty_scores,
lambdas,
base_deviance: pirls_result.deviance,
};
let sampled =
block_sampled_marginal_correction(&target).map_err(EstimationError::InvalidInput)?;
let min_ess = (sampled.n_draws as f64 * MIN_IMPORTANCE_ESS_FRACTION).max(1.0);
if sampled.importance_ess < min_ess {
log::info!(
"[#784] block-local fallback declined: importance ESS {:.1} < {:.1} \
(m={m} dirs, max|γ|={:.3}, τ={:.3})",
sampled.importance_ess,
min_ess,
verdict.max_abs_skewness,
verdict.threshold,
);
return Ok(zero());
}
log::info!(
"[#784] block-local sampled marginalization ENGAGED: m={m} curvature-heavy dirs, \
max|γ|={:.3}, τ={:.3}, Δ_b={:.4e}, ESS={:.1}/{}",
verdict.max_abs_skewness,
verdict.threshold,
sampled.value,
sampled.importance_ess,
sampled.n_draws,
);
let mut gradient = Array1::<f64>::zeros(n_rho + n_ext);
for k in 0..n_rho.min(sampled.rho_gradient.len()) {
gradient[k] = -sampled.rho_gradient[k];
}
Ok(TkCorrectionTerms {
value: -sampled.value,
gradient: Some(gradient),
hessian: None,
})
}
pub(super) fn should_compute_hot_diagnostics(&self, eval_idx: u64) -> bool {
const HOT_DIAGNOSTIC_EVAL_INTERVAL: u64 = 200;
(log::log_enabled!(log::Level::Info) || log::log_enabled!(log::Level::Warn))
&& (eval_idx == 1 || eval_idx.is_multiple_of(HOT_DIAGNOSTIC_EVAL_INTERVAL))
}
fn invalidate_link_dependent_state(&self) {
self.cache_manager.clear_eval_and_factor_caches();
self.cache_manager.pirls_cache.write().unwrap().clear();
self.clear_warm_start_predictor_state();
self.clear_warm_start_adaptive_signals();
}
fn clear_warm_start_predictor_state(&self) {
self.warm_start_beta.write().unwrap().take();
self.warm_start_rho.write().unwrap().take();
self.prev_warm_start_beta.write().unwrap().take();
self.prev_warm_start_rho.write().unwrap().take();
self.ift_warm_start_cache.write().unwrap().take();
self.ift_cached_factor.write().unwrap().take();
self.clear_ift_mode_response_cache();
}
fn ift_mode_response_cache_key(&self) -> usize {
self as *const Self as usize
}
fn pending_joint_ift_theta(&self) -> Option<Array1<f64>> {
latest_outer_theta_for_ift()
}
fn clear_joint_ift_mode_response_cache(&self) {
if let Some(caches) = IFT_JOINT_MODE_RESPONSE_CACHES.get() {
caches
.lock()
.unwrap()
.remove(&self.ift_mode_response_cache_key());
}
}
fn clear_ift_mode_response_cache(&self) {
if let Some(caches) = IFT_MODE_RESPONSE_CACHES.get() {
caches
.lock()
.unwrap()
.remove(&self.ift_mode_response_cache_key());
}
}
fn mode_response_cols_for_warm_start(
&self,
bundle: &EvalShared,
cols: &Array2<f64>,
) -> Option<Array2<f64>> {
if cols.ncols() == 0 || cols.nrows() != self.p {
return None;
}
if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
return Some(cols.clone());
}
match bundle.pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => Some(cols.clone()),
pirls::PirlsCoordinateFrame::TransformedQs
if self
.active_constraint_free_basis(bundle.pirls_result.as_ref())
.is_none() =>
{
Some(cols.clone())
}
_ => None,
}
}
fn store_ift_mode_response_cache_from_result(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
result: &super::unified::RemlLamlResult,
) {
let rho_cols = result
.rho_mode_response_cols
.as_ref()
.and_then(|cols| self.mode_response_cols_for_warm_start(bundle, cols));
let ext_cols = result
.ext_mode_response_cols
.as_ref()
.and_then(|cols| self.mode_response_cols_for_warm_start(bundle, cols));
if rho_cols.is_none() && ext_cols.is_none() {
if result.gradient.is_none() {
return;
}
self.clear_ift_mode_response_cache();
self.clear_joint_ift_mode_response_cache();
return;
}
let rho_col_count = rho_cols.as_ref().map_or(0, Array2::ncols);
let ext_col_count = ext_cols.as_ref().map_or(0, Array2::ncols);
ift_mode_response_caches().lock().unwrap().insert(
self.ift_mode_response_cache_key(),
IftModeResponseRuntimeCache {
rho: rho.clone(),
rho_mode_response_cols: rho_cols.clone(),
ext_mode_response_cols: ext_cols.clone(),
},
);
log::debug!(
"[IFT-CACHE] outcome=mode_response_store rho_cols={} ext_cols={} p={}",
rho_col_count,
ext_col_count,
self.p,
);
let Some(theta) = self.pending_joint_ift_theta() else {
self.clear_joint_ift_mode_response_cache();
return;
};
if theta.len() <= rho.len() || theta.len() != rho.len() + ext_col_count {
self.clear_joint_ift_mode_response_cache();
return;
}
let Some(rho_cols_ref) = rho_cols.as_ref() else {
self.clear_joint_ift_mode_response_cache();
return;
};
let Some(ext_cols_ref) = ext_cols.as_ref() else {
self.clear_joint_ift_mode_response_cache();
return;
};
if rho_cols_ref.nrows() != self.p
|| rho_cols_ref.ncols() != rho.len()
|| ext_cols_ref.nrows() != self.p
|| ext_cols_ref.ncols() != ext_col_count
{
self.clear_joint_ift_mode_response_cache();
return;
}
let active_constraints = self
.active_constraint_free_basis(bundle.pirls_result.as_ref())
.is_some();
if active_constraints {
self.clear_joint_ift_mode_response_cache();
log::info!(
"[IFT-REJECTED] reason=active_constraints joint_dim={}",
theta.len()
);
return;
}
let beta_original = match bundle.pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => {
bundle.pirls_result.beta_transformed.as_ref().clone()
}
pirls::PirlsCoordinateFrame::TransformedQs => bundle
.pirls_result
.reparam_result
.qs
.dot(bundle.pirls_result.beta_transformed.as_ref()),
};
if beta_original.len() != self.p || beta_original.iter().any(|v| !v.is_finite()) {
self.clear_joint_ift_mode_response_cache();
return;
}
let mut mode_response_cols = Array2::<f64>::zeros((self.p, theta.len()));
mode_response_cols
.slice_mut(s![.., ..rho.len()])
.assign(rho_cols_ref);
mode_response_cols
.slice_mut(s![.., rho.len()..])
.assign(ext_cols_ref);
if mode_response_cols.iter().any(|v| !v.is_finite()) {
self.clear_joint_ift_mode_response_cache();
return;
}
ift_joint_mode_response_caches().lock().unwrap().insert(
self.ift_mode_response_cache_key(),
IftJointModeResponseRuntimeCache {
theta,
rho_dim: rho.len(),
beta_original,
mode_response_cols,
active_constraints,
},
);
log::debug!(
"[IFT-CACHE] outcome=joint_mode_response_store rho_cols={} ext_cols={} p={}",
rho_col_count,
ext_col_count,
self.p,
);
}
fn cached_ift_rho_mode_response_cols(
&self,
cache: &super::IftWarmStartCache,
) -> Option<Array2<f64>> {
let guard = ift_mode_response_caches().lock().unwrap();
let cached = guard.get(&self.ift_mode_response_cache_key())?;
if cached.rho.len() != cache.rho.len()
|| cached
.rho
.iter()
.zip(cache.rho.iter())
.any(|(&a, &b)| a.to_bits() != b.to_bits())
{
return None;
}
if cached
.ext_mode_response_cols
.as_ref()
.is_some_and(|cols| cols.nrows() != self.p)
{
return None;
}
let cols = cached.rho_mode_response_cols.as_ref()?;
if cols.nrows() != self.p || cols.ncols() != cache.rho.len() {
return None;
}
Some(cols.clone())
}
fn predict_warm_start_beta_joint_ift_with_outcome(
&self,
new_rho: &Array1<f64>,
max_dtheta_cap: f64,
) -> Option<(Coefficients, IftPredictionOutcome)> {
let theta = self.pending_joint_ift_theta()?;
let cache = {
let guard = ift_joint_mode_response_caches().lock().unwrap();
guard.get(&self.ift_mode_response_cache_key())?.clone()
};
if cache.active_constraints {
log::info!(
"[IFT-REJECTED] reason=active_constraints joint_dim={}",
cache.theta.len(),
);
return None;
}
if !joint_ift_cache_matches_theta(&cache, &theta, new_rho)
|| cache.beta_original.len() != self.p
|| cache.mode_response_cols.nrows() != self.p
|| cache.mode_response_cols.ncols() != cache.theta.len()
{
return None;
}
let mut max_abs_dtheta = 0.0_f64;
let dtheta: Array1<f64> = theta
.iter()
.zip(cache.theta.iter())
.map(|(&new_value, &old_value)| {
let d = new_value - old_value;
if !d.is_finite() {
return f64::INFINITY;
}
if d.abs() > max_abs_dtheta {
max_abs_dtheta = d.abs();
}
d
})
.collect();
if !max_abs_dtheta.is_finite() || max_abs_dtheta > max_dtheta_cap {
log::info!(
"[IFT-REJECTED] reason=large_dtheta max_dtheta={:.3e} cap={:.3e} joint_dim={}",
max_abs_dtheta,
max_dtheta_cap,
cache.theta.len(),
);
return None;
}
if dtheta.iter().all(|d| d.abs() <= IFT_WARM_START_DRHO_EPS) {
log::info!(
"[IFT-NOOP] reason=all_dtheta_below_eps max_dtheta={:.3e} joint_dim={}",
max_abs_dtheta,
cache.theta.len(),
);
return Some((
Coefficients::new(cache.beta_original),
IftPredictionOutcome::Noop,
));
}
let solution_original = cache.mode_response_cols.dot(&dtheta);
if !solution_original.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_solution max_dtheta={:.3e} joint_dim={}",
max_abs_dtheta,
cache.theta.len(),
);
return None;
}
let mut predicted = cache.beta_original;
for (target, &correction) in predicted.iter_mut().zip(solution_original.iter()) {
*target -= correction;
}
if !predicted.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_predicted max_dtheta={:.3e} joint_dim={}",
max_abs_dtheta,
cache.theta.len(),
);
return None;
}
log::info!(
"[IFT-CACHE] outcome=joint_mode_response_hit joint_dim={} rho_dim={} p={}",
cache.theta.len(),
cache.rho_dim,
self.p,
);
log::debug!(
"[warm-start] joint IFT prediction reused mode responses: max|Δθ|={:.3e}, ‖Δβ‖={:.3e}",
max_abs_dtheta,
solution_original.dot(&solution_original).sqrt(),
);
Some((
Coefficients::new(predicted),
IftPredictionOutcome::Predicted,
))
}
fn joint_ift_cache_matches_pending_theta(&self, new_rho: &Array1<f64>) -> bool {
let Some(theta) = self.pending_joint_ift_theta() else {
return false;
};
let guard = ift_joint_mode_response_caches().lock().unwrap();
let Some(cache) = guard.get(&self.ift_mode_response_cache_key()) else {
return false;
};
joint_ift_cache_matches_theta(cache, &theta, new_rho)
}
fn clear_warm_start_adaptive_signals(&self) {
self.last_inner_iters.store(0, Ordering::Relaxed);
self.last_inner_converged.store(false, Ordering::Relaxed);
self.last_pirls_lm_lambda.store(0, Ordering::Relaxed);
self.last_ift_prediction_residual
.store(IFT_RESIDUAL_NO_SIGNAL_BITS, Ordering::Relaxed);
self.last_pirls_accept_rho
.store(IFT_RESIDUAL_NO_SIGNAL_BITS, Ordering::Relaxed);
self.clear_ift_quality_runtime_state();
}
pub(crate) fn set_link_states(
&mut self,
mixture_link_state: Option<crate::types::MixtureLinkState>,
sas_link_state: Option<SasLinkState>,
) {
let changed = self.runtime_mixture_link_state != mixture_link_state
|| self.runtime_sas_link_state != sas_link_state;
if !changed {
return;
}
self.runtime_mixture_link_state = mixture_link_state;
self.runtime_sas_link_state = sas_link_state;
*self.persistent_warm_start_key.write().unwrap() = None;
self.persistent_warm_start_loaded
.store(false, Ordering::Relaxed);
self.invalidate_link_dependent_state();
}
fn hessian_cd_arrays(
&self,
pirls_result: &PirlsResult,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
Ok((
pirls_result.solve_c_array.clone(),
pirls_result.solve_d_array.clone(),
))
}
fn hessian_surface_arrays(
&self,
pirls_result: &PirlsResult,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let (c_array, d_array) = self.hessian_cd_arrays(pirls_result)?;
Ok(crate::pirls::outer_hessian_curvature_arrays(
pirls_result.final_weights_signed(),
pirls_result.solve_weights_psd(),
&c_array,
&d_array,
&pirls_result.final_eta,
&self.config.link_kind,
))
}
fn hessian_cde_arrays(
&self,
pirls_result: &PirlsResult,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let c_array = pirls_result.solve_c_array.clone();
let d_array = pirls_result.solve_d_array.clone();
let n = d_array.len();
let mut e_array = Array1::<f64>::zeros(n);
let link_function = self.config.link_function();
let inverse_link = if let Some(state) = self.runtime_mixture_link_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = self.runtime_sas_link_state {
if matches!(link_function, LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
InverseLink::Standard(
StandardLink::try_from(link_function)
.expect("state-bearing link without runtime state in hessian_cde_arrays"),
)
};
let canonical_logit = {
let spec = reml_spec(&pirls_result.likelihood);
matches!(spec.response, ResponseFamily::Binomial)
&& matches!(spec.link, InverseLink::Standard(StandardLink::Logit))
} && self.runtime_mixture_link_state.is_none();
if canonical_logit {
use rayon::prelude::*;
let final_eta = &pirls_result.final_eta;
let weights = &self.weights;
let e_s = e_array.as_slice_mut().expect("e_array must be contiguous");
e_s.par_iter_mut().enumerate().for_each(|(i, e_o)| {
let eta_raw = final_eta[i];
if pirls::eta_clamp_active(&inverse_link, eta_raw) {
*e_o = 0.0;
} else {
let jet = crate::mixture_link::logit_inverse_link_jet5(eta_raw);
*e_o = weights[i].max(0.0) * jet.d4;
}
});
return Ok((c_array, d_array, e_array));
}
let likelihood = &pirls_result.likelihood;
let weight_family = pirls::weight_family_for_glm_likelihood(likelihood);
let phi = reml_fixed_glm_dispersion(likelihood);
let dmu_deta = &pirls_result.solve_dmu_deta;
let d2mu_deta2 = &pirls_result.solve_d2mu_deta2;
let d3mu_deta3 = &pirls_result.solve_d3mu_deta3;
if dmu_deta.len() != n || d2mu_deta2.len() != n || d3mu_deta3.len() != n {
crate::bail_invalid_estim!(
"Tierney-Kadane e_obs requires populated solve_*mu_deta arrays (n={}, dmu={}, d2mu={}, d3mu={}); ensure PIRLS rehydration ran",
n,
dmu_deta.len(),
d2mu_deta2.len(),
d3mu_deta3.len(),
);
}
let mu = &pirls_result.solvemu;
if mu.len() != n {
crate::bail_invalid_estim!(
"Tierney-Kadane e_obs requires solvemu populated (n={}, len={})",
n,
mu.len(),
);
}
use rayon::prelude::*;
let final_eta = &pirls_result.final_eta;
let weights = &self.weights;
let y_view = &self.y;
let inverse_link_ref = &inverse_link;
let e_s = e_array.as_slice_mut().expect("e_array must be contiguous");
e_s.par_iter_mut()
.enumerate()
.try_for_each(|(i, e_o)| -> Result<(), EstimationError> {
let eta_raw = final_eta[i];
if pirls::eta_clamp_active(inverse_link_ref, eta_raw) {
*e_o = 0.0;
return Ok(());
}
let h1 = dmu_deta[i];
let h2 = d2mu_deta2[i];
let h3 = d3mu_deta3[i];
let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
inverse_link_ref,
eta_raw,
)?;
let h5 = crate::mixture_link::inverse_link_pdffourth_derivative_for_inverse_link(
inverse_link_ref,
eta_raw,
)?;
if !h1.is_finite()
|| !h2.is_finite()
|| !h3.is_finite()
|| !h4.is_finite()
|| !h5.is_finite()
{
*e_o = 0.0;
return Ok(());
}
let mu_i = mu[i];
let vj = pirls::variance_jet_for_weight_family(weight_family, mu_i);
if !(vj.v.is_finite() && vj.v > 0.0) {
*e_o = 0.0;
return Ok(());
}
let pw = weights[i].max(0.0);
let y_i = y_view[i];
let e_i = pirls::e_obs_from_jets(y_i, mu_i, h1, h2, h3, h4, h5, vj, phi, pw);
*e_o = if e_i.is_finite() { e_i } else { 0.0 };
Ok(())
})?;
Ok((c_array, d_array, e_array))
}
fn hessian_cdef_arrays(
&self,
pirls_result: &PirlsResult,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let (c_array, d_array, e_array) = self.hessian_cde_arrays(pirls_result)?;
let canonical_logit = {
let spec = reml_spec(&pirls_result.likelihood);
matches!(spec.response, ResponseFamily::Binomial)
&& matches!(spec.link, InverseLink::Standard(StandardLink::Logit))
} && self.runtime_mixture_link_state.is_none();
if !canonical_logit {
crate::bail_invalid_estim!(
"Tierney-Kadane outer Hessian is implemented for canonical Binomial Logit Firth fits only"
);
}
let mut f_array = Array1::<f64>::zeros(e_array.len());
use rayon::prelude::*;
let final_eta = &pirls_result.final_eta;
let weights = &self.weights;
let f_s = f_array.as_slice_mut().expect("f_array must be contiguous");
f_s.par_iter_mut().enumerate().for_each(|(i, f_o)| {
let eta_raw = final_eta[i];
let eta_used = eta_raw.clamp(-ETA_OVERFLOW_CLAMP, ETA_OVERFLOW_CLAMP);
if eta_raw != eta_used {
*f_o = 0.0;
} else {
let jet = crate::mixture_link::logit_inverse_link_jet5(eta_used);
*f_o = weights[i].max(0.0) * jet.d5;
}
});
Ok((c_array, d_array, e_array, f_array))
}
pub(super) fn compute_soft_priorcost(&self, rho: &Array1<f64>) -> f64 {
let len = rho.len();
if len == 0 || RHO_SOFT_PRIOR_WEIGHT == 0.0 {
return 0.0;
}
let anchor = self.rho_weight_anchor();
let inv_bound = 1.0 / RHO_BOUND;
let sharp = RHO_SOFT_PRIOR_SHARPNESS;
let mut cost = 0.0;
for &ri in rho.iter() {
let scaled = sharp * (ri - anchor) * inv_bound;
cost += scaled.cosh().ln();
}
cost * RHO_SOFT_PRIOR_WEIGHT
}
pub(super) fn compute_soft_priorgrad(&self, rho: &Array1<f64>) -> Array1<f64> {
let len = rho.len();
let mut grad = Array1::<f64>::zeros(len);
if len == 0 || RHO_SOFT_PRIOR_WEIGHT == 0.0 {
return grad;
}
let anchor = self.rho_weight_anchor();
let inv_bound = 1.0 / RHO_BOUND;
let sharp = RHO_SOFT_PRIOR_SHARPNESS;
for (g, &ri) in grad.iter_mut().zip(rho.iter()) {
let scaled = sharp * (ri - anchor) * inv_bound;
*g = sharp * inv_bound * scaled.tanh() * RHO_SOFT_PRIOR_WEIGHT;
}
grad
}
pub(super) fn add_soft_priorhessian_in_place(&self, rho: &Array1<f64>, hess: &mut Array2<f64>) {
let len = rho.len();
if len == 0 || RHO_SOFT_PRIOR_WEIGHT == 0.0 {
return;
}
let anchor = self.rho_weight_anchor();
let a = RHO_SOFT_PRIOR_SHARPNESS / RHO_BOUND;
let prefactor = RHO_SOFT_PRIOR_WEIGHT * a * a;
for i in 0..len {
let t = (a * (rho[i] - anchor)).tanh();
hess[[i, i]] += prefactor * (1.0 - t * t);
}
}
pub(super) fn compute_soft_priorhess(&self, rho: &Array1<f64>) -> Option<Array2<f64>> {
let len = rho.len();
if len == 0 || RHO_SOFT_PRIOR_WEIGHT == 0.0 {
return None;
}
let mut hess = Array2::<f64>::zeros((len, len));
self.add_soft_priorhessian_in_place(rho, &mut hess);
if hess.iter().any(|&v| v != 0.0) {
Some(hess)
} else {
None
}
}
fn evaluate_configured_rho_prior(
&self,
rho: &Array1<f64>,
) -> super::rho_prior_eval::RhoPriorEval {
let effective = self.effective_rho_prior();
let anchor = self.rho_weight_anchor();
let rho_anchored = (anchor != 0.0).then(|| rho.mapv(|r| r - anchor));
let rho_eff: &Array1<f64> = rho_anchored.as_ref().unwrap_or(rho);
let mut eval = super::rho_prior_eval::evaluate(
effective.as_ref(),
rho_eff,
super::rho_prior_eval::InvalidPriorPolicy::Saturate,
)
.expect("Saturate policy never errors");
if eval.cost.is_finite() {
let mask = firth_default_coord_mask(&self.rho_prior, rho.len());
if mask.iter().any(|&d| d) {
let theta = super::rho_prior_eval::pc_prior_rate(
FIRTH_DEFAULT_PC_UPPER,
FIRTH_DEFAULT_PC_TAIL_PROB,
);
let mut hess = eval
.hessian
.take()
.unwrap_or_else(|| Array2::<f64>::zeros((rho.len(), rho.len())));
for (idx, &is_default) in mask.iter().enumerate() {
if !is_default {
continue;
}
let r = rho_eff[idx];
let (pc_c, pc_g, pc_h) = super::rho_prior_eval::pc_prior_terms(theta, r);
let (b_c, b_g, b_h) = super::rho_prior_eval::firth_default_barrier_terms(
theta,
FIRTH_DEFAULT_PC_UPPER,
r,
);
eval.cost += b_c - pc_c;
eval.gradient[idx] += b_g - pc_g;
hess[[idx, idx]] += b_h - pc_h;
}
eval.hessian = hess.iter().any(|&v| v != 0.0).then_some(hess);
}
}
eval
}
fn effective_rho_prior(&self) -> std::borrow::Cow<'_, RhoPrior> {
resolve_effective_rho_prior(&self.rho_prior)
}
fn gaussian_weight_log_sum_half(&self) -> f64 {
0.5 * self
.weights
.iter()
.filter(|&&wi| wi > 0.0)
.map(|&wi| wi.ln())
.sum::<f64>()
}
fn rho_weight_anchor(&self) -> f64 {
let mut sum = 0.0;
let mut count = 0usize;
for &wi in self.weights.iter() {
if wi > 0.0 {
sum += wi.ln();
count += 1;
}
}
if count == 0 { 0.0 } else { sum / count as f64 }
}
fn compute_configured_rho_prior_cost(&self, rho: &Array1<f64>) -> f64 {
self.evaluate_configured_rho_prior(rho).cost
}
fn compute_configured_rho_prior_grad(&self, rho: &Array1<f64>) -> Array1<f64> {
self.evaluate_configured_rho_prior(rho).gradient
}
fn compute_configured_rho_prior_hess(&self, rho: &Array1<f64>) -> Option<Array2<f64>> {
self.evaluate_configured_rho_prior(rho).hessian
}
pub(super) fn effectivehessian(
&self,
pr: &PirlsResult,
) -> Result<(Array2<f64>, RidgePassport), EstimationError> {
let h = &pr.stabilizedhessian_transformed;
if h.factorize().is_ok() {
return Ok((h.to_dense(), pr.ridge_passport));
}
Err(EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})
}
pub(crate) fn newwith_offset<X>(
y: ArrayView1<'a, f64>,
x: X,
weights: ArrayView1<'a, f64>,
offset: ArrayView1<'_, f64>,
canonical_penalties: Vec<crate::construction::CanonicalPenalty>,
p: usize,
config: &'a RemlConfig,
nullspace_dims: Option<Vec<usize>>,
coefficient_lower_bounds: Option<Array1<f64>>,
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
) -> Result<Self, EstimationError>
where
X: Into<DesignMatrix>,
{
Self::newwith_offset_shared(
y,
x,
weights,
offset,
Arc::new(canonical_penalties),
p,
Arc::new(config.clone()),
nullspace_dims,
coefficient_lower_bounds,
linear_constraints,
)
}
pub(crate) fn newwith_offset_shared<X>(
y: ArrayView1<'a, f64>,
x: X,
weights: ArrayView1<'a, f64>,
offset: ArrayView1<'_, f64>,
canonical_penalties: Arc<Vec<crate::construction::CanonicalPenalty>>,
p: usize,
config: Arc<RemlConfig>,
nullspace_dims: Option<Vec<usize>>,
coefficient_lower_bounds: Option<Array1<f64>>,
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
) -> Result<Self, EstimationError>
where
X: Into<DesignMatrix>,
{
let x = x.into();
crate::construction::report_penalty_pair_redundancy(canonical_penalties.as_ref());
let expected_len = canonical_penalties.len();
let nullspace_dims = match nullspace_dims {
Some(dims) => {
if dims.len() != expected_len {
crate::bail_invalid_estim!(
"nullspace_dims length {} does not match penalties {}",
dims.len(),
expected_len
);
}
dims
}
None => vec![0; expected_len],
};
let balanced_penalty_root =
create_balanced_penalty_root_from_canonical(&canonical_penalties, p)?;
let reparam_invariant =
precompute_reparam_invariant_from_canonical(&canonical_penalties, p)?;
let sparse_penalty_blocks =
build_sparse_penalty_blocks_from_canonical(canonical_penalties.as_ref(), p)?
.map(Arc::new);
let runtime_mixture_link_state = config.link_kind.mixture_state().cloned();
let runtime_sas_link_state = config.link_kind.sas_state().copied();
Ok(Self {
y,
x,
weights,
offset: offset.to_owned(),
canonical_penalties,
balanced_penalty_root,
reparam_invariant,
sparse_penalty_blocks,
p,
config,
runtime_mixture_link_state,
runtime_sas_link_state,
nullspace_dims,
coefficient_lower_bounds,
linear_constraints,
penalty_shrinkage_floor: None,
rho_prior: RhoPrior::Flat,
cache_manager: EvalCacheManager::new(),
arena: RemlArena::new(),
warm_start_beta: RwLock::new(None),
warm_start_rho: RwLock::new(None),
prev_warm_start_beta: RwLock::new(None),
prev_warm_start_rho: RwLock::new(None),
warm_start_enabled: AtomicBool::new(true),
screening_max_inner_iterations: Arc::new(AtomicUsize::new(0)),
outer_inner_cap: Arc::new(AtomicUsize::new(0)),
last_inner_iters: Arc::new(AtomicUsize::new(0)),
last_inner_converged: Arc::new(AtomicBool::new(false)),
ift_warm_start_cache: RwLock::new(None),
last_pirls_lm_lambda: Arc::new(AtomicU64::new(0)),
last_ift_prediction_residual: Arc::new(AtomicU64::new(IFT_RESIDUAL_NO_SIGNAL_BITS)),
last_pirls_accept_rho: Arc::new(AtomicU64::new(IFT_RESIDUAL_NO_SIGNAL_BITS)),
ift_cached_factor: RwLock::new(None),
kronecker_penalty_system: None,
kronecker_factored: None,
gaussian_fixed_cache: RwLock::new(None),
alo_frozen_nuisance: RwLock::new(None),
persistent_warm_start_key: RwLock::new(None),
persistent_latent_values_fingerprint: None,
persistent_latent_values_cache: RwLock::new(PersistentLatentValuesCache::default()),
analytic_penalty_registry_fingerprint: 0,
persistent_warm_start_loaded: AtomicBool::new(false),
})
}
pub(in crate::solver::estimate) fn reset_surface<X>(
&mut self,
x: X,
canonical_penalties: Arc<Vec<crate::construction::CanonicalPenalty>>,
p: usize,
nullspace_dims: Vec<usize>,
coefficient_lower_bounds: Option<Array1<f64>>,
linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
kronecker_penalty_system: Option<crate::smooth::KroneckerPenaltySystem>,
kronecker_factored: Option<crate::basis::KroneckerFactoredBasis>,
) -> Result<(), EstimationError>
where
X: Into<DesignMatrix>,
{
let expected_len = canonical_penalties.len();
if nullspace_dims.len() != expected_len {
crate::bail_invalid_estim!(
"nullspace_dims length {} does not match penalties {}",
nullspace_dims.len(),
expected_len
);
}
let balanced_penalty_root =
create_balanced_penalty_root_from_canonical(&canonical_penalties, p)?;
let reparam_invariant =
precompute_reparam_invariant_from_canonical(&canonical_penalties, p)?;
let sparse_penalty_blocks =
build_sparse_penalty_blocks_from_canonical(canonical_penalties.as_ref(), p)?
.map(Arc::new);
self.x = x.into();
self.canonical_penalties = canonical_penalties;
self.balanced_penalty_root = balanced_penalty_root;
self.reparam_invariant = reparam_invariant;
self.sparse_penalty_blocks = sparse_penalty_blocks;
self.p = p;
self.nullspace_dims = nullspace_dims;
self.coefficient_lower_bounds = coefficient_lower_bounds;
self.linear_constraints = linear_constraints;
self.kronecker_penalty_system = kronecker_penalty_system;
self.kronecker_factored = kronecker_factored;
*self.gaussian_fixed_cache.write().unwrap() = None;
*self.alo_frozen_nuisance.write().unwrap() = None;
*self.persistent_warm_start_key.write().unwrap() = None;
self.persistent_warm_start_loaded
.store(false, Ordering::Relaxed);
self.cache_manager.clear_eval_and_factor_caches();
self.cache_manager.pirls_cache.write().unwrap().clear();
self.clear_warm_start_predictor_state();
self.clear_warm_start_adaptive_signals();
self.reset_hypergradient_budget_controller();
Ok(())
}
pub(crate) fn set_kronecker_penalty_system(
&mut self,
system: crate::smooth::KroneckerPenaltySystem,
) {
self.kronecker_penalty_system = Some(system);
}
pub(crate) fn set_kronecker_factored(
&mut self,
factored: crate::basis::KroneckerFactoredBasis,
) {
self.kronecker_factored = Some(factored);
}
pub(crate) fn set_penalty_shrinkage_floor(&mut self, floor: Option<f64>) {
self.penalty_shrinkage_floor = floor;
*self.persistent_warm_start_key.write().unwrap() = None;
self.persistent_warm_start_loaded
.store(false, Ordering::Relaxed);
}
pub(crate) fn set_rho_prior(&mut self, prior: RhoPrior) {
self.rho_prior = prior;
*self.persistent_warm_start_key.write().unwrap() = None;
self.persistent_warm_start_loaded
.store(false, Ordering::Relaxed);
}
pub(in crate::solver::estimate) fn set_analytic_penalty_registry_fingerprint(
&mut self,
fingerprint: u64,
) {
if self.analytic_penalty_registry_fingerprint == fingerprint {
return;
}
self.analytic_penalty_registry_fingerprint = fingerprint;
*self.persistent_warm_start_key.write().unwrap() = None;
self.persistent_warm_start_loaded
.store(false, Ordering::Relaxed);
}
pub(crate) fn set_persistent_latent_values_fingerprint(&mut self, fingerprint: u64) {
self.persistent_latent_values_fingerprint = Some(fingerprint);
}
pub(super) fn rhokey_sanitized(&self, rho: &Array1<f64>) -> Option<Vec<u64>> {
EvalCacheManager::sanitized_rhokey(rho)
}
pub(super) fn prepare_eval_bundlewithkey(
&self,
rho: &Array1<f64>,
key: Option<Vec<u64>>,
) -> Result<EvalShared, EstimationError> {
let decision = self.select_reml_geometry(rho);
match decision.geometry {
RemlGeometry::SparseExactSpd => {
match self.prepare_sparse_eval_bundlewithkey(rho, key.clone()) {
Ok(bundle) => {
log::info!(
"[reml-geometry] sparse_exact_spd reason={} p={} nnz_x={} nnz_h_est={} density_h_est={}",
decision.reason,
decision.p,
decision.nnz_x,
decision
.nnz_h_upper_est
.map(|v| v.to_string())
.unwrap_or_else(|| "na".to_string()),
decision
.density_h_upper_est
.map(|v| format!("{v:.4}"))
.unwrap_or_else(|| "na".to_string()),
);
Ok(bundle)
}
Err(err) => {
log::warn!(
"[reml-geometry] sparse_exact_spd failed ({}); falling back to dense spectral",
err
);
self.prepare_dense_eval_bundlewithkey(rho, key)
}
}
}
RemlGeometry::DenseSpectral => self.prepare_dense_eval_bundlewithkey(rho, key),
}
}
pub(crate) fn obtain_eval_bundle(
&self,
rho: &Array1<f64>,
) -> Result<EvalShared, EstimationError> {
let key = self.rhokey_sanitized(rho);
if let Some(existing) = self.cache_manager.cached_eval_bundle(&key) {
return Ok(existing.clone());
}
let bundle = self.prepare_eval_bundlewithkey(rho, key)?;
self.cache_manager.store_eval_bundle(bundle.clone());
Ok(bundle)
}
pub(crate) fn obtain_eval_bundle_for_outer_theta(
&self,
rho: &Array1<f64>,
theta: &Array1<f64>,
) -> Result<EvalShared, EstimationError> {
let key = self.rhokey_sanitized(theta);
if let Some(existing) = self.cache_manager.cached_eval_bundle(&key) {
return Ok(existing.clone());
}
let bundle = self.prepare_eval_bundlewithkey(rho, key)?;
self.cache_manager.store_eval_bundle(bundle.clone());
Ok(bundle)
}
pub(crate) fn objective_innerhessian(
&self,
rho: &Array1<f64>,
) -> Result<Array2<f64>, EstimationError> {
let bundle = self.obtain_eval_bundle(rho)?;
if let Some(sparse) = bundle.sparse_exact.as_ref() {
let h = crate::linalg::sparse_exact::assemble_sparse_factor_h_dense(&sparse.factor)?;
if h.nrows() != self.p || h.ncols() != self.p {
crate::bail_invalid_estim!(
"sparse exact objective inner Hessian shape {}x{} != {}x{}",
h.nrows(),
h.ncols(),
self.p,
self.p
);
}
return Ok(h);
}
Ok(bundle.h_total.as_ref().clone())
}
fn previous_outer_gradient_norm(&self, current_key: &Option<Vec<u64>>) -> Option<f64> {
let guard = self.cache_manager.current_outer_eval.read().unwrap();
let (cached_key, eval) = guard.as_ref()?;
if current_key
.as_ref()
.is_some_and(|current_key| current_key == cached_key)
{
return None;
}
let norm = eval.gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
(norm.is_finite() && norm >= 0.0).then_some(norm)
}
pub(super) fn active_constraint_free_basis(&self, pr: &PirlsResult) -> Option<Array2<f64>> {
let lin = pr.linear_constraints_transformed.as_ref()?;
let beta_t = pr.beta_transformed.as_ref();
let mut activerows: Vec<Array1<f64>> = Vec::new();
for i in 0..lin.a.nrows() {
let slack = lin.a.row(i).dot(beta_t) - lin.b[i];
if slack <= ACTIVE_CONSTRAINT_SLACK_TOL {
activerows.push(lin.a.row(i).to_owned());
}
}
if activerows.is_empty() {
return None;
}
let p_t = lin.a.ncols();
let mut a_t = Array2::<f64>::zeros((p_t, activerows.len()));
for (j, row) in activerows.iter().enumerate() {
for k in 0..p_t {
a_t[[k, j]] = row[k];
}
}
let qrow = Self::orthonormalize_columns(&a_t, ORTHONORM_DROP_TOL); let rank = qrow.ncols();
if rank == 0 {
return None;
}
if rank >= p_t {
return Some(Array2::<f64>::zeros((p_t, 0)));
}
let mut z = Array2::<f64>::zeros((p_t, p_t - rank));
let mut kept = 0usize;
for j in 0..p_t {
let mut v = Array1::<f64>::zeros(p_t);
v[j] = 1.0;
for t in 0..rank {
let qt = qrow.column(t);
let proj = qt.dot(&v);
v -= &qt.mapv(|x| x * proj);
}
for t in 0..kept {
let zt = z.column(t);
let proj = zt.dot(&v);
v -= &zt.mapv(|x| x * proj);
}
let nrm = v.dot(&v).sqrt();
if nrm > ORTHONORM_DROP_TOL {
z.column_mut(kept).assign(&v.mapv(|x| x / nrm));
kept += 1;
if kept == p_t - rank {
break;
}
}
}
Some(z.slice(ndarray::s![.., 0..kept]).to_owned())
}
fn barrier_config_from_constraints(
constraints: &crate::pirls::LinearInequalityConstraints,
) -> Option<super::unified::BarrierConfig> {
let config = super::unified::BarrierConfig::from_constraints(Some(constraints))?;
{
const DIAGNOSTIC_BOUND_SLACK: f64 = 0.01;
const DIAGNOSTIC_BETA_MAGNITUDE: f64 = 1.0;
const DIAGNOSTIC_CURVATURE_REL_THRESHOLD: f64 = 0.05;
let max_idx = config
.constrained_indices
.iter()
.max()
.copied()
.unwrap_or(0);
let mut beta_test = Array1::<f64>::zeros(max_idx + 1);
for ((&idx, &rhs), &sign) in config
.constrained_indices
.iter()
.zip(config.lower_bounds.iter())
.zip(config.bound_signs.iter())
{
beta_test[idx] = (rhs + DIAGNOSTIC_BOUND_SLACK) / sign;
}
let significant = config.barrier_curvature_is_significant(
&beta_test,
DIAGNOSTIC_BETA_MAGNITUDE,
DIAGNOSTIC_CURVATURE_REL_THRESHOLD,
);
log::trace!(
"[barrier] curvature significant={significant} (tau={:.2e}, n_constrained={})",
config.tau,
config.constrained_indices.len(),
);
}
Some(config)
}
pub(super) fn enforce_constraint_kkt(&self, pr: &PirlsResult) -> Result<(), EstimationError> {
let Some(kkt) = pr.constraint_kkt.as_ref() else {
return Ok(());
};
let stationarity_tol = if kkt.working_set_rank_deficient {
crate::solver::active_set::ACTIVE_SET_KKT_DEGENERATE_STATIONARITY_TOL
} else {
KKT_TOL_STAT
};
if kkt.primal_feasibility > KKT_TOL_PRIMAL
|| kkt.dual_feasibility > KKT_TOL_DUAL
|| kkt.complementarity > KKT_TOL_COMP
|| kkt.stationarity > stationarity_tol
{
let mut worstrow_msg = String::new();
if let Some(lin) = pr.linear_constraints_transformed.as_ref() {
let mut worst = 0.0_f64;
let mut worstrow = 0usize;
for i in 0..lin.a.nrows() {
let slack = lin.a.row(i).dot(&pr.beta_transformed.0) - lin.b[i];
let viol = (-slack).max(0.0);
if viol > worst {
worst = viol;
worstrow = i;
}
}
if worst > 0.0 {
worstrow_msg = format!("; worstrow={} worstviolation={:.3e}", worstrow, worst);
}
}
return Err(EstimationError::ParameterConstraintViolation(format!(
"KKT residuals exceed tolerance: primal={:.3e}, dual={:.3e}, comp={:.3e}, stat={:.3e} (tol={:.3e}{}); active={}/{}{}",
kkt.primal_feasibility,
kkt.dual_feasibility,
kkt.complementarity,
kkt.stationarity,
stationarity_tol,
if kkt.working_set_rank_deficient {
", degenerate face"
} else {
""
},
kkt.n_active,
kkt.n_constraints,
worstrow_msg
)));
}
Ok(())
}
pub(super) fn projectwith_basis(matrix: &Array2<f64>, z: &Array2<f64>) -> Array2<f64> {
let zt_m = crate::faer_ndarray::fast_atb(z, matrix);
crate::faer_ndarray::fast_ab(&zt_m, z)
}
pub(super) fn compute_penalty_subspace(
&self,
e_transformed: &Array2<f64>,
ridge_passport: RidgePassport,
) -> Result<PenaltySubspace, EstimationError> {
let p = e_transformed.ncols();
if e_transformed.nrows() == 0 || p == 0 {
return Ok(PenaltySubspace {
evals: Array1::zeros(p),
evecs: Array2::zeros((p, p)),
rank: 0,
});
}
let cached =
self.cache_manager
.cached_penalty_subspace(e_transformed, &ridge_passport, || {
let mut s_lambda = e_transformed.t().dot(e_transformed);
let ridge = ridge_passport.penalty_logdet_ridge();
if ridge > 0.0 {
for i in 0..p {
s_lambda[[i, i]] += ridge;
}
}
let (evals, evecs) = s_lambda
.eigh(Side::Lower)
.map_err(EstimationError::EigendecompositionFailed)?;
let rank = if self.canonical_penalties.is_empty() {
positive_penalty_rank_and_logdet(evals.as_slice().unwrap()).0
} else {
self.canonical_penalties
.iter()
.map(crate::construction::CanonicalPenalty::rank)
.sum::<usize>()
.min(p)
};
Ok(PenaltySubspace { evals, evecs, rank })
})?;
Ok(PenaltySubspace {
evals: cached.evals.clone(),
evecs: cached.evecs.clone(),
rank: cached.rank,
})
}
fn fixed_subspace_penalty_rank_and_logdet_from_subspace(
&self,
penalty_subspace: &PenaltySubspace,
) -> (usize, f64) {
if penalty_subspace.rank == 0 || penalty_subspace.evals.is_empty() {
return (0, 0.0);
}
if self.canonical_penalties.is_empty() {
return positive_penalty_rank_and_logdet(penalty_subspace.evals.as_slice().unwrap());
}
let p = penalty_subspace.evals.len();
let evals_slice = penalty_subspace.evals.as_slice().unwrap();
let log_det: f64 = evals_slice
.iter()
.skip(p.saturating_sub(penalty_subspace.rank))
.filter_map(|&ev| if ev > 0.0 { Some(ev.ln()) } else { None })
.sum();
(penalty_subspace.rank, log_det)
}
pub(super) fn fixed_subspace_hessian_projected_parts(
&self,
h_total: &Array2<f64>,
penalty_subspace: &PenaltySubspace,
) -> Result<(f64, Option<super::unified::PenaltySubspaceTrace>), EstimationError> {
let p = h_total.ncols();
if p == 0 {
return Ok((0.0, None));
}
if h_total.nrows() != p {
crate::bail_invalid_estim!(
"fixed_subspace_hessian_projected_parts: H must be square, got {}x{}",
h_total.nrows(),
p
);
}
if penalty_subspace.evecs.nrows() != p || penalty_subspace.evecs.ncols() != p {
crate::bail_invalid_estim!(
"fixed_subspace_hessian_projected_parts: penalty eigenspace dim {}x{} does not match H dim {}",
penalty_subspace.evecs.nrows(),
penalty_subspace.evecs.ncols(),
p
);
}
let r = penalty_subspace.rank;
if r == 0 {
return Ok((0.0, None));
}
let positive_cols: Vec<usize> = (p - r..p).collect();
let mut u_s = Array2::<f64>::zeros((p, r));
for (out_col, &src_col) in positive_cols.iter().enumerate() {
for row in 0..p {
u_s[[row, out_col]] = penalty_subspace.evecs[[row, src_col]];
}
}
let h_times_u = crate::faer_ndarray::fast_ab(h_total, &u_s);
let mut h_proj = crate::faer_ndarray::fast_atb(&u_s, &h_times_u);
enforce_symmetry(&mut h_proj);
let (h_proj_evals, h_proj_evecs) = h_proj
.eigh(Side::Lower)
.map_err(EstimationError::EigendecompositionFailed)?;
let h_thr = super::unified::positive_eigenvalue_threshold(h_proj_evals.as_slice().unwrap());
let log_det = super::unified::exact_pseudo_logdet(h_proj_evals.as_slice().unwrap(), h_thr);
let mut h_proj_inverse = Array2::<f64>::zeros((r, r));
for a in 0..r {
let sigma = h_proj_evals[a];
if sigma <= h_thr {
continue;
}
let inv = 1.0 / sigma;
for i in 0..r {
for j in 0..r {
h_proj_inverse[[i, j]] += inv * h_proj_evecs[[i, a]] * h_proj_evecs[[j, a]];
}
}
}
Ok((
log_det,
Some(super::unified::PenaltySubspaceTrace {
u_s,
h_proj_inverse,
}),
))
}
pub(super) fn fixed_subspace_penalty_trace_from_subspace(
&self,
penalty_subspace: &PenaltySubspace,
s_direction: &Array2<f64>,
) -> Result<f64, EstimationError> {
let p_dim = penalty_subspace.evals.len();
if penalty_subspace.rank == 0 || p_dim == 0 {
return Ok(0.0);
}
if s_direction.nrows() != p_dim || s_direction.ncols() != p_dim {
crate::bail_invalid_estim!(
"fixed_subspace_penalty_trace_from_subspace: S_direction must be {}x{}, got {}x{}",
p_dim,
p_dim,
s_direction.nrows(),
s_direction.ncols()
);
}
let mut trace = 0.0;
for idx in p_dim - penalty_subspace.rank..p_dim {
let ev = penalty_subspace.evals[idx];
let u = penalty_subspace.evecs.column(idx).to_owned();
let spsi_u = s_direction.dot(&u);
trace += u.dot(&spsi_u) / ev;
}
Ok(trace)
}
pub(super) fn updatewarm_start_from(&self, pr: &PirlsResult) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
match pr.status {
pirls::PirlsStatus::Converged | pirls::PirlsStatus::StalledAtValidMinimum => {
let frame_was_original = matches!(
pr.coordinate_frame,
pirls::PirlsCoordinateFrame::OriginalSparseNative
);
let beta_original = match pr.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => {
pr.beta_transformed.as_ref().clone()
}
pirls::PirlsCoordinateFrame::TransformedQs => {
pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
}
};
{
let mut prev_beta_w = self.prev_warm_start_beta.write().unwrap();
let mut prev_rho_w = self.prev_warm_start_rho.write().unwrap();
let mut cur_beta_w = self.warm_start_beta.write().unwrap();
let mut cur_rho_w = self.warm_start_rho.write().unwrap();
*prev_beta_w = cur_beta_w.take();
*prev_rho_w = cur_rho_w.take();
cur_beta_w.replace(Coefficients::new(beta_original.clone()));
}
let lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>> = {
use rayon::prelude::*;
let blocks: Vec<ndarray::Array1<f64>> = self
.canonical_penalties
.par_iter()
.map(|cp| {
let r = &cp.col_range;
let beta_block = beta_original.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
cp.local.dot(¢ered)
})
.collect();
if blocks.is_empty() {
None
} else {
Some(blocks)
}
};
{
let mut cache_w = self.ift_warm_start_cache.write().unwrap();
cache_w.replace(super::IftWarmStartCache {
beta_original,
rho: ndarray::Array1::zeros(0),
penalized_hessian_transformed: pr.penalized_hessian_transformed.clone(),
qs: pr.reparam_result.qs.clone(),
frame_was_original,
lambda_s_beta_blocks,
});
}
self.ift_cached_factor.write().unwrap().take();
self.clear_ift_mode_response_cache();
}
_ => {
self.clear_warm_start_predictor_state();
}
}
}
pub(super) fn record_warm_start_rho(&self, rho: &Array1<f64>) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
self.warm_start_rho.write().unwrap().replace(rho.to_owned());
if let Some(cache) = self.ift_warm_start_cache.write().unwrap().as_mut() {
cache.rho = rho.to_owned();
}
}
pub(crate) fn outer_cache_session(&self) -> Option<std::sync::Arc<crate::cache::Session>> {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return None;
}
let key = self.persistent_warm_start_cache_key()?;
crate::solver::persistent_warm_start::open_outer_session(&key)
}
fn persistent_warm_start_cache_key(&self) -> Option<String> {
if let Some(key) = self.persistent_warm_start_key.read().unwrap().clone() {
return Some(key);
}
let mut hasher = Fingerprinter::new();
hasher.write_str("gamfit-persistent-warm-start-v2");
hasher.write_str(&crate::solver::persistent_warm_start::cache_schema_tag());
hasher.write_usize(self.y.len());
hasher.write_usize(self.p);
hasher.write_str(&format!("{:?}", self.config.likelihood));
hasher.write_str(&format!("{:?}", self.config.link_kind));
hasher.write_f64(self.config.pirls_convergence_tolerance);
hasher.write_f64(self.config.reml_convergence_tolerance);
hasher.write_usize(self.config.max_iterations);
hasher.write_bool(self.config.firth_bias_reduction);
hasher.write_str(&format!("{:?}", self.runtime_mixture_link_state));
hasher.write_str(&format!("{:?}", self.runtime_sas_link_state));
match self.penalty_shrinkage_floor {
Some(value) => {
hasher.write_bool(true);
hasher.write_f64(value);
}
None => hasher.write_bool(false),
}
hasher.write_str(&format!("{:?}", self.rho_prior));
hash_array_view(&mut hasher, self.y);
hash_array_view(&mut hasher, self.weights);
hash_array_view(&mut hasher, self.offset.view());
if hash_design_matrix(&mut hasher, &self.x).is_err() {
return None;
}
hash_canonical_penalties(&mut hasher, self.canonical_penalties.as_ref());
hasher.write_u64(self.analytic_penalty_registry_fingerprint);
hasher.write_usize(self.nullspace_dims.len());
for &dim in &self.nullspace_dims {
hasher.write_usize(dim);
}
match self.coefficient_lower_bounds.as_ref() {
Some(bounds) => {
hasher.write_bool(true);
hash_array_view(&mut hasher, bounds.view());
}
None => hasher.write_bool(false),
}
match self.linear_constraints.as_ref() {
Some(constraints) => {
hasher.write_bool(true);
hash_array2(&mut hasher, &constraints.a);
hash_array_view(&mut hasher, constraints.b.view());
}
None => hasher.write_bool(false),
}
hasher.write_bool(self.kronecker_penalty_system.is_some());
hasher.write_bool(self.kronecker_factored.is_some());
let key = hasher.finish_hex();
self.persistent_warm_start_key
.write()
.unwrap()
.replace(key.clone());
Some(key)
}
fn persistent_latent_values_cache_key(&self) -> Option<String> {
let latent_fingerprint = self.persistent_latent_values_fingerprint?;
self.persistent_warm_start_cache_key()
.map(|key| format!("persistent-latent-values-v2:{key}:{latent_fingerprint:016x}"))
}
pub(crate) fn load_persistent_latent_values(
&self,
n_obs: usize,
latent_dim: usize,
) -> Option<Array2<f64>> {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return None;
}
let key = self.persistent_latent_values_cache_key()?;
self.persistent_latent_values_cache
.write()
.unwrap()
.lookup(&key, n_obs, latent_dim)
}
pub(crate) fn store_persistent_latent_values(&self, values: &Array2<f64>) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
let Some(key) = self.persistent_latent_values_cache_key() else {
return;
};
self.persistent_latent_values_cache
.write()
.unwrap()
.insert(key, values.clone());
}
fn load_persistent_warm_start_once(&self) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
if self
.persistent_warm_start_loaded
.swap(true, Ordering::Relaxed)
{
return;
}
if self.warm_start_beta.read().unwrap().is_some() {
return;
}
let Some(key) = self.persistent_warm_start_cache_key() else {
return;
};
let Some(record) = load_record(&key) else {
return;
};
if !record.is_compatible(&key, self.y.len(), self.p) {
return;
}
let rho_len = record.rho.len();
if rho_len != self.canonical_penalties.len() {
return;
}
{
self.warm_start_beta
.write()
.unwrap()
.replace(Coefficients::new(Array1::from_vec(record.beta)));
self.warm_start_rho
.write()
.unwrap()
.replace(Array1::from_vec(record.rho));
*self.prev_warm_start_beta.write().unwrap() = record
.prev_beta
.map(|beta| Coefficients::new(Array1::from_vec(beta)));
*self.prev_warm_start_rho.write().unwrap() = record.prev_rho.map(Array1::from_vec);
}
self.last_inner_iters
.store(record.last_inner_iters, Ordering::Relaxed);
self.last_inner_converged
.store(record.last_inner_converged, Ordering::Relaxed);
self.last_pirls_lm_lambda.store(
record
.last_pirls_lm_lambda
.filter(|v| v.is_finite() && *v > 0.0)
.map(f64::to_bits)
.unwrap_or(0),
Ordering::Relaxed,
);
self.last_ift_prediction_residual.store(
finite_nonnegative_bits_or_no_signal(record.last_ift_prediction_residual),
Ordering::Relaxed,
);
self.last_pirls_accept_rho.store(
finite_nonnegative_bits_or_no_signal(record.last_pirls_accept_rho),
Ordering::Relaxed,
);
log::info!("[warm-start-cache] restored persistent warm start key={key}");
}
fn store_persistent_warm_start(&self) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
let Some(key) = self.persistent_warm_start_cache_key() else {
return;
};
let Some(beta) = self.warm_start_beta.read().unwrap().as_ref().cloned() else {
return;
};
let Some(rho) = self.warm_start_rho.read().unwrap().clone() else {
return;
};
if beta.0.len() != self.p || rho.len() != self.canonical_penalties.len() {
return;
}
let mut record = PersistentWarmStartRecord::new(key, self.y.len(), self.p);
record.updated_unix_secs = record.created_unix_secs;
record.rho = rho.to_vec();
record.beta = beta.0.to_vec();
record.prev_rho = self
.prev_warm_start_rho
.read()
.unwrap()
.as_ref()
.map(|rho| rho.to_vec());
record.prev_beta = self
.prev_warm_start_beta
.read()
.unwrap()
.as_ref()
.map(|coefficients| coefficients.0.to_vec());
record.last_inner_iters = self.last_inner_iters.load(Ordering::Relaxed);
record.last_inner_converged = self.last_inner_converged.load(Ordering::Relaxed);
record.last_pirls_lm_lambda =
finite_positive_from_bits(self.last_pirls_lm_lambda.load(Ordering::Relaxed));
record.last_ift_prediction_residual =
finite_nonnegative_from_bits(self.last_ift_prediction_residual.load(Ordering::Relaxed));
record.last_pirls_accept_rho =
finite_nonnegative_from_bits(self.last_pirls_accept_rho.load(Ordering::Relaxed));
if let Err(err) = store_record(&record) {
log::warn!("[warm-start-cache] failed to persist warm start: {err}");
}
}
pub(crate) fn predict_warm_start_beta_ift_with_outcome(
&self,
new_rho: &Array1<f64>,
) -> Option<(Coefficients, IftPredictionOutcome)> {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return None;
}
let last_residual_bits = self.last_ift_prediction_residual.load(Ordering::Relaxed);
let r = f64::from_bits(last_residual_bits);
let last_residual = if r.is_finite() && r >= 0.0 {
Some(r)
} else {
None
};
let current_ift_step_cap = self.ift_quality_step_cap(adaptive_ift_max_drho(last_residual));
if self.joint_ift_cache_matches_pending_theta(new_rho) {
return self
.predict_warm_start_beta_joint_ift_with_outcome(new_rho, current_ift_step_cap);
}
let cache_guard = self.ift_warm_start_cache.read().unwrap();
let cache = cache_guard.as_ref()?;
if !cache.rho.is_empty() && cache.rho.len() == new_rho.len() {
let mut max_abs_drho = 0.0_f64;
let mut any_non_finite = false;
for i in 0..cache.rho.len() {
let d = new_rho[i] - cache.rho[i];
if !d.is_finite() {
any_non_finite = true;
max_abs_drho = f64::INFINITY;
break;
}
if d.abs() > max_abs_drho {
max_abs_drho = d.abs();
}
}
if !any_non_finite && max_abs_drho <= IFT_WARM_START_DRHO_EPS {
log::info!(
"[IFT-NOOP] reason=all_drho_below_eps max_drho={:.3e} drho_dim={}",
max_abs_drho,
cache.rho.len(),
);
return Some((
Coefficients::new(cache.beta_original.clone()),
IftPredictionOutcome::Noop,
));
}
let max_drho_cap = current_ift_step_cap;
if !max_abs_drho.is_finite() || max_abs_drho > max_drho_cap {
log::info!(
"[IFT-REJECTED] reason=large_drho max_drho={:.3e} cap={:.3e} drho_dim={}",
max_abs_drho,
max_drho_cap,
cache.rho.len(),
);
return None;
}
}
if let Some(rho_mode_response_cols) = self.cached_ift_rho_mode_response_cols(cache) {
if let Some(prediction) = predict_warm_start_beta_ift_from_mode_response_cols(
cache,
new_rho,
self.p,
last_residual,
Some(current_ift_step_cap),
&rho_mode_response_cols,
) {
log::info!(
"[IFT-CACHE] outcome=mode_response_hit drho_dim={} p={}",
new_rho.len(),
self.p,
);
return Some(prediction);
}
log::debug!(
"[IFT-CACHE] outcome=mode_response_fallback drho_dim={} p={}",
new_rho.len(),
self.p,
);
}
let factor_arc: Arc<dyn crate::linalg::matrix::FactorizedSystem> = {
let read_guard = self.ift_cached_factor.read().unwrap();
if let Some(arc) = read_guard.as_ref() {
log::info!(
"[IFT-CACHE] outcome=hit drho_dim={} p={}",
new_rho.len(),
self.p,
);
Arc::clone(arc)
} else {
drop(read_guard);
let factorize_start = std::time::Instant::now();
let new_factor = match cache.penalized_hessian_transformed.factorize() {
Ok(f) => f,
Err(_) => {
log::info!(
"[IFT-REJECTED] reason=hessian_factorize_failed_cached drho_dim={}",
new_rho.len(),
);
return None;
}
};
log::info!(
"[IFT-CACHE] outcome=miss drho_dim={} p={} elapsed={:.3}s",
new_rho.len(),
self.p,
factorize_start.elapsed().as_secs_f64(),
);
let arc: Arc<dyn crate::linalg::matrix::FactorizedSystem> = Arc::from(new_factor);
let mut write_guard = self.ift_cached_factor.write().unwrap();
if let Some(existing) = write_guard.as_ref() {
Arc::clone(existing)
} else {
*write_guard = Some(Arc::clone(&arc));
arc
}
}
};
predict_warm_start_beta_ift_inner_with_outcome(
cache,
self.canonical_penalties.as_ref(),
new_rho,
self.p,
last_residual,
Some(current_ift_step_cap),
Some(factor_arc.as_ref()),
)
}
pub(crate) fn predict_warm_start_beta_with_source(
&self,
new_rho: &Array1<f64>,
) -> Option<(Coefficients, WarmStartPredictionSource)> {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return None;
}
if self.take_ift_quality_flat_override()
&& let Some(cur_beta) = self.warm_start_beta.read().unwrap().clone()
{
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
if let Some((predicted, outcome)) = self.predict_warm_start_beta_ift_with_outcome(new_rho) {
log::debug!("[warm-start] IFT prediction accepted");
let source = match outcome {
IftPredictionOutcome::Predicted => WarmStartPredictionSource::Ift,
IftPredictionOutcome::Noop => WarmStartPredictionSource::Flat,
};
return Some((predicted, source));
}
let cur_beta = self.warm_start_beta.read().unwrap().clone()?;
let cur_rho = self.warm_start_rho.read().unwrap().clone();
let prev_beta = self.prev_warm_start_beta.read().unwrap().clone();
let prev_rho = self.prev_warm_start_rho.read().unwrap().clone();
let (cur_rho, prev_beta, prev_rho) = match (cur_rho, prev_beta, prev_rho) {
(Some(cr), Some(pb), Some(pr)) => (cr, pb, pr),
_ => return Some((cur_beta, WarmStartPredictionSource::Flat)),
};
if cur_rho.len() != new_rho.len() || cur_rho.len() != prev_rho.len() {
log::info!(
"[TANGENT-REJECTED] reason=rho_dim_mismatch new_rho_dim={} cur_rho_dim={} prev_rho_dim={}",
new_rho.len(),
cur_rho.len(),
prev_rho.len(),
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
if cur_beta.0.len() != prev_beta.0.len() {
log::info!(
"[TANGENT-REJECTED] reason=beta_dim_mismatch cur_beta_dim={} prev_beta_dim={}",
cur_beta.0.len(),
prev_beta.0.len(),
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
const DEGENERATE_DRHO_NORM_SQ: f64 = 1e-24;
let d_rho_norm_sq: f64 = cur_rho
.iter()
.zip(prev_rho.iter())
.map(|(c, p)| (c - p) * (c - p))
.sum();
if !d_rho_norm_sq.is_finite() || d_rho_norm_sq <= DEGENERATE_DRHO_NORM_SQ {
log::info!(
"[TANGENT-REJECTED] reason=degenerate_drho d_rho_norm_sq={:.3e}",
d_rho_norm_sq,
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
let step_dot_d: f64 = new_rho
.iter()
.zip(cur_rho.iter())
.zip(prev_rho.iter())
.map(|((n, c), p)| (n - c) * (c - p))
.sum();
let alpha = step_dot_d / d_rho_norm_sq;
if !alpha.is_finite() {
log::info!(
"[TANGENT-REJECTED] reason=nonfinite_alpha step_dot_d={:.3e} d_rho_norm_sq={:.3e}",
step_dot_d,
d_rho_norm_sq,
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
let last_residual_bits = self.last_ift_prediction_residual.load(Ordering::Relaxed);
let r = f64::from_bits(last_residual_bits);
let last_residual = if r.is_finite() && r >= 0.0 {
Some(r)
} else {
None
};
let alpha_cap = adaptive_tangent_alpha_cap(last_residual);
if alpha <= 0.0 || alpha > alpha_cap {
let reason = if alpha <= 0.0 {
"alpha_negative"
} else {
"alpha_above_cap"
};
log::info!(
"[TANGENT-REJECTED] reason={} alpha={:.3e} cap={:.3e}",
reason,
alpha,
alpha_cap,
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
const TANGENT_ALPHA_NOOP_EPS: f64 = 1e-12;
if alpha.abs() <= TANGENT_ALPHA_NOOP_EPS {
log::info!(
"[TANGENT-NOOP] reason=alpha_below_eps alpha={:.3e} eps={:.3e}",
alpha,
TANGENT_ALPHA_NOOP_EPS,
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
let mut predicted = cur_beta.0.clone();
for ((p, c), pp) in predicted
.iter_mut()
.zip(cur_beta.0.iter())
.zip(prev_beta.0.iter())
{
*p = c + alpha * (c - pp);
}
if !predicted.iter().all(|v: &f64| v.is_finite()) {
log::info!(
"[TANGENT-REJECTED] reason=non_finite_predicted alpha={:.3e} cap={:.3e}",
alpha,
alpha_cap,
);
return Some((cur_beta, WarmStartPredictionSource::Flat));
}
log::info!(
"[TANGENT-PREDICT] alpha={:.3e} cap={:.3e} drho_step_norm_sq={:.3e} drho_prev_norm_sq={:.3e}",
alpha,
alpha_cap,
step_dot_d.abs(),
d_rho_norm_sq,
);
Some((
Coefficients::new(predicted),
WarmStartPredictionSource::TangentLine,
))
}
pub(crate) fn setwarm_start_original_beta(&self, beta_original: Option<ArrayView1<'_, f64>>) {
if !self.warm_start_enabled.load(Ordering::Relaxed) {
return;
}
if let Some(beta) = beta_original {
self.clear_warm_start_predictor_state();
self.clear_warm_start_adaptive_signals();
if beta.len() == self.p {
if !beta.iter().all(|v: &f64| v.is_finite()) {
log::warn!(
"[warm-start] external β setter rejected non-finite seed (len={}); slot left empty",
beta.len(),
);
return;
}
self.warm_start_beta
.write()
.unwrap()
.replace(Coefficients::new(beta.to_owned()));
} else {
log::warn!(
"[warm-start] external β setter rejected length mismatch: got {}, expected {}",
beta.len(),
self.p,
);
}
}
}
pub(crate) fn current_original_basis_beta(&self) -> Option<Array1<f64>> {
let beta_guard = self.warm_start_beta.read().ok()?;
let beta = beta_guard.as_ref()?;
if beta.0.len() == self.p && beta.0.iter().all(|v| v.is_finite()) {
Some(beta.0.clone())
} else {
None
}
}
pub(crate) fn reset_outer_seed_state(&self) {
self.cache_manager.invalidate_eval_bundle();
self.cache_manager.pirls_cache.write().unwrap().clear();
self.clear_warm_start_predictor_state();
self.clear_warm_start_adaptive_signals();
self.outer_inner_cap.store(0, Ordering::Relaxed);
self.screening_max_inner_iterations
.store(0, Ordering::Relaxed);
self.reset_hypergradient_budget_controller();
}
pub(crate) fn x(&self) -> &DesignMatrix {
&self.x
}
pub(crate) fn balanced_penalty_root(&self) -> &Array2<f64> {
&self.balanced_penalty_root
}
pub(crate) fn gaussian_fixed_cache_if_eligible(
&self,
) -> Option<Arc<crate::pirls::GaussianFixedCache>> {
let spec = reml_spec(&self.config.likelihood);
let family_ok = matches!(spec.response, ResponseFamily::Gaussian);
let link_ok = matches!(
self.config.link_kind,
crate::types::InverseLink::Standard(StandardLink::Identity)
);
if !family_ok
|| !link_ok
|| self.config.firth_bias_reduction
|| self.coefficient_lower_bounds.is_some()
|| self.linear_constraints.is_some()
{
return None;
}
{
let guard = self.gaussian_fixed_cache.read().unwrap();
if let Some(cache) = guard.as_ref() {
return Some(Arc::clone(cache));
}
}
let mut guard = self.gaussian_fixed_cache.write().unwrap();
if let Some(cache) = guard.as_ref() {
return Some(Arc::clone(cache));
}
let build_start = std::time::Instant::now();
let weights_owned = self.weights.to_owned();
let mut wz = self.y.to_owned();
wz -= &self.offset;
wz *= &weights_owned;
let centered_weighted_y_sq = self
.y
.iter()
.zip(self.offset.iter())
.zip(weights_owned.iter())
.map(|((&y, &offset), &w)| {
let centered = y - offset;
w * centered * centered
})
.sum::<f64>();
let xtwx = match crate::linalg::matrix::LinearOperator::xt_diag_x_signed_op(
&self.x,
crate::linalg::matrix::SignedWeightsView::from_array(&weights_owned),
) {
Ok(m) => m,
Err(e) => {
log::warn!("[gaussian-fixed-cache] disabling cache: failed to build XᵀWX: {e}");
return None;
}
};
let xtwy = self.x.transpose_vector_multiply(&wz);
let xtwx_sparse_orig = if let Some(sparse_design) = self.x.as_sparse() {
let sparse_start = std::time::Instant::now();
match crate::pirls::SparseXtwxPrecomputed::build(sparse_design.as_ref(), &weights_owned)
{
Ok(precomp) => {
log::info!(
"[gaussian-fixed-cache] sparse XᵀWX nnz={} built in {:.3} ms",
precomp.xtwxvalues.len(),
sparse_start.elapsed().as_secs_f64() * 1e3
);
Some(Arc::new(precomp))
}
Err(e) => {
log::warn!(
"[gaussian-fixed-cache] sparse XᵀWX build failed; falling back: {e}"
);
None
}
}
} else {
None
};
let cache = Arc::new(crate::pirls::GaussianFixedCache {
xtwx_orig: xtwx,
xtwy_orig: xtwy,
centered_weighted_y_sq,
xtwx_sparse_orig,
});
log::info!(
"[gaussian-fixed-cache] built p={} n={} in {:.3} ms",
self.p,
self.y.len(),
build_start.elapsed().as_secs_f64() * 1e3
);
*guard = Some(Arc::clone(&cache));
Some(cache)
}
pub(crate) fn canonical_penalties(&self) -> &[crate::construction::CanonicalPenalty] {
&self.canonical_penalties
}
pub(super) fn sparse_penalty_logdet_runtime(
&self,
rho: &Array1<f64>,
blocks: &[SparsePenaltyBlock],
) -> (usize, f64, Array1<f64>) {
let mut logdet = 0.0_f64;
let mut det1 = Array1::<f64>::zeros(rho.len());
let mut penalty_rank = 0usize;
for block in blocks {
for &eig in block.positive_eigenvalues.iter() {
logdet += rho[block.penalty_idx.get()] + eig.ln();
}
penalty_rank += block.positive_eigenvalues.len();
if block.penalty_idx.get() < det1.len() {
det1[block.penalty_idx.get()] = block.positive_eigenvalues.len() as f64;
}
}
(penalty_rank.min(self.p), logdet, det1)
}
pub(super) fn prepare_dense_eval_bundlewithkey(
&self,
rho: &Array1<f64>,
key: Option<Vec<u64>>,
) -> Result<EvalShared, EstimationError> {
let pirls_result = self.execute_pirls_if_needed(rho)?;
let (mut h_total, ridge_passport) = self.effectivehessian(pirls_result.as_ref())?;
let mut firth_dense_operator: Option<Arc<FirthDenseOperator>> = None;
if let Some(jeffreys_link) = reml_robust_jeffreys_link(&self.config) {
let firth_n = pirls_result.x_transformed.nrows();
let firth_p = pirls_result.x_transformed.ncols();
if !super::firth_problem_scale_allows(firth_n, firth_p) {
log::info!(
"disabling Firth bias reduction for large model (n={}, p={}, n*p={}, n*p^2={}): \
exact Firth operator is small-model-only",
firth_n,
firth_p,
firth_n.saturating_mul(firth_p),
firth_n.saturating_mul(firth_p).saturating_mul(firth_p),
);
} else {
let x_dense = pirls_result
.x_transformed
.try_to_dense_arc(
"dense REML eval bundle requires dense transformed design for Firth operator",
)
.map_err(EstimationError::InvalidInput)?;
let firth_build_start = std::time::Instant::now();
let firth_op = Arc::new(Self::build_firth_dense_operator_for_link(
&jeffreys_link,
x_dense.as_ref(),
&pirls_result.final_eta,
self.weights,
)?);
log::debug!(
"[Firth-op] build n={} p={} r={} half_logdet={:.3e} elapsed={:.3}s",
firth_op.x_dense.nrows(),
firth_op.x_dense.ncols(),
firth_op.k_reduced.nrows(),
firth_op.half_log_det,
firth_build_start.elapsed().as_secs_f64(),
);
let mut weighted_xtdx = Array2::<f64>::zeros((0, 0));
let diag_term = Self::xt_diag_x_dense_into(
&firth_op.x_dense,
&(&firth_op.w2 * &firth_op.h_diag),
&mut weighted_xtdx,
);
let bpb = crate::faer_ndarray::fast_atb(&firth_op.b_base, &firth_op.p_b_base);
let mut hphi = 0.5 * (diag_term - bpb);
enforce_symmetry(&mut hphi);
if hphi.iter().all(|v| v.is_finite()) {
h_total -= &hphi;
}
firth_dense_operator = Some(firth_op);
} }
if let Some(ref lin) = pirls_result.linear_constraints_transformed
&& let Some(barrier_cfg) = Self::barrier_config_from_constraints(lin)
{
let beta_t = pirls_result.beta_transformed.as_ref();
if let Err(e) = barrier_cfg.add_barrier_hessian_diagonal(&mut h_total, beta_t) {
log::warn!("Barrier Hessian diagonal skipped: {e}");
}
}
Ok(EvalShared {
key,
pirls_result,
ridge_passport,
geometry: RemlGeometry::DenseSpectral,
h_total: Arc::new(h_total),
sparse_exact: None,
firth_dense_operator,
firth_dense_operator_original: None,
})
}
pub(super) fn prepare_sparse_eval_bundlewithkey(
&self,
rho: &Array1<f64>,
key: Option<Vec<u64>>,
) -> Result<EvalShared, EstimationError> {
let pirls_result = self.execute_pirls_if_needed(rho)?;
if !matches!(
pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::OriginalSparseNative
) {
crate::bail_invalid_estim!(
"sparse exact geometry requires sparse-native PIRLS coordinates"
);
}
let ridge_passport = pirls_result.ridge_passport;
let x_sparse = self.x().as_sparse().ok_or_else(|| {
EstimationError::InvalidInput(
"sparse exact geometry requires sparse original design".to_string(),
)
})?;
let penalty_blocks = self
.sparse_penalty_blocks
.as_ref()
.ok_or_else(|| {
EstimationError::InvalidInput(
"sparse exact geometry requires block-separable penalties".to_string(),
)
})?
.clone();
let lambdas = rho.mapv(f64::exp);
let mut s_lambda = Array2::<f64>::zeros((self.p, self.p));
for (k, cp) in self.canonical_penalties.iter().enumerate() {
if k < lambdas.len() && lambdas[k] != 0.0 {
cp.accumulate_weighted(&mut s_lambda, lambdas[k]);
}
}
if let Some(ref lin) = self.linear_constraints
&& let Some(barrier_cfg) = Self::barrier_config_from_constraints(lin)
{
let beta_orig = self.sparse_exact_beta_original(pirls_result.as_ref());
if let Err(e) = barrier_cfg.add_barrier_hessian_diagonal(&mut s_lambda, &beta_orig) {
log::warn!("Sparse barrier Hessian diagonal skipped: {e}");
}
}
let mut workspace = PirlsWorkspace::new(self.y.len(), self.p, 0, 0);
let gaussian_cache = self.gaussian_fixed_cache_if_eligible();
let precomputed_xtwx = gaussian_cache
.as_ref()
.and_then(|c| c.xtwx_sparse_orig.as_ref().map(|arc| arc.as_ref()));
let (hessian_weights, _, _) = self.hessian_surface_arrays(pirls_result.as_ref())?;
let sparse_system = assemble_and_factor_sparse_penalized_system(
&mut workspace,
x_sparse,
&hessian_weights,
&s_lambda,
ridge_passport.delta,
precomputed_xtwx,
)?;
let (penalty_rank, logdet_s_pos, det1_values) =
self.sparse_penalty_logdet_runtime(rho, penalty_blocks.as_ref());
let firth_dense_operator_original = if let Some(jeffreys_link) =
reml_robust_jeffreys_link(&self.config)
{
let firth_n = self.x().nrows();
let firth_p = self.x().ncols();
if !super::firth_problem_scale_allows(firth_n, firth_p) {
log::info!(
"disabling Firth bias reduction for large model (n={}, p={}, n*p={}, n*p^2={}): \
exact Firth operator is small-model-only",
firth_n,
firth_p,
firth_n.saturating_mul(firth_p),
firth_n.saturating_mul(firth_p).saturating_mul(firth_p),
);
None
} else {
let x_dense = self
.x()
.try_to_dense_arc(
"sparse exact REML runtime requires dense design for Firth operator",
)
.map_err(EstimationError::InvalidInput)?;
Some(Arc::new(Self::build_firth_dense_operator_for_link(
&jeffreys_link,
x_dense.as_ref(),
&pirls_result.final_eta,
self.weights,
)?))
}
} else {
None
};
Ok(EvalShared {
key,
pirls_result,
ridge_passport,
geometry: RemlGeometry::SparseExactSpd,
h_total: Arc::new(Array2::zeros((0, 0))),
sparse_exact: Some(Arc::new({
let factor = Arc::new(sparse_system.factor);
let sfactor =
crate::linalg::sparse_exact::factorize_simplicial(&sparse_system.h_sparse)?;
let takahashi = Some(Arc::new(
crate::linalg::sparse_exact::TakahashiInverse::compute(&sfactor)?,
));
SparseExactEvalData {
factor,
takahashi,
logdet_h: sparse_system.logdet_h,
logdet_s_pos,
penalty_rank,
det1_values: Arc::new(det1_values),
}
})),
firth_dense_operator: None,
firth_dense_operator_original,
})
}
pub(super) fn execute_pirls_if_needed(
&self,
rho: &Array1<f64>,
) -> Result<Arc<PirlsResult>, EstimationError> {
let use_cache = self
.cache_manager
.pirls_cache_enabled
.load(Ordering::Relaxed);
let key_opt = self.rhokey_sanitized(rho);
if use_cache
&& let Some(key) = &key_opt
&& let Some(cached) = self.cache_manager.pirls_cache.write().unwrap().get(key)
{
if cached.cache_compacted {
let mut pirls_config = self.config.as_pirls_config();
pirls_config.link_kind =
if let Some(state) = self.runtime_mixture_link_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = self.runtime_sas_link_state {
if matches!(self.config.link_function(), LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
InverseLink::Standard(
StandardLink::try_from(self.config.link_function())
.expect("state-bearing link without runtime state"),
)
};
return Ok(Arc::new(cached.rehydrate_after_reml_cache(
self.x(),
self.y,
self.weights,
self.offset.view(),
&pirls_config.link_kind,
)?));
}
return Ok(cached);
}
let screening_cap = self.screening_max_inner_iterations.load(Ordering::Relaxed);
let in_screening = screening_cap > 0;
let raw_outer_cap = self.outer_inner_cap.load(Ordering::Relaxed);
let efs_single_loop_cap = decode_efs_single_loop_cap(raw_outer_cap);
let in_efs_single_loop = efs_single_loop_cap.is_some();
let outer_cap = efs_single_loop_cap.unwrap_or(raw_outer_cap);
if !in_screening {
self.load_persistent_warm_start_once();
}
let predicted_warm_start_with_source = if self.warm_start_enabled.load(Ordering::Relaxed) {
self.predict_warm_start_beta_with_source(rho)
} else {
None
};
let predicted_warm_start = predicted_warm_start_with_source
.as_ref()
.map(|(c, _)| c.clone());
let prediction_source = predicted_warm_start_with_source.as_ref().map(|(_, s)| *s);
let pirls_result = {
let warm_start_holder = self.warm_start_beta.read().unwrap();
let fallback_warm_start_ref = if self.warm_start_enabled.load(Ordering::Relaxed) {
warm_start_holder.as_ref()
} else {
None
};
let warm_start_ref = predicted_warm_start.as_ref().or(fallback_warm_start_ref);
let mut pirls_config = self.config.as_pirls_config();
let original_cap = pirls_config.max_iterations;
if in_screening {
pirls_config.max_iterations = pirls_config.max_iterations.min(screening_cap);
}
if outer_cap > 0 {
pirls_config.max_iterations = pirls_config.max_iterations.min(outer_cap);
}
if pirls_config.max_iterations != original_cap {
log::debug!(
"[PIRLS cap] inner_max_iterations={} (full={} screening={} outer={})",
pirls_config.max_iterations,
original_cap,
if in_screening {
screening_cap as i64
} else {
-1
},
if outer_cap > 0 { outer_cap as i64 } else { -1 },
);
}
pirls_config.link_kind = if let Some(state) = self.runtime_mixture_link_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = self.runtime_sas_link_state {
if matches!(self.config.link_function(), LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
InverseLink::Standard(
StandardLink::try_from(self.config.link_function())
.expect("state-bearing link without runtime state"),
)
};
let cached_lambda_bits = self.last_pirls_lm_lambda.load(Ordering::Relaxed);
if cached_lambda_bits != 0 {
let cached_lambda = f64::from_bits(cached_lambda_bits);
let last_iters = self.last_inner_iters.load(Ordering::Relaxed);
let last_converged = self.last_inner_converged.load(Ordering::Relaxed);
pirls_config.initial_lm_lambda =
adaptive_lm_lambda_hint(cached_lambda, last_iters, last_converged);
}
let adaptive_kkt_tolerance = if !in_screening {
if let Some(override_tol) = self.hypergradient_adaptive_kkt_override(&pirls_config)
{
Some(override_tol)
} else if let Some(outer_grad_norm) = self.previous_outer_gradient_norm(&key_opt) {
let ceiling = pirls_config.convergence_tolerance;
let floor = (self.config.reml_convergence_tolerance
/ ADAPTIVE_KKT_FLOOR_REML_DIVISOR)
.min(ceiling);
(floor > 0.0 && ceiling >= floor).then_some(pirls::AdaptiveKktTolerance {
eta: ADAPTIVE_KKT_ETA,
floor,
ceiling,
outer_grad_norm,
})
} else {
None
}
} else {
None
};
let cache_handle = self.gaussian_fixed_cache_if_eligible();
let problem = pirls::PirlsProblem {
x: &self.x,
offset: self.offset.view(),
y: self.y,
priorweights: self.weights,
covariate_se: None,
gaussian_fixed_cache: cache_handle.as_deref(),
};
let penalty = pirls::PenaltyConfig {
canonical_penalties: &self.canonical_penalties,
balanced_penalty_root: Some(&self.balanced_penalty_root),
reparam_invariant: Some(&self.reparam_invariant),
p: self.p,
coefficient_lower_bounds: self.coefficient_lower_bounds.as_ref(),
linear_constraints_original: self.linear_constraints.as_ref(),
penalty_shrinkage_floor: self.penalty_shrinkage_floor,
kronecker_factored: self.kronecker_factored.as_ref(),
};
let pirls_start = std::time::Instant::now();
let result = pirls::fit_model_for_fixed_rho_with_adaptive_kkt(
LogSmoothingParamsView::new(rho.view()),
problem,
penalty,
&pirls_config,
warm_start_ref,
adaptive_kkt_tolerance,
false,
);
let pirls_elapsed = pirls_start.elapsed();
if let Ok((ref res, ref wm)) = result {
log::info!(
"[STAGE] inner pirls solve iters={} status={:?} max_eta={:.1} jeffreys_logdet={} elapsed={:.3}s",
wm.iterations,
res.status,
res.max_abs_eta,
res.jeffreys_logdet()
.map(|v| format!("{v:.3e}"))
.unwrap_or_else(|| "none".to_string()),
pirls_elapsed.as_secs_f64(),
);
}
result
};
if let Err(e) = &pirls_result {
if in_screening {
log::debug!("[seed-screen] P-IRLS rejected candidate: {e:?}");
} else {
log::warn!("[GAM COST] -> P-IRLS INNER LOOP FAILED. Error: {e:?}");
}
}
let (pirls_result, _) = pirls_result?; let pirls_result = Arc::new(pirls_result);
if !in_screening && !in_efs_single_loop {
self.enforce_constraint_kkt(pirls_result.as_ref())?;
}
match pirls_result.status {
pirls::PirlsStatus::Converged | pirls::PirlsStatus::StalledAtValidMinimum => {
if !in_screening {
if let Some(predicted) = predicted_warm_start.as_ref() {
let converged_original = match pirls_result.coordinate_frame {
pirls::PirlsCoordinateFrame::OriginalSparseNative => {
pirls_result.beta_transformed.as_ref().clone()
}
pirls::PirlsCoordinateFrame::TransformedQs => pirls_result
.reparam_result
.qs
.dot(pirls_result.beta_transformed.as_ref()),
};
if matches!(prediction_source, Some(WarmStartPredictionSource::Ift))
&& predicted.0.len() == converged_original.len()
{
let mut diff_sq = 0.0_f64;
let mut conv_sq = 0.0_f64;
for (p_val, c_val) in predicted.0.iter().zip(converged_original.iter())
{
let d = c_val - p_val;
diff_sq += d * d;
conv_sq += c_val * c_val;
}
let conv_norm = conv_sq.sqrt();
let pred_residual = diff_sq.sqrt();
let quality = pred_residual / (1.0 + conv_norm);
if quality.is_finite() && quality >= 0.0 {
let last_residual_bits =
self.last_ift_prediction_residual.load(Ordering::Relaxed);
let r = f64::from_bits(last_residual_bits);
let last_residual = if r.is_finite() && r >= 0.0 {
Some(r)
} else {
None
};
let current_cap =
self.ift_quality_step_cap(adaptive_ift_max_drho(last_residual));
let cap_predicted = self
.record_ift_prediction_quality(quality, current_cap)
.unwrap_or(current_cap);
log::info!(
"[IFT-QUALITY] quality={:.3e} ift={:.3e} pred_residual={:.3e} cap_predicted={:.3e}",
quality,
current_cap,
pred_residual,
cap_predicted,
);
self.last_ift_prediction_residual
.store(quality.to_bits(), Ordering::Relaxed);
}
}
}
self.updatewarm_start_from(pirls_result.as_ref());
self.record_warm_start_rho(rho);
self.last_inner_iters
.store(pirls_result.iteration, Ordering::Relaxed);
self.last_inner_converged.store(true, Ordering::Relaxed);
if pirls_result.final_lm_lambda.is_finite()
&& pirls_result.final_lm_lambda > 0.0
{
self.last_pirls_lm_lambda
.store(pirls_result.final_lm_lambda.to_bits(), Ordering::Relaxed);
}
if let Some(rho) = pirls_result.final_accept_rho
&& rho.is_finite()
&& rho >= 0.0
{
self.last_pirls_accept_rho
.store(rho.to_bits(), Ordering::Relaxed);
}
self.store_persistent_warm_start();
if use_cache && let Some(key) = key_opt {
self.cache_manager
.pirls_cache
.write()
.unwrap()
.insert(key, Arc::new(pirls_result.compact_for_reml_cache()));
}
}
Ok(pirls_result)
}
pirls::PirlsStatus::Unstable => {
Err(EstimationError::PerfectSeparationDetected {
iteration: pirls_result.iteration,
max_abs_eta: pirls_result.max_abs_eta,
})
}
pirls::PirlsStatus::MaxIterationsReached
| pirls::PirlsStatus::LmStepSearchExhausted => {
let kind = match pirls_result.status {
pirls::PirlsStatus::LmStepSearchExhausted => "LM step search exhausted",
_ => "max iterations reached",
};
if in_efs_single_loop
&& pirls_result.deviance.is_finite()
&& pirls_result.stable_penalty_term.is_finite()
&& pirls_result.gradient_natural_scale.is_finite()
&& pirls_result.lastgradient_norm.is_finite()
&& pirls_result
.beta_transformed
.0
.iter()
.all(|v| v.is_finite())
{
log::info!(
"[EFS-single-loop] accepted partial PIRLS sweep: {kind} \
(cap={} |g_beta|={:.3e} r_g={:.3e} iter={})",
efs_single_loop_cap.unwrap_or(outer_cap),
pirls_result.lastgradient_norm,
pirls_result.relative_gradient_norm(),
pirls_result.iteration,
);
self.updatewarm_start_from(pirls_result.as_ref());
self.record_warm_start_rho(rho);
self.last_inner_iters
.store(pirls_result.iteration, Ordering::Relaxed);
self.last_inner_converged.store(false, Ordering::Relaxed);
return Ok(pirls_result);
}
if in_screening
&& pirls_result.deviance.is_finite()
&& pirls_result.stable_penalty_term.is_finite()
&& pirls_result.gradient_natural_scale.is_finite()
&& pirls_result.lastgradient_norm.is_finite()
&& pirls_result
.beta_transformed
.0
.iter()
.all(|v| v.is_finite())
{
log::debug!(
"[seed-screen] partial-fit accepted for ranking: {kind} (|g| {:.3e}, r_g {:.3e}, iter {})",
pirls_result.lastgradient_norm,
pirls_result.relative_gradient_norm(),
pirls_result.iteration
);
return Ok(pirls_result);
}
if in_screening {
log::debug!(
"[seed-screen] P-IRLS rejected: {kind} (gradient norm {:.3e}, iter {})",
pirls_result.lastgradient_norm,
pirls_result.iteration
);
} else {
log::error!(
"P-IRLS could not certify a valid minimum: {kind} (gradient norm {:.3e}, iter {})",
pirls_result.lastgradient_norm,
pirls_result.iteration
);
self.last_inner_iters
.store(pirls_result.iteration, Ordering::Relaxed);
self.last_inner_converged.store(false, Ordering::Relaxed);
self.last_pirls_lm_lambda.store(0, Ordering::Relaxed);
self.last_ift_prediction_residual
.store(IFT_RESIDUAL_NO_SIGNAL_BITS, Ordering::Relaxed);
self.clear_ift_quality_runtime_state();
}
Err(EstimationError::PirlsDidNotConverge {
max_iterations: pirls_result.iteration,
last_change: pirls_result.lastgradient_norm,
})
}
}
}
pub(super) fn execute_pirls_stateless_for_cubature(
&self,
rho: &Array1<f64>,
) -> Result<Arc<PirlsResult>, EstimationError> {
let mut pirls_config = self.config.as_pirls_config();
pirls_config.link_kind = if let Some(state) = self.runtime_mixture_link_state.clone() {
InverseLink::Mixture(state)
} else if let Some(state) = self.runtime_sas_link_state {
if matches!(self.config.link_function(), LinkFunction::BetaLogistic) {
InverseLink::BetaLogistic(state)
} else {
InverseLink::Sas(state)
}
} else {
InverseLink::Standard(
StandardLink::try_from(self.config.link_function())
.expect("state-bearing link without runtime state"),
)
};
let cache_handle = self.gaussian_fixed_cache_if_eligible();
let problem = pirls::PirlsProblem {
x: &self.x,
offset: self.offset.view(),
y: self.y,
priorweights: self.weights,
covariate_se: None,
gaussian_fixed_cache: cache_handle.as_deref(),
};
let penalty = pirls::PenaltyConfig {
canonical_penalties: &self.canonical_penalties,
balanced_penalty_root: Some(&self.balanced_penalty_root),
reparam_invariant: Some(&self.reparam_invariant),
p: self.p,
coefficient_lower_bounds: self.coefficient_lower_bounds.as_ref(),
linear_constraints_original: self.linear_constraints.as_ref(),
penalty_shrinkage_floor: self.penalty_shrinkage_floor,
kronecker_factored: self.kronecker_factored.as_ref(),
};
let pirls_start = std::time::Instant::now();
let result = pirls::fit_model_for_fixed_rho_with_adaptive_kkt(
LogSmoothingParamsView::new(rho.view()),
problem,
penalty,
&pirls_config,
None,
None,
false,
);
let pirls_elapsed = pirls_start.elapsed();
if let Ok((ref res, ref wm)) = result {
log::info!(
"[STAGE] sigma-cubature pirls solve iters={} status={:?} max_eta={:.1} elapsed={:.3}s",
wm.iterations,
res.status,
res.max_abs_eta,
pirls_elapsed.as_secs_f64(),
);
}
let (pirls_result, _) = result?;
let pirls_result = Arc::new(pirls_result);
self.enforce_constraint_kkt(pirls_result.as_ref())?;
match pirls_result.status {
pirls::PirlsStatus::Converged | pirls::PirlsStatus::StalledAtValidMinimum => {
Ok(pirls_result)
}
pirls::PirlsStatus::Unstable => Err(EstimationError::PerfectSeparationDetected {
iteration: pirls_result.iteration,
max_abs_eta: pirls_result.max_abs_eta,
}),
pirls::PirlsStatus::MaxIterationsReached
| pirls::PirlsStatus::LmStepSearchExhausted => {
Err(EstimationError::PirlsDidNotConverge {
max_iterations: pirls_result.iteration,
last_change: pirls_result.lastgradient_norm,
})
}
}
}
}
const IFT_WARM_START_DEFAULT_MAX_DRHO: f64 = 2.0;
const IFT_RESIDUAL_TIER_EXCELLENT: f64 = 0.01;
const IFT_RESIDUAL_TIER_VERY_GOOD: f64 = 0.05;
const IFT_RESIDUAL_TIER_OK: f64 = 0.20;
const IFT_RESIDUAL_TIER_MARGINAL: f64 = 0.50;
fn adaptive_ift_max_drho(last_residual: Option<f64>) -> f64 {
let Some(r) = last_residual else {
return IFT_WARM_START_DEFAULT_MAX_DRHO;
};
if r.is_nan() || r < 0.0 {
return IFT_WARM_START_DEFAULT_MAX_DRHO;
}
match r {
r if r < IFT_RESIDUAL_TIER_EXCELLENT => 4.0,
r if r < IFT_RESIDUAL_TIER_VERY_GOOD => 3.0,
r if r < IFT_RESIDUAL_TIER_OK => 2.0,
r if r < IFT_RESIDUAL_TIER_MARGINAL => 1.0,
_ => 0.5,
}
}
const TANGENT_ALPHA_DEFAULT_CAP: f64 = 1.5;
fn adaptive_tangent_alpha_cap(last_residual: Option<f64>) -> f64 {
let Some(r) = last_residual else {
return TANGENT_ALPHA_DEFAULT_CAP;
};
if r.is_nan() || r < 0.0 {
return TANGENT_ALPHA_DEFAULT_CAP;
}
match r {
r if r < IFT_RESIDUAL_TIER_EXCELLENT => 2.0,
r if r < IFT_RESIDUAL_TIER_VERY_GOOD => 1.75,
r if r < IFT_RESIDUAL_TIER_OK => 1.5,
r if r < IFT_RESIDUAL_TIER_MARGINAL => 1.0,
_ => 0.5,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum IftPredictionOutcome {
Predicted,
Noop,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum WarmStartPredictionSource {
Ift,
TangentLine,
Flat,
}
const IFT_WARM_START_DRHO_EPS: f64 = 1e-12;
pub(crate) const IFT_RESIDUAL_NO_SIGNAL_BITS: u64 = 0x7ff8_0000_0000_0000;
pub(crate) fn adaptive_lm_lambda_hint(
cached_lambda: f64,
last_iters: usize,
last_converged: bool,
) -> Option<f64> {
const NEWTON_FRIENDLY_MAX_ITERS: usize = 2;
const HARD_FIT_MIN_ITERS: usize = 10;
const NEWTON_LAMBDA_FLOOR: f64 = 1e-9;
const NEWTON_LAMBDA_CEILING: f64 = 1e-3;
const HARD_FIT_LAMBDA_FLOOR: f64 = 1e-3;
const HARD_FIT_LAMBDA_CEILING: f64 = 1.0;
const DEFAULT_LAMBDA_FLOOR: f64 = 1e-6;
const DEFAULT_LAMBDA_CEILING: f64 = 1e-3;
if !cached_lambda.is_finite() || cached_lambda <= 0.0 {
return None;
}
if last_iters == 0 && !last_converged {
return None;
}
let (floor, ceiling) =
if last_converged && (1..=NEWTON_FRIENDLY_MAX_ITERS).contains(&last_iters) {
(NEWTON_LAMBDA_FLOOR, NEWTON_LAMBDA_CEILING)
} else if !last_converged || last_iters >= HARD_FIT_MIN_ITERS {
(HARD_FIT_LAMBDA_FLOOR, HARD_FIT_LAMBDA_CEILING)
} else {
(DEFAULT_LAMBDA_FLOOR, DEFAULT_LAMBDA_CEILING)
};
Some(cached_lambda.clamp(floor, ceiling))
}
fn predict_warm_start_beta_ift_inner_with_outcome(
cache: &super::IftWarmStartCache,
canonical_penalties: &[crate::construction::CanonicalPenalty],
new_rho: &Array1<f64>,
p: usize,
last_ift_residual: Option<f64>,
max_drho_cap_override: Option<f64>,
factor_override: Option<&dyn crate::linalg::matrix::FactorizedSystem>,
) -> Option<(Coefficients, IftPredictionOutcome)> {
if cache.rho.is_empty() {
return None;
}
let k = cache.rho.len();
if new_rho.len() != k {
log::info!(
"[IFT-REJECTED] reason=rho_dim_mismatch new_rho_dim={} cache_rho_dim={}",
new_rho.len(),
k,
);
return None;
}
if canonical_penalties.len() != k {
log::info!(
"[IFT-REJECTED] reason=penalty_dim_mismatch penalties_dim={} cache_rho_dim={}",
canonical_penalties.len(),
k,
);
return None;
}
if cache.beta_original.len() != p {
log::info!(
"[IFT-REJECTED] reason=beta_dim_mismatch cache_beta_dim={} expected_p={}",
cache.beta_original.len(),
p,
);
return None;
}
let mut max_abs_drho = 0.0_f64;
let upper_bounds = latest_outer_rho_upper_bounds_for_ift();
let upper_active = |idx: usize| -> bool {
let upper = upper_bounds
.as_ref()
.and_then(|bounds| bounds.get(idx))
.copied()
.unwrap_or(RHO_BOUND);
upper.is_finite() && cache.rho[idx] >= upper - 1.0e-8
};
let drho: Array1<f64> = (0..k)
.map(|i| {
if upper_active(i) {
return 0.0;
}
let d = new_rho[i] - cache.rho[i];
if !d.is_finite() {
return f64::INFINITY;
}
if d.abs() > max_abs_drho {
max_abs_drho = d.abs();
}
d
})
.collect();
let max_drho_cap = max_drho_cap_override
.filter(|cap| cap.is_finite() && *cap > 0.0)
.unwrap_or_else(|| adaptive_ift_max_drho(last_ift_residual));
if !max_abs_drho.is_finite() || max_abs_drho > max_drho_cap {
log::info!(
"[IFT-REJECTED] reason=large_drho max_drho={:.3e} cap={:.3e} drho_dim={}",
max_abs_drho,
max_drho_cap,
k,
);
return None;
}
let beta_cur = &cache.beta_original;
let mut rhs_original = Array1::<f64>::zeros(p);
let mut any_active = false;
let precomputed_ok = match &cache.lambda_s_beta_blocks {
Some(blocks) => blocks.len() == canonical_penalties.len(),
None => false,
};
for (idx, cp) in canonical_penalties.iter().enumerate() {
let dr = drho[idx];
if dr.abs() <= IFT_WARM_START_DRHO_EPS {
continue;
}
any_active = true;
let r = &cp.col_range;
let scale = dr * cache.rho[idx].exp();
let mut rhs_slice = rhs_original.slice_mut(s![r.start..r.end]);
if precomputed_ok {
let blocks = cache.lambda_s_beta_blocks.as_ref().unwrap();
let sb_block = &blocks[idx];
if sb_block.len() == rhs_slice.len() {
for (target, src) in rhs_slice.iter_mut().zip(sb_block.iter()) {
*target += scale * *src;
}
continue;
}
}
let beta_block = beta_cur.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
let sb_block = cp.local.dot(¢ered);
for (target, src) in rhs_slice.iter_mut().zip(sb_block.iter()) {
*target += scale * *src;
}
}
if !any_active {
log::info!(
"[IFT-NOOP] reason=all_drho_below_eps max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return Some((
Coefficients::new(beta_cur.clone()),
IftPredictionOutcome::Noop,
));
}
if !rhs_original.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_rhs max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
let solve_in_original = cache.frame_was_original;
let rhs_in_h_basis = if solve_in_original {
rhs_original
} else {
if cache.qs.nrows() != p || cache.qs.ncols() != p {
log::info!(
"[IFT-REJECTED] reason=qs_dim_mismatch qs_dim={}x{} expected_p={}",
cache.qs.nrows(),
cache.qs.ncols(),
p,
);
return None;
}
cache.qs.t().dot(&rhs_original)
};
let owned_factor;
let factor_ref: &dyn crate::linalg::matrix::FactorizedSystem = match factor_override {
Some(f) => f,
None => {
owned_factor = match cache.penalized_hessian_transformed.factorize() {
Ok(f) => f,
Err(_) => {
log::info!(
"[IFT-REJECTED] reason=hessian_factorize_failed max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
};
owned_factor.as_ref()
}
};
let solution_in_h_basis = match factor_ref.solve(&rhs_in_h_basis) {
Ok(u) => u,
Err(_) => {
log::info!(
"[IFT-REJECTED] reason=hessian_solve_failed max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
};
let solution_original = if solve_in_original {
solution_in_h_basis
} else {
cache.qs.dot(&solution_in_h_basis)
};
if !solution_original.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_solution max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
let mut predicted = beta_cur.clone();
for (target, &correction) in predicted.iter_mut().zip(solution_original.iter()) {
*target -= correction;
}
if !predicted.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_predicted max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
log::debug!(
"[warm-start] IFT prediction: max|Δρ|={:.3e}, ‖rhs‖={:.3e}, ‖Δβ‖={:.3e}",
max_abs_drho,
rhs_in_h_basis.dot(&rhs_in_h_basis).sqrt(),
solution_original.dot(&solution_original).sqrt(),
);
Some((
Coefficients::new(predicted),
IftPredictionOutcome::Predicted,
))
}
fn predict_warm_start_beta_ift_from_mode_response_cols(
cache: &super::IftWarmStartCache,
new_rho: &Array1<f64>,
p: usize,
last_ift_residual: Option<f64>,
max_drho_cap_override: Option<f64>,
rho_mode_response_cols: &Array2<f64>,
) -> Option<(Coefficients, IftPredictionOutcome)> {
if cache.rho.is_empty() {
return None;
}
let k = cache.rho.len();
if new_rho.len() != k {
log::info!(
"[IFT-REJECTED] reason=rho_dim_mismatch new_rho_dim={} cache_rho_dim={}",
new_rho.len(),
k,
);
return None;
}
if cache.beta_original.len() != p {
log::info!(
"[IFT-REJECTED] reason=beta_dim_mismatch cache_beta_dim={} expected_p={}",
cache.beta_original.len(),
p,
);
return None;
}
if rho_mode_response_cols.nrows() != p || rho_mode_response_cols.ncols() != k {
return None;
}
let mut max_abs_drho = 0.0_f64;
let upper_bounds = latest_outer_rho_upper_bounds_for_ift();
let upper_active = |idx: usize| -> bool {
let upper = upper_bounds
.as_ref()
.and_then(|bounds| bounds.get(idx))
.copied()
.unwrap_or(RHO_BOUND);
upper.is_finite() && cache.rho[idx] >= upper - 1.0e-8
};
let drho: Array1<f64> = (0..k)
.map(|i| {
if upper_active(i) {
return 0.0;
}
let d = new_rho[i] - cache.rho[i];
if !d.is_finite() {
return f64::INFINITY;
}
if d.abs() > max_abs_drho {
max_abs_drho = d.abs();
}
d
})
.collect();
let max_drho_cap = max_drho_cap_override
.filter(|cap| cap.is_finite() && *cap > 0.0)
.unwrap_or_else(|| adaptive_ift_max_drho(last_ift_residual));
if !max_abs_drho.is_finite() || max_abs_drho > max_drho_cap {
log::info!(
"[IFT-REJECTED] reason=large_drho max_drho={:.3e} cap={:.3e} drho_dim={}",
max_abs_drho,
max_drho_cap,
k,
);
return None;
}
if drho.iter().all(|d| d.abs() <= IFT_WARM_START_DRHO_EPS) {
log::info!(
"[IFT-NOOP] reason=all_drho_below_eps max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return Some((
Coefficients::new(cache.beta_original.clone()),
IftPredictionOutcome::Noop,
));
}
let solution_original = rho_mode_response_cols.dot(&drho);
if !solution_original.iter().all(|v| v.is_finite()) {
return None;
}
let mut predicted = cache.beta_original.clone();
for (target, &correction) in predicted.iter_mut().zip(solution_original.iter()) {
*target -= correction;
}
if !predicted.iter().all(|v| v.is_finite()) {
log::info!(
"[IFT-REJECTED] reason=non_finite_predicted max_drho={:.3e} drho_dim={}",
max_abs_drho,
k,
);
return None;
}
log::debug!(
"[warm-start] IFT prediction reused mode responses: max|Δρ|={:.3e}, ‖Δβ‖={:.3e}",
max_abs_drho,
solution_original.dot(&solution_original).sqrt(),
);
Some((
Coefficients::new(predicted),
IftPredictionOutcome::Predicted,
))
}
impl<'a> RemlState<'a> {
pub fn compute_cost(&self, p: &Array1<f64>) -> Result<f64, EstimationError> {
let cost_call_idx = {
let mut calls = self.arena.cost_eval_count.write().unwrap();
*calls += 1;
*calls
};
let t_eval_start = std::time::Instant::now();
{
let prefix: Vec<String> = p.iter().take(4).map(|v| format!("{:.3}", v)).collect();
log::debug!(
"[REML] eval#{} begin cost-only | rho[..4]=[{}] | k={}",
cost_call_idx,
prefix.join(","),
p.len()
);
}
let rho_key = EvalCacheManager::sanitized_rhokey(p);
if let Some(eval) = self.cache_manager.cached_outer_eval(&rho_key) {
log::debug!(
"[REML] eval#{} cache hit | cost {:.6e} | elapsed {:.1}ms",
cost_call_idx,
eval.cost,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
return Ok(eval.cost);
}
let prior_cost = self.compute_soft_priorcost(p) + self.compute_configured_rho_prior_cost(p);
if !prior_cost.is_finite() {
log::debug!(
"[REML] eval#{} prior short-circuit | prior_cost {:.6e} | rejecting step \
without inner solve | elapsed {:.1}ms",
cost_call_idx,
prior_cost,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
return Ok(f64::INFINITY);
}
let t_pirls = std::time::Instant::now();
let bundle = match self.obtain_eval_bundle(p) {
Ok(bundle) => bundle,
Err(EstimationError::ModelIsIllConditioned { .. }) => {
self.cache_manager.invalidate_eval_bundle();
log::debug!(
"P-IRLS flagged ill-conditioning for current rho; returning +inf cost to retreat."
);
let (at_lower, at_upper) = boundary_hit_indices(p.view(), RHO_BOUND, 1e-8);
if !(at_lower.is_empty() && at_upper.is_empty()) {
log::debug!(
"[Diag] rho bounds: lower={:?} upper={:?}",
at_lower,
at_upper
);
}
return Ok(f64::INFINITY);
}
Err(EstimationError::PerfectSeparationDetected { .. })
| Err(EstimationError::PirlsDidNotConverge { .. }) => {
self.cache_manager.invalidate_eval_bundle();
log::debug!(
"P-IRLS separation/non-convergence at current rho; returning +inf cost to retreat."
);
return Ok(f64::INFINITY);
}
Err(e) => {
self.cache_manager.invalidate_eval_bundle();
let (at_lower, at_upper) = boundary_hit_indices(p.view(), RHO_BOUND, 1e-8);
if !(at_lower.is_empty() && at_upper.is_empty()) {
log::debug!(
"[Diag] rho bounds: lower={:?} upper={:?}",
at_lower,
at_upper
);
}
return Err(e);
}
};
let pirls_ms = t_pirls.elapsed().as_secs_f64() * 1000.0;
log::debug!(
"[REML] eval#{} pirls done | elapsed {:.1}ms | backend {:?}",
cost_call_idx,
pirls_ms,
bundle.backend_kind()
);
if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
let t_assemble = std::time::Instant::now();
let result =
self.evaluate_unified_sparse(p, &bundle, super::unified::EvalMode::ValueOnly)?;
let cost = screening_residual_penalty(result.cost, bundle.pirls_result.as_ref());
log::debug!(
"[REML] eval#{} sparse cost {:.6e} | assemble {:.1}ms | total {:.1}ms",
cost_call_idx,
cost,
t_assemble.elapsed().as_secs_f64() * 1000.0,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
return Ok(cost);
}
{
let pirls_result = bundle.pirls_result.as_ref();
let ridge_used = bundle.ridge_passport.delta;
if !p.is_empty() {
let k_lambda = p.len();
let k_r = pirls_result.reparam_result.canonical_transformed.len();
let k_d = pirls_result.reparam_result.det1.len();
if !(k_lambda == k_r && k_r == k_d) {
return Err(EstimationError::LayoutError(format!(
"Penalty dimension mismatch: lambdas={}, R={}, det1={}",
k_lambda, k_r, k_d
)));
}
if self.nullspace_dims.len() != k_lambda {
return Err(EstimationError::LayoutError(format!(
"Nullspace dimension mismatch: expected {} entries, got {}",
k_lambda,
self.nullspace_dims.len()
)));
}
}
const MIN_ACCEPTABLE_HESSIAN_EIGENVALUE: f64 = 1e-12;
let want_hot_diag = !pirls_result.status.is_failed_max_iterations()
&& self.should_compute_hot_diagnostics(cost_call_idx);
if ridge_used > 0.0 && want_hot_diag {
let pht_dense = pirls_result.penalized_hessian_transformed.to_dense();
if let Ok((eigs, _)) = pht_dense.eigh(Side::Lower)
&& let Some(min_eig) = eigs.iter().cloned().reduce(f64::min)
{
if should_emit_h_min_eig_diag(min_eig) {
log::debug!(
"[Diag] H min_eig={:.3e} (ridge={:.3e})",
min_eig,
ridge_used
);
}
if min_eig <= 0.0 {
log::warn!(
"Penalized Hessian not PD (min eig <= 0) before stabilization; proceeding with ridge {:.3e}.",
ridge_used
);
}
if !min_eig.is_finite() || min_eig <= MIN_ACCEPTABLE_HESSIAN_EIGENVALUE {
let condition_number = symmetric_spectrum_condition_number(&pht_dense);
log::warn!(
"Penalized Hessian extremely ill-conditioned (cond={:.3e}); continuing with stabilized Hessian.",
condition_number
);
}
}
}
}
let t_assemble = std::time::Instant::now();
let result = self.evaluate_unified(p, &bundle, super::unified::EvalMode::ValueOnly)?;
let cost = screening_residual_penalty(result.cost, bundle.pirls_result.as_ref());
log::debug!(
"[REML] eval#{} dense cost {:.6e} | assemble {:.1}ms | total {:.1}ms",
cost_call_idx,
cost,
t_assemble.elapsed().as_secs_f64() * 1000.0,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
Ok(cost)
}
pub fn compute_screening_proxy(&self, p: &Array1<f64>) -> Result<f64, EstimationError> {
let in_screening = self.screening_max_inner_iterations.load(Ordering::Relaxed) > 0;
if !in_screening {
return self.compute_cost(p);
}
let bundle = match self.obtain_eval_bundle(p) {
Ok(bundle) => bundle,
Err(err) if err.is_inner_solve_retreat() => {
self.cache_manager.invalidate_eval_bundle();
return Ok(f64::INFINITY);
}
Err(e) => {
self.cache_manager.invalidate_eval_bundle();
return Err(e);
}
};
let proxy = bundle.pirls_result.min_penalized_deviance;
if proxy.is_finite() {
Ok(proxy)
} else {
Ok(f64::INFINITY)
}
}
fn build_dense_firth_operator_for_outer_basis(
&self,
pirls_result: &PirlsResult,
bundle: &EvalShared,
free_basis_opt: &Option<Array2<f64>>,
) -> Result<Option<std::sync::Arc<super::FirthDenseOperator>>, EstimationError> {
let Some(jeffreys_link) = reml_robust_jeffreys_link(&self.config) else {
return Ok(None);
};
if let Some(z) = free_basis_opt.as_ref() {
let x_projected = pirls_result.x_transformed.to_dense().dot(z);
if !super::firth_problem_scale_allows(x_projected.nrows(), x_projected.ncols()) {
log::info!(
"disabling Firth bias reduction for projected outer basis (n={}, p={}, n*p={}, n*p^2={}): \
exact Firth operator is small-model-only",
x_projected.nrows(),
x_projected.ncols(),
x_projected.nrows().saturating_mul(x_projected.ncols()),
x_projected
.nrows()
.saturating_mul(x_projected.ncols())
.saturating_mul(x_projected.ncols()),
);
return Ok(None);
}
return Ok(Some(std::sync::Arc::new(
Self::build_firth_dense_operator_for_link(
&jeffreys_link,
&x_projected,
&pirls_result.final_eta,
self.weights,
)?,
)));
}
if let Some(cached) = bundle.firth_dense_operator.clone() {
return Ok(Some(cached));
}
let x_dense = pirls_result.x_transformed.to_dense();
Ok(Some(std::sync::Arc::new(
Self::build_firth_dense_operator_for_link(
&jeffreys_link,
&x_dense,
&pirls_result.final_eta,
self.weights,
)?,
)))
}
fn build_dense_derivative_context(
&self,
pirls_result: &PirlsResult,
bundle: &EvalShared,
free_basis_opt: &Option<Array2<f64>>,
include_firth_derivs: bool,
) -> Result<DerivativeContext, EstimationError> {
use super::unified::{
DispersionHandling, GaussianDerivatives, SinglePredictorGlmDerivatives,
};
let is_gaussian_identity = reml_is_gaussian_identity(&pirls_result.likelihood);
let firth_op =
self.build_dense_firth_operator_for_outer_basis(pirls_result, bundle, free_basis_opt)?;
let firth_active_for_derivs = include_firth_derivs && firth_op.is_some();
let deriv_provider: Box<dyn super::unified::HessianDerivativeProvider> =
if is_gaussian_identity || pirls_result.derivatives_unsupported {
Box::new(GaussianDerivatives)
} else {
let (hessian_weights, c_array, d_array) =
self.hessian_surface_arrays(pirls_result)?;
let x_transformed = if let Some(z) = free_basis_opt.as_ref() {
let x_dense = pirls_result.x_transformed.to_dense();
crate::linalg::matrix::DesignMatrix::Dense(
crate::matrix::DenseDesignMatrix::from(x_dense.dot(z)),
)
} else {
pirls_result.x_transformed.clone()
};
let base = SinglePredictorGlmDerivatives {
c_array,
d_array: Some(d_array),
hessian_weights,
x_transformed,
};
if firth_active_for_derivs {
if let Some(firth_op) = firth_op.clone() {
Box::new(super::unified::FirthAwareGlmDerivatives { base, firth_op })
} else {
Box::new(base)
}
} else {
Box::new(base)
}
};
let dispersion = if is_gaussian_identity {
DispersionHandling::ProfiledGaussian
} else {
DispersionHandling::Fixed {
phi: reml_fixed_glm_dispersion(&pirls_result.likelihood),
include_logdet_h: true,
include_logdet_s: true,
}
};
let log_likelihood = crate::pirls::calculate_loglikelihood_omitting_constants(
self.y,
&pirls_result.finalmu,
&pirls_result.likelihood,
self.weights,
);
let barrier_config = if free_basis_opt.is_none() {
pirls_result
.linear_constraints_transformed
.as_ref()
.and_then(Self::barrier_config_from_constraints)
} else {
None
};
Ok(DerivativeContext {
deriv_provider,
dispersion,
log_likelihood,
firth_op,
barrier_config,
})
}
fn build_sparse_derivative_context(
&self,
pirls_result: &PirlsResult,
bundle: &EvalShared,
) -> Result<DerivativeContext, EstimationError> {
use super::unified::{
DispersionHandling, FirthAwareGlmDerivatives, GaussianDerivatives,
SinglePredictorGlmDerivatives,
};
let is_gaussian_identity = reml_is_gaussian_identity(&pirls_result.likelihood);
let firth_op = if let Some(jeffreys_link) = reml_robust_jeffreys_link(&self.config) {
if let Some(cached) = bundle.firth_dense_operator_original.clone() {
Some(cached)
} else {
let x_dense = self
.x()
.try_to_dense_arc(
"sparse exact REML runtime requires dense design for Firth operator",
)
.map_err(EstimationError::InvalidInput)?;
Some(std::sync::Arc::new(
Self::build_firth_dense_operator_for_link(
&jeffreys_link,
x_dense.as_ref(),
&pirls_result.final_eta,
self.weights,
)?,
))
}
} else {
None
};
let (dispersion, deriv_provider): (_, Box<dyn super::unified::HessianDerivativeProvider>) =
if is_gaussian_identity {
(
DispersionHandling::ProfiledGaussian,
Box::new(GaussianDerivatives),
)
} else if pirls_result.derivatives_unsupported {
(
DispersionHandling::Fixed {
phi: reml_fixed_glm_dispersion(&pirls_result.likelihood),
include_logdet_h: true,
include_logdet_s: true,
},
Box::new(GaussianDerivatives),
)
} else {
let (hessian_weights, c_array, d_array) =
self.hessian_surface_arrays(pirls_result)?;
(
DispersionHandling::Fixed {
phi: reml_fixed_glm_dispersion(&pirls_result.likelihood),
include_logdet_h: true,
include_logdet_s: true,
},
{
let base = SinglePredictorGlmDerivatives {
c_array,
d_array: Some(d_array),
hessian_weights,
x_transformed: self.x().clone(),
};
if let Some(firth_op) = firth_op.clone() {
Box::new(FirthAwareGlmDerivatives { base, firth_op })
} else {
Box::new(base)
}
},
)
};
let log_likelihood = crate::pirls::calculate_loglikelihood_omitting_constants(
self.y,
&pirls_result.finalmu,
&pirls_result.likelihood,
self.weights,
);
let barrier_config = self
.linear_constraints
.as_ref()
.and_then(Self::barrier_config_from_constraints);
Ok(DerivativeContext {
deriv_provider,
dispersion,
log_likelihood,
firth_op,
barrier_config,
})
}
fn build_penalty_coords(&self) -> Vec<super::unified::PenaltyCoordinate> {
if let Some(ref kron) = self.kronecker_penalty_system
&& self.kronecker_factored.is_some()
{
let d = kron.ndim();
let total_dim = kron.p_total();
let eigenvalues: Vec<ndarray::Array1<f64>> = kron
.marginal_eigensystems
.iter()
.map(|(evals, _)| evals.clone())
.collect();
let mut coords = Vec::with_capacity(kron.num_penalties());
for k in 0..d {
coords.push(super::unified::PenaltyCoordinate::KroneckerMarginal {
eigenvalues: eigenvalues.clone(),
dim_index: k,
marginal_dims: kron.marginal_dims.clone(),
total_dim,
});
}
if kron.has_double_penalty {
let identity_root = ndarray::Array2::<f64>::eye(total_dim);
coords.push(super::unified::PenaltyCoordinate::from_dense_root(
identity_root,
));
}
return coords;
}
self.canonical_penalties
.iter()
.map(|cp| cp.to_penalty_coordinate())
.collect()
}
fn finish_assembly(
&self,
pirls_result: &PirlsResult,
ctx: DerivativeContext,
hessian_op: std::sync::Arc<dyn super::unified::HessianOperator>,
beta: Array1<f64>,
penalty_logdet: super::unified::PenaltyLogdetDerivs,
nullspace_dim: f64,
hessian_logdet_correction: f64,
penalty_subspace_trace: Option<std::sync::Arc<super::unified::PenaltySubspaceTrace>>,
free_basis: Option<&Array2<f64>>,
) -> super::assembly::InnerAssembly<'static> {
let penalty_coords = match free_basis {
Some(z) => self
.build_penalty_coords()
.iter()
.map(|coord| coord.project_into_subspace(z))
.collect(),
None => self.build_penalty_coords(),
};
super::assembly::InnerAssembly {
log_likelihood: ctx.log_likelihood,
penalty_quadratic: pirls_result.stable_penalty_term,
beta,
n_observations: self.weights.iter().filter(|&&wi| wi > 0.0).count(),
hessian_op,
penalty_coords,
penalty_logdet,
dispersion: ctx.dispersion,
rho_curvature_scale: 1.0,
rho_prior: self.effective_rho_prior().into_owned(),
hessian_logdet_correction,
penalty_subspace_trace,
deriv_provider: Some(ctx.deriv_provider),
tk_correction: 0.0,
tk_gradient: None,
firth: ctx.firth_op,
nullspace_dim: Some(nullspace_dim),
barrier_config: ctx.barrier_config,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
kkt_residual: None,
active_constraints: None,
}
}
fn build_dense_assembly(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::assembly::InnerAssembly<'static>, EstimationError> {
use super::unified::{
DenseCholeskyValueOnlyOperator, DenseSpectralOperator, PseudoLogdetMode,
};
use std::borrow::Cow;
let pirls_result = bundle.pirls_result.as_ref();
let ridge_passport = pirls_result.ridge_passport;
let free_basis_opt = self.active_constraint_free_basis(pirls_result);
let (h_for_operator, e_for_logdet) = if let Some(z) = free_basis_opt.as_ref() {
(
Cow::Owned(Self::projectwith_basis(bundle.h_total.as_ref(), z)),
Cow::Owned(pirls_result.reparam_result.e_transformed.dot(z)),
)
} else {
(
Cow::Borrowed(bundle.h_total.as_ref()),
Cow::Borrowed(&pirls_result.reparam_result.e_transformed),
)
};
let hessian_mode = if bundle.firth_dense_operator.is_some() {
PseudoLogdetMode::HardPseudo
} else {
PseudoLogdetMode::Smooth
};
let hessian_op: std::sync::Arc<dyn super::unified::HessianOperator> = if mode
== super::unified::EvalMode::ValueOnly
&& matches!(hessian_mode, PseudoLogdetMode::Smooth)
&& free_basis_opt.is_none()
{
match DenseCholeskyValueOnlyOperator::from_spd(h_for_operator.as_ref()) {
Ok(chol_op) => std::sync::Arc::new(chol_op),
Err(_) => std::sync::Arc::new(
DenseSpectralOperator::from_symmetric_with_mode(
h_for_operator.as_ref(),
hessian_mode,
)
.map_err(|e| {
EstimationError::InvalidInput(format!(
"DenseSpectralOperator from PIRLS Hessian: {e}"
))
})?,
),
}
} else {
std::sync::Arc::new(
DenseSpectralOperator::from_symmetric_with_mode(
h_for_operator.as_ref(),
hessian_mode,
)
.map_err(|e| {
EstimationError::InvalidInput(format!(
"DenseSpectralOperator from PIRLS Hessian: {e}"
))
})?,
)
};
let c_nontrivial = pirls_result.solve_c_array.iter().any(|&c| c != 0.0);
let uses_kron_penalty_logdet = self.kronecker_penalty_system.as_ref().is_some_and(|kron| {
self.kronecker_factored.is_some() && kron.num_penalties() == rho.len()
});
let needs_penalty_subspace = !uses_kron_penalty_logdet
|| (matches!(hessian_mode, PseudoLogdetMode::Smooth) && c_nontrivial);
let penalty_subspace = if needs_penalty_subspace {
Some(self.compute_penalty_subspace(e_for_logdet.as_ref(), ridge_passport)?)
} else {
None
};
let (penalty_rank, penalty_logdet) = self.dense_penalty_logdet_derivs(
rho,
e_for_logdet.as_ref(),
&[],
ridge_passport,
penalty_subspace.as_ref(),
mode,
)?;
let beta = if let Some(z) = free_basis_opt.as_ref() {
z.t().dot(pirls_result.beta_transformed.as_ref())
} else {
pirls_result.beta_transformed.as_ref().clone()
};
let nullspace_dim = h_for_operator.ncols().saturating_sub(penalty_rank) as f64;
let (hessian_logdet_correction, penalty_subspace_trace) =
if matches!(hessian_mode, PseudoLogdetMode::Smooth) && c_nontrivial {
let Some(penalty_subspace) = penalty_subspace.as_ref() else {
crate::bail_invalid_estim!(
"projected Hessian logdet requires penalty subspace"
);
};
let (log_det_h_proj, kernel) =
self.fixed_subspace_hessian_projected_parts(&h_for_operator, penalty_subspace)?;
(
log_det_h_proj - hessian_op.logdet(),
kernel.map(std::sync::Arc::new),
)
} else {
(0.0, None)
};
let ctx =
self.build_dense_derivative_context(pirls_result, bundle, &free_basis_opt, true)?;
Ok(self.finish_assembly(
pirls_result,
ctx,
hessian_op,
beta,
penalty_logdet,
nullspace_dim,
hessian_logdet_correction,
penalty_subspace_trace,
free_basis_opt.as_ref(),
))
}
fn build_sparse_assembly(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::assembly::InnerAssembly<'static>, EstimationError> {
use super::unified::{HessianOperator, PenaltyLogdetDerivs, SparseCholeskyOperator};
let sparse = bundle.sparse_exact.as_ref().ok_or_else(|| {
EstimationError::InvalidInput("missing sparse exact evaluation payload".to_string())
})?;
let pirls_result = bundle.pirls_result.as_ref();
let beta = self.sparse_exact_beta_original(pirls_result);
let p_dim = beta.len();
let hessian_op: std::sync::Arc<dyn HessianOperator> = {
let mut op = SparseCholeskyOperator::new(sparse.factor.clone(), sparse.logdet_h, p_dim);
if let Some(ref taka) = sparse.takahashi {
op = op.with_takahashi(taka.clone());
}
std::sync::Arc::new(op)
};
log::trace!(
"SparseCholeskyOperator: dim={}, active_rank={}",
hessian_op.dim(),
hessian_op.active_rank()
);
let nullspace_dim = p_dim.saturating_sub(sparse.penalty_rank) as f64;
let det2 = if mode == super::unified::EvalMode::ValueGradientHessian {
let lambdas = rho.mapv(f64::exp);
let (_, det2) = self.structural_penalty_logdet_derivatives_block_local(
&lambdas,
bundle.ridge_passport.penalty_logdet_ridge(),
)?;
Some(det2)
} else {
None
};
let penalty_logdet = PenaltyLogdetDerivs {
value: sparse.logdet_s_pos,
first: sparse.det1_values.as_ref().clone(),
second: det2,
};
let ctx = self.build_sparse_derivative_context(pirls_result, bundle)?;
Ok(self.finish_assembly(
pirls_result,
ctx,
hessian_op,
beta,
penalty_logdet,
nullspace_dim,
0.0,
None,
None,
))
}
fn build_dense_original_assembly(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::assembly::InnerAssembly<'static>, EstimationError> {
use super::unified::{DenseSpectralOperator, PseudoLogdetMode};
let pirls_result = bundle.pirls_result.as_ref();
let ridge_passport = pirls_result.ridge_passport;
let mut h_total_original =
self.bundle_matrix_in_original_basis(pirls_result, bundle.h_total.as_ref());
let beta = self.sparse_exact_beta_original(pirls_result);
let transformed_barrier_to_strip = pirls_result
.linear_constraints_transformed
.as_ref()
.and_then(Self::barrier_config_from_constraints);
if let Some(ref barrier_cfg_t) = transformed_barrier_to_strip {
let mut diag_trans = Array2::<f64>::zeros((beta.len(), beta.len()));
let beta_t = pirls_result.beta_transformed.as_ref();
match barrier_cfg_t.add_barrier_hessian_diagonal(&mut diag_trans, beta_t) {
Ok(()) => {
let qs = &pirls_result.reparam_result.qs;
let tmp = qs.dot(&diag_trans);
let rotated = tmp.dot(&qs.t());
h_total_original -= &rotated;
}
Err(e) => {
log::warn!(
"Transformed-basis barrier-diagonal reconstruction failed ({e}); \
leaving bundle Hessian unchanged before adding original-basis barrier"
);
}
}
}
if let Some(barrier_cfg) = self
.linear_constraints
.as_ref()
.and_then(Self::barrier_config_from_constraints)
&& let Err(e) = barrier_cfg.add_barrier_hessian_diagonal(&mut h_total_original, &beta)
{
log::warn!(
"Original-basis barrier Hessian diagonal skipped: {e}; \
cost/gradient/logdet consistency may regress on infeasible \
candidates (slack ≤ 0). BFGS line search must maintain \
feasibility for the barrier-aware outer objective."
);
}
let hessian_mode = if bundle.firth_dense_operator.is_some()
|| bundle.firth_dense_operator_original.is_some()
{
PseudoLogdetMode::HardPseudo
} else {
PseudoLogdetMode::Smooth
};
let hessian_op: std::sync::Arc<dyn super::unified::HessianOperator> = {
use super::unified::DenseCholeskyValueOnlyOperator;
if mode == super::unified::EvalMode::ValueOnly
&& matches!(hessian_mode, PseudoLogdetMode::Smooth)
{
match DenseCholeskyValueOnlyOperator::from_spd(&h_total_original) {
Ok(chol_op) => std::sync::Arc::new(chol_op),
Err(_) => std::sync::Arc::new(
DenseSpectralOperator::from_symmetric_with_mode(
&h_total_original,
hessian_mode,
)
.map_err(|e| {
EstimationError::InvalidInput(format!(
"DenseSpectralOperator from original-basis PIRLS Hessian: {e}"
))
})?,
),
}
} else {
std::sync::Arc::new(
DenseSpectralOperator::from_symmetric_with_mode(
&h_total_original,
hessian_mode,
)
.map_err(|e| {
EstimationError::InvalidInput(format!(
"DenseSpectralOperator from original-basis PIRLS Hessian: {e}"
))
})?,
)
}
};
let e_for_logdet = &pirls_result.reparam_result.e_transformed;
let c_nontrivial = pirls_result.solve_c_array.iter().any(|&c| c != 0.0);
let uses_kron_penalty_logdet = self.kronecker_penalty_system.as_ref().is_some_and(|kron| {
self.kronecker_factored.is_some() && kron.num_penalties() == rho.len()
});
let needs_penalty_subspace = !uses_kron_penalty_logdet
|| (matches!(hessian_mode, PseudoLogdetMode::Smooth) && c_nontrivial);
let penalty_subspace = if needs_penalty_subspace {
Some(self.compute_penalty_subspace(e_for_logdet, ridge_passport)?)
} else {
None
};
let (penalty_rank, penalty_logdet) = self.dense_penalty_logdet_derivs(
rho,
e_for_logdet,
&[],
ridge_passport,
penalty_subspace.as_ref(),
mode,
)?;
let nullspace_dim = beta.len().saturating_sub(penalty_rank) as f64;
let (hessian_logdet_correction, penalty_subspace_trace) =
if matches!(hessian_mode, PseudoLogdetMode::Smooth) && c_nontrivial {
let Some(penalty_subspace) = penalty_subspace.as_ref() else {
crate::bail_invalid_estim!(
"projected Hessian logdet requires penalty subspace"
);
};
let qs = &pirls_result.reparam_result.qs;
let h_transformed = crate::faer_ndarray::fast_ab(
&crate::faer_ndarray::fast_atb(qs, &h_total_original),
qs,
);
let (log_det_h_proj, kernel_trans) =
self.fixed_subspace_hessian_projected_parts(&h_transformed, penalty_subspace)?;
let kernel_orig = kernel_trans.map(|kernel_trans| {
let u_s_orig = qs.dot(&kernel_trans.u_s);
std::sync::Arc::new(super::unified::PenaltySubspaceTrace {
u_s: u_s_orig,
h_proj_inverse: kernel_trans.h_proj_inverse,
})
});
(log_det_h_proj - hessian_op.logdet(), kernel_orig)
} else {
(0.0, None)
};
let ctx = self.build_sparse_derivative_context(pirls_result, bundle)?;
Ok(self.finish_assembly(
pirls_result,
ctx,
hessian_op,
beta,
penalty_logdet,
nullspace_dim,
hessian_logdet_correction,
penalty_subspace_trace,
None,
))
}
fn build_auto_assembly(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::assembly::InnerAssembly<'static>, EstimationError> {
if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
self.build_sparse_assembly(rho, bundle, mode)
} else if matches!(
bundle.pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
) && self
.active_constraint_free_basis(bundle.pirls_result.as_ref())
.is_none()
{
self.build_dense_original_assembly(rho, bundle, mode)
} else {
self.build_dense_assembly(rho, bundle, mode)
}
}
fn apply_alo_stabilization_to_result(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
mut result: super::unified::RemlLamlResult,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
let want_gradient = mode != super::unified::EvalMode::ValueOnly;
let Some(alo_eval) = self.alo_stabilization_eval(rho, bundle, want_gradient)? else {
return Ok(result);
};
result.cost += alo_eval.cost;
if want_gradient {
match (result.gradient.as_mut(), alo_eval.gradient.as_ref()) {
(Some(gradient), Some(alo_gradient)) if gradient.len() >= alo_gradient.len() => {
for idx in 0..alo_gradient.len() {
gradient[idx] += alo_gradient[idx];
}
}
_ => {
log::warn!(
"[ALO-STABILIZED-REML] unstable ALO detected but analytic gradient \
augmentation was unavailable; retaining REML gradient and adding \
value term only (n={} max_h={:.3} min_denom={:.3})",
self.y.len(),
alo_eval.max_leverage,
alo_eval.min_denominator,
);
}
}
}
log::info!(
"[ALO-STABILIZED-REML] active cost_add={:.6e} max_h={:.3} min_denom={:.3} k_hat={:?}",
alo_eval.cost,
alo_eval.max_leverage,
alo_eval.min_denominator,
alo_eval.k_hat,
);
Ok(result)
}
fn alo_stabilization_eval(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
want_gradient: bool,
) -> Result<Option<AloStabilizationEval>, EstimationError> {
if self.config.firth_bias_reduction
|| !matches!(
self.config.likelihood.spec.response,
ResponseFamily::Gaussian
)
|| !matches!(self.config.link_function(), LinkFunction::Identity)
|| !matches!(
bundle.pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
)
|| bundle.backend_kind() == GeometryBackendKind::SparseExactSpd
{
return Ok(None);
}
let n = self.y.len();
if n < ALO_STABILIZATION_MIN_N {
return Ok(None);
}
let edf = bundle.pirls_result.edf;
if edf.is_finite() && edf > ALO_EDF_FRACTION_SATURATION * (n as f64) {
return Ok(None);
}
let alo = crate::inference::alo::compute_alo_diagnostics_from_pirls(
bundle.pirls_result.as_ref(),
self.y,
self.config.link_function(),
)?;
let max_leverage = alo
.leverage
.iter()
.copied()
.fold(0.0_f64, |acc, h| acc.max(h));
let min_denominator = alo
.leverage
.iter()
.copied()
.map(|h| 1.0 - h)
.fold(f64::INFINITY, |acc, d| acc.min(d));
if max_leverage < ALO_MAX_LEVERAGE_THRESHOLD
&& min_denominator > ALO_DENOM_INSTABILITY_THRESHOLD
{
return Ok(None);
}
if self.high_leverage_is_pure_parametric(&alo.leverage, bundle.pirls_result.as_ref())? {
return Ok(None);
}
let n_high_leverage = alo
.leverage
.iter()
.filter(|&&h| h > ALO_MAX_LEVERAGE_THRESHOLD)
.count();
if (n_high_leverage as f64) > ALO_PERVASIVE_LEVERAGE_FRACTION * (n as f64) {
return Ok(None);
}
let raw_influence: Vec<f64> = alo
.leverage
.iter()
.copied()
.map(|h| h.max(0.0) / (1.0 - h).max(1e-12))
.collect();
let psis = crate::inference::psis::pareto_smooth_weights(&raw_influence);
let smoothed_influence = psis
.as_ref()
.map(|p| p.smoothed.as_slice())
.unwrap_or(raw_influence.as_slice());
let mean_influence =
smoothed_influence.iter().sum::<f64>() / smoothed_influence.len() as f64;
let fresh_influence_scale: Vec<f64> = smoothed_influence
.iter()
.map(|w| (*w / mean_influence.max(f64::MIN_POSITIVE)).clamp(0.25, 4.0))
.collect();
let fresh_phi = match self.config.likelihood.scale.fixed_phi() {
Some(phi) if phi.is_finite() && phi > 0.0 => phi,
Some(_) => 1.0,
None => {
let dp = bundle.pirls_result.deviance + bundle.pirls_result.stable_penalty_term;
let denom = (n as f64 - bundle.pirls_result.edf).max(1.0);
(dp / denom).max(f64::MIN_POSITIVE)
}
};
let (influence_scale, phi) = {
let mut frozen = self.alo_frozen_nuisance.write().unwrap();
match frozen.as_ref() {
Some(cached) if cached.n_obs == n && cached.influence_scale.len() == n => {
(cached.influence_scale.clone(), cached.phi)
}
_ => {
*frozen = Some(super::AloFrozenNuisance {
n_obs: n,
influence_scale: fresh_influence_scale.clone(),
phi: fresh_phi,
});
(fresh_influence_scale, fresh_phi)
}
}
};
let mut cost = 0.0_f64;
for i in 0..n {
cost += ALO_TAU * alo_leverage_barrier(alo.leverage[i]);
cost += ALO_GAMMA
* influence_scale[i]
* gaussian_alo_deviance(self.y[i], alo.eta_tilde[i], self.weights[i], phi);
}
if !(cost.is_finite() && cost >= 0.0) {
return Ok(None);
}
let gradient = if want_gradient
&& rho.len()
== bundle
.pirls_result
.reparam_result
.canonical_transformed
.len()
&& n.saturating_mul(bundle.pirls_result.beta_transformed.as_ref().len())
<= ALO_GRADIENT_MAX_WORK
{
let x = bundle.pirls_result.x_transformed.to_dense();
match bundle
.pirls_result
.dense_stabilizedhessian_transformed("ALO-stabilized REML gradient")?
.cholesky(Side::Lower)
{
Ok(chol) => {
let mut h_inv_xt = x.t().to_owned();
chol.solve_mat_in_place(&mut h_inv_xt);
self.alo_stabilization_gradient(
rho,
bundle,
&alo,
&influence_scale,
phi,
&AloFactoredHessian {
x: &x,
chol: &chol,
h_inv_xt: &h_inv_xt,
},
)?
}
Err(_) => None,
}
} else {
None
};
Ok(Some(AloStabilizationEval {
cost,
gradient,
k_hat: psis.as_ref().map(|p| p.k_hat),
max_leverage,
min_denominator,
}))
}
fn high_leverage_is_pure_parametric(
&self,
alo_leverage: &Array1<f64>,
pirls_result: &PirlsResult,
) -> Result<bool, EstimationError> {
let pure_parametric_cols = self.pure_parametric_column_indices();
if pure_parametric_cols.len() <= 1 {
return Ok(false);
}
let high_rows: Vec<usize> = alo_leverage
.iter()
.enumerate()
.filter_map(|(idx, &h)| (h > ALO_MAX_LEVERAGE_THRESHOLD).then_some(idx))
.collect();
if high_rows.is_empty() {
return Ok(false);
}
let parametric_leverage =
self.pure_parametric_projection_leverage(&pure_parametric_cols, pirls_result)?;
let all_high_rows_parametric = high_rows.iter().all(|&idx| {
let h = alo_leverage[idx];
let hp = parametric_leverage[idx];
hp >= ALO_MAX_LEVERAGE_THRESHOLD && hp >= ALO_PARAMETRIC_LEVERAGE_SHARE * h
});
if all_high_rows_parametric {
log::info!(
"[ALO-STABILIZED-REML] suppressed: {} high-leverage rows are explained by {} pure-parametric columns",
high_rows.len(),
pure_parametric_cols.len(),
);
}
Ok(all_high_rows_parametric)
}
fn pure_parametric_column_indices(&self) -> Vec<usize> {
let mut covered_by_penalty = vec![false; self.p];
for penalty in self.canonical_penalties.iter() {
for col in penalty.col_range.clone() {
if col < self.p {
covered_by_penalty[col] = true;
}
}
}
covered_by_penalty
.iter()
.enumerate()
.filter_map(|(col, &covered)| (!covered).then_some(col))
.collect()
}
fn pure_parametric_projection_leverage(
&self,
cols: &[usize],
pirls_result: &PirlsResult,
) -> Result<Array1<f64>, EstimationError> {
let x_dense = match self
.x
.try_to_dense_arc("ALO pure-parametric activation gate requires dense design")
{
Ok(x_dense) => x_dense,
Err(reason) => {
log::debug!("[ALO-STABILIZED-REML] pure-parametric gate skipped: {reason}");
return Ok(Array1::<f64>::zeros(pirls_result.finalweights.len()));
}
};
let n = x_dense.nrows();
if n != pirls_result.finalweights.len() || cols.iter().any(|&col| col >= x_dense.ncols()) {
crate::bail_invalid_estim!(
"ALO pure-parametric activation gate received inconsistent dimensions"
);
}
let q = cols.len();
let mut gram = Array2::<f64>::zeros((q, q));
for i in 0..n {
let wi = pirls_result.finalweights[i];
if !wi.is_finite() || wi <= 0.0 {
return Ok(Array1::<f64>::zeros(n));
}
for (a, &ca) in cols.iter().enumerate() {
let xa = x_dense[[i, ca]];
for (b, &cb) in cols.iter().take(a + 1).enumerate() {
gram[[a, b]] += wi * xa * x_dense[[i, cb]];
}
}
}
for a in 0..q {
for b in 0..a {
gram[[b, a]] = gram[[a, b]];
}
}
let factor = match StableSolver::new("ALO pure-parametric activation gate").factorize(&gram)
{
Ok(factor) => factor,
Err(_) => return Ok(Array1::<f64>::zeros(n)),
};
let mut gram_inv = Array2::<f64>::eye(q);
let mut gram_inv_view = array2_to_matmut(&mut gram_inv);
factor.solve_in_place(gram_inv_view.as_mut());
let mut leverage = Array1::<f64>::zeros(n);
for i in 0..n {
let wi = pirls_result.finalweights[i];
let mut h = 0.0;
for (a, &ca) in cols.iter().enumerate() {
let xa = x_dense[[i, ca]];
for (b, &cb) in cols.iter().enumerate() {
h += xa * gram_inv[[a, b]] * x_dense[[i, cb]];
}
}
leverage[i] = (wi * h).max(0.0);
}
Ok(leverage)
}
fn alo_stabilization_gradient(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
alo: &crate::inference::alo::AloDiagnostics,
influence_scale: &[f64],
phi: f64,
factored: &AloFactoredHessian<'_>,
) -> Result<Option<Array1<f64>>, EstimationError> {
let &AloFactoredHessian { x, chol, h_inv_xt } = factored;
let beta = bundle.pirls_result.beta_transformed.as_ref();
let k = rho.len();
let nrows = x.nrows();
let phi_safe = phi.max(f64::MIN_POSITIVE);
let mut deta_loo = Array2::<f64>::zeros((nrows, k));
let mut lev_deriv = Array2::<f64>::zeros((nrows, k));
for penalty_idx in 0..k {
let lambda = rho[penalty_idx].exp();
let sbeta = transformed_penalty_matvec(
&bundle.pirls_result.reparam_result.canonical_transformed[penalty_idx],
beta,
)
.mapv(|v| lambda * v);
let mut mode_deriv = chol.solvevec(&sbeta);
mode_deriv.mapv_inplace(|v| -v);
let eta_deriv = x.dot(&mode_deriv);
for i in 0..nrows {
let m_i = h_inv_xt.column(i);
let penalty_m = transformed_penalty_matvec(
&bundle.pirls_result.reparam_result.canonical_transformed[penalty_idx],
&m_i.to_owned(),
);
let h_i = alo.leverage[i];
let denom = (1.0 - h_i).max(1e-12);
let residual = self.y[i] - bundle.pirls_result.final_eta[i];
let v_ik = -self.weights[i] * lambda * m_i.dot(&penalty_m);
lev_deriv[[i, penalty_idx]] = v_ik;
deta_loo[[i, penalty_idx]] =
eta_deriv[i] / denom - residual * v_ik / (denom * denom);
}
}
let mut grad = Array1::<f64>::zeros(k);
for i in 0..nrows {
let h_i = alo.leverage[i];
let raw_dev_eta_grad =
-2.0 * self.weights[i] * (self.y[i] - alo.eta_tilde[i]) / phi_safe;
let raw_dev =
gaussian_alo_raw_deviance(self.y[i], alo.eta_tilde[i], self.weights[i], phi_safe);
let dev_eta_grad = gaussian_alo_deviance_saturation_factor(raw_dev) * raw_dev_eta_grad;
let b_prime = alo_leverage_barrier_derivative(h_i);
for kk in 0..k {
let u_ik = deta_loo[[i, kk]];
let v_ik = lev_deriv[[i, kk]];
grad[kk] +=
ALO_TAU * b_prime * v_ik + ALO_GAMMA * influence_scale[i] * dev_eta_grad * u_ik;
}
}
if grad.iter().all(|g| g.is_finite()) {
Ok(Some(grad))
} else {
Ok(None)
}
}
fn build_prior(
&self,
rho: &Array1<f64>,
mode: super::unified::EvalMode,
) -> Option<(f64, Array1<f64>, Option<Array2<f64>>)> {
super::assembly::soft_prior_for_mode(
rho,
mode,
|r| self.compute_soft_priorcost(r) + self.compute_configured_rho_prior_cost(r),
|r| self.compute_soft_priorgrad(r) + &self.compute_configured_rho_prior_grad(r),
|r| {
let mut hess = self
.compute_soft_priorhess(r)
.unwrap_or_else(|| Array2::<f64>::zeros((r.len(), r.len())));
if let Some(configured) = self.compute_configured_rho_prior_hess(r) {
hess += &configured;
}
if hess.iter().any(|&v| v != 0.0) {
Some(hess)
} else {
None
}
},
)
}
fn assemble_and_evaluate(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
assembly: super::assembly::InnerAssembly<'static>,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
let prior = self.build_prior(rho, mode);
self.validate_tk_ext_coords(mode, &assembly.ext_coords)?;
let tk_terms = self.tierney_kadane_terms(rho, bundle, mode, &assembly.ext_coords)?;
let trace_state = self.hypergradient_trace_state();
Self::reset_hypergradient_trace_telemetry(&trace_state);
let assembly_ext_len = assembly.ext_coords.len();
let mut inner_solution = assembly.build();
inner_solution.stochastic_trace_state = trace_state;
inner_solution.gaussian_weight_log_sum_half = self.gaussian_weight_log_sum_half();
let solution_beta = inner_solution.beta.clone();
let result = super::assembly::evaluate_solution(
&inner_solution,
rho.as_slice().unwrap(),
mode,
prior,
)
.map_err(EstimationError::InvalidInput)?;
let result = self.apply_tk_to_result(result, tk_terms)?;
let block_terms = self.block_local_sampled_correction(rho, bundle, assembly_ext_len)?;
let result = self.apply_tk_to_result(result, block_terms)?;
let result = self.apply_alo_stabilization_to_result(rho, bundle, mode, result)?;
self.store_ift_mode_response_cache_from_result(rho, bundle, &result);
if let Some(polish_step) = result.inner_polish_step.as_ref() {
self.apply_inner_polish_step_to_warm_start(bundle, &solution_beta, polish_step);
}
Ok(result)
}
fn assemble_and_evaluate_efs(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
assembly: super::assembly::InnerAssembly<'static>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
use super::unified::{compute_efs_update, compute_hybrid_efs_update};
let beta_for_barrier = assembly.beta.clone();
let has_psi = assembly.ext_coords.iter().any(|c| !c.is_penalty_like);
let eval_mode = super::unified::EvalMode::ValueAndGradient;
self.validate_tk_ext_coords(eval_mode, &assembly.ext_coords)?;
let tk_terms = self.tierney_kadane_terms(rho, bundle, eval_mode, &assembly.ext_coords)?;
let assembly_ext_len = assembly.ext_coords.len();
let mut inner_solution = assembly.build();
inner_solution.gaussian_weight_log_sum_half = self.gaussian_weight_log_sum_half();
let inner_hessian_scale =
super::unified::hessian_operator_geometric_scale(inner_solution.hessian_op.as_ref());
let prior = self.build_prior(rho, eval_mode);
let cost_result = super::assembly::evaluate_solution(
&inner_solution,
rho.as_slice().unwrap(),
eval_mode,
prior,
)
.map_err(EstimationError::InvalidInput)?;
let cost_result = self.apply_tk_to_result(cost_result, tk_terms)?;
let block_terms = self.block_local_sampled_correction(rho, bundle, assembly_ext_len)?;
let cost_result = self.apply_tk_to_result(cost_result, block_terms)?;
let cost_result =
self.apply_alo_stabilization_to_result(rho, bundle, eval_mode, cost_result)?;
self.store_ift_mode_response_cache_from_result(rho, bundle, &cost_result);
let gradient =
cost_result
.gradient
.as_ref()
.ok_or(EstimationError::GradientUnavailable {
context: concat!(
"[outer-efs-first-order-fallback] EFS needs gradient; ",
"switch to BFGS or compass search"
),
mode: "ValueAndGradient",
})?;
let efs_eval = if has_psi {
let hybrid = compute_hybrid_efs_update(
&inner_solution,
rho.as_slice().unwrap(),
gradient.as_slice().unwrap(),
);
let diagnostics = super::unified::efs_single_loop_diagnostics(
&inner_solution,
rho.as_slice().unwrap(),
gradient.as_slice().unwrap(),
&hybrid.steps,
bundle.pirls_result.relative_gradient_norm(),
);
self.record_efs_single_loop_bias(rho, diagnostics)?;
let psi_gradient = if hybrid.psi_indices.is_empty() {
None
} else {
Some(ndarray::Array1::from_vec(hybrid.psi_gradient))
};
let psi_indices = if hybrid.psi_indices.is_empty() {
None
} else {
Some(hybrid.psi_indices)
};
crate::solver::outer_strategy::EfsEval {
cost: cost_result.cost,
steps: hybrid.steps,
beta: Some(beta_for_barrier),
psi_gradient,
psi_indices,
inner_hessian_scale,
}
} else {
let steps = compute_efs_update(
&inner_solution,
rho.as_slice().unwrap(),
gradient.as_slice().unwrap(),
);
let diagnostics = super::unified::efs_single_loop_diagnostics(
&inner_solution,
rho.as_slice().unwrap(),
gradient.as_slice().unwrap(),
&steps,
bundle.pirls_result.relative_gradient_norm(),
);
self.record_efs_single_loop_bias(rho, diagnostics)?;
crate::solver::outer_strategy::EfsEval {
cost: cost_result.cost,
steps,
beta: Some(beta_for_barrier),
psi_gradient: None,
psi_indices: None,
inner_hessian_scale,
}
};
Ok(efs_eval)
}
pub fn evaluate_unified(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
let assembly = self.build_auto_assembly(rho, bundle, mode)?;
self.assemble_and_evaluate(rho, bundle, mode, assembly)
}
pub fn evaluate_unified_sparse(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
mode: super::unified::EvalMode,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
let assembly = self.build_sparse_assembly(rho, bundle, mode)?;
self.assemble_and_evaluate(rho, bundle, mode, assembly)
}
pub fn evaluate_unified_with_psi_ext(
&self,
rho: &Array1<f64>,
cache_theta: Option<&Array1<f64>>,
mode: super::unified::EvalMode,
hyper_dirs: &[crate::estimate::reml::DirectionalHyperParam],
) -> Result<super::unified::RemlLamlResult, EstimationError> {
let t0 = std::time::Instant::now();
let bundle = if let Some(theta) = cache_theta {
self.obtain_eval_bundle_for_outer_theta(rho, theta)?
} else {
self.obtain_eval_bundle(rho)?
};
let pirls_ms = t0.elapsed().as_secs_f64() * 1000.0;
let t1 = std::time::Instant::now();
let (ext_coords, ext_pair_fn, rho_ext_pair_fn, fixed_drift_deriv) =
if !hyper_dirs.is_empty() {
if mode == super::unified::EvalMode::ValueGradientHessian {
let (coords, epf, repf, fixed_drift_deriv) =
self.build_tau_unified_objects_from_bundle(rho, &bundle, hyper_dirs)?;
(coords, Some(epf), Some(repf), fixed_drift_deriv)
} else if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
(
self.build_tau_hyper_coords_sparse_exact(rho, &bundle, hyper_dirs)?,
None,
None,
None,
)
} else if matches!(
bundle.pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
) && self
.active_constraint_free_basis(bundle.pirls_result.as_ref())
.is_none()
{
(
self.build_tau_hyper_coords_original_basis(rho, &bundle, hyper_dirs)?,
None,
None,
None,
)
} else {
(
self.build_tau_hyper_coords(rho, &bundle, hyper_dirs)?,
None,
None,
None,
)
}
} else {
(Vec::new(), None, None, None)
};
let tau_build_ms = t1.elapsed().as_secs_f64() * 1000.0;
let t2 = std::time::Instant::now();
let mut assembly = self.build_auto_assembly(rho, &bundle, mode)?;
assembly.ext_coords = ext_coords;
assembly.ext_coord_pair_fn = ext_pair_fn;
assembly.rho_ext_pair_fn = rho_ext_pair_fn;
assembly.fixed_drift_deriv = fixed_drift_deriv;
let result = self.assemble_and_evaluate(rho, &bundle, mode, assembly);
let reml_eval_ms = t2.elapsed().as_secs_f64() * 1000.0;
log::debug!(
"[outer-timing] evaluate_unified_with_psi_ext: PIRLS={:.1}ms tau_build={:.1}ms reml_eval={:.1}ms total={:.1}ms",
pirls_ms,
tau_build_ms,
reml_eval_ms,
t0.elapsed().as_secs_f64() * 1000.0,
);
result
}
pub fn compute_efs_steps(
&self,
p: &Array1<f64>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
if self.large_n_efs_single_loop_lane() {
self.cache_manager.invalidate_eval_bundle();
let previous_cap = self
.outer_inner_cap
.swap(efs_single_loop_encoded_cap(), Ordering::Relaxed);
let result = self.compute_efs_steps_inner(p);
self.outer_inner_cap.store(previous_cap, Ordering::Relaxed);
self.cache_manager.invalidate_eval_bundle();
return result;
}
self.compute_efs_steps_inner(p)
}
fn compute_efs_steps_inner(
&self,
p: &Array1<f64>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
let bundle = match self.obtain_eval_bundle(p) {
Ok(bundle) => bundle,
Err(EstimationError::ModelIsIllConditioned { .. }) => {
self.cache_manager.invalidate_eval_bundle();
return Err(EstimationError::RemlOptimizationFailed(
"inner solve ill-conditioned during EFS evaluation".to_string(),
));
}
Err(e) => {
self.cache_manager.invalidate_eval_bundle();
return Err(e);
}
};
self.evaluate_efs(p, &bundle, Vec::new())
}
pub fn compute_efs_steps_with_psi_ext(
&self,
rho: &Array1<f64>,
hyper_dirs: &[crate::estimate::reml::DirectionalHyperParam],
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
if self.large_n_efs_single_loop_lane() {
self.cache_manager.invalidate_eval_bundle();
let previous_cap = self
.outer_inner_cap
.swap(efs_single_loop_encoded_cap(), Ordering::Relaxed);
let result = self.compute_efs_steps_with_psi_ext_inner(rho, hyper_dirs);
self.outer_inner_cap.store(previous_cap, Ordering::Relaxed);
self.cache_manager.invalidate_eval_bundle();
return result;
}
self.compute_efs_steps_with_psi_ext_inner(rho, hyper_dirs)
}
fn compute_efs_steps_with_psi_ext_inner(
&self,
rho: &Array1<f64>,
hyper_dirs: &[crate::estimate::reml::DirectionalHyperParam],
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
let bundle = match self.obtain_eval_bundle(rho) {
Ok(bundle) => bundle,
Err(EstimationError::ModelIsIllConditioned { .. }) => {
self.cache_manager.invalidate_eval_bundle();
return Err(EstimationError::RemlOptimizationFailed(
"inner solve ill-conditioned during psi-ext EFS evaluation".to_string(),
));
}
Err(e) => {
self.cache_manager.invalidate_eval_bundle();
return Err(e);
}
};
let ext_coords = if !hyper_dirs.is_empty() {
if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
self.build_tau_hyper_coords_sparse_exact(rho, &bundle, hyper_dirs)?
} else if matches!(
bundle.pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
) && self
.active_constraint_free_basis(bundle.pirls_result.as_ref())
.is_none()
{
self.build_tau_hyper_coords_original_basis(rho, &bundle, hyper_dirs)?
} else {
self.build_tau_hyper_coords(rho, &bundle, hyper_dirs)?
}
} else {
Vec::new()
};
if ext_coords.is_empty() {
return self.compute_efs_steps(rho);
}
self.evaluate_efs(rho, &bundle, ext_coords)
}
pub fn compute_gradient(&self, p: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
self.arena
.lastgradient_used_stochastic_fallback
.store(false, Ordering::Relaxed);
let t_eval_start = std::time::Instant::now();
{
let prefix: Vec<String> = p.iter().take(4).map(|v| format!("{:.3}", v)).collect();
log::debug!(
"[REML] grad-only begin | rho[..4]=[{}] | k={}",
prefix.join(","),
p.len()
);
}
let rho_key = EvalCacheManager::sanitized_rhokey(p);
if let Some(eval) = self.cache_manager.cached_outer_eval(&rho_key) {
let gnorm = eval.gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
log::debug!(
"[REML] grad-only cache hit | |g| {:.3e} | elapsed {:.1}ms",
gnorm,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
return Ok(eval.gradient);
}
let t_pirls = std::time::Instant::now();
let bundle = match self.obtain_eval_bundle(p) {
Ok(bundle) => bundle,
Err(err @ EstimationError::ModelIsIllConditioned { .. }) => {
self.cache_manager.invalidate_eval_bundle();
return Err(err);
}
Err(e) => {
self.cache_manager.invalidate_eval_bundle();
return Err(e);
}
};
let pirls_ms = t_pirls.elapsed().as_secs_f64() * 1000.0;
log::debug!(
"[REML] grad-only pirls done | elapsed {:.1}ms | backend {:?}",
pirls_ms,
bundle.backend_kind()
);
let t_assemble = std::time::Instant::now();
if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
let result = self.evaluate_unified_sparse(
p,
&bundle,
super::unified::EvalMode::ValueAndGradient,
)?;
let ift_residual_energy = result.ift_residual_energy;
store_ift_residual_energy_for_outer_theta(p, ift_residual_energy);
let grad = result
.gradient
.ok_or(EstimationError::GradientUnavailable {
context: "REML sparse gradient evaluation requires gradient",
mode: "ValueAndGradient",
})?;
let gnorm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
log::debug!(
"[REML] grad-only sparse done | |g| {:.3e} | assemble {:.1}ms | total {:.1}ms",
gnorm,
t_assemble.elapsed().as_secs_f64() * 1000.0,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
self.update_hypergradient_budget_after_outer_eval(p, &grad, ift_residual_energy);
return Ok(grad);
}
let result =
self.evaluate_unified(p, &bundle, super::unified::EvalMode::ValueAndGradient)?;
let ift_residual_energy = result.ift_residual_energy;
store_ift_residual_energy_for_outer_theta(p, ift_residual_energy);
let grad = result
.gradient
.ok_or(EstimationError::GradientUnavailable {
context: "REML dense gradient evaluation requires gradient",
mode: "ValueAndGradient",
})?;
let gnorm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
log::debug!(
"[REML] grad-only dense done | |g| {:.3e} | assemble {:.1}ms | total {:.1}ms",
gnorm,
t_assemble.elapsed().as_secs_f64() * 1000.0,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
self.update_hypergradient_budget_after_outer_eval(p, &grad, ift_residual_energy);
Ok(grad)
}
pub fn compute_outer_eval_with_order(
&self,
p: &Array1<f64>,
order: crate::solver::outer_strategy::OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
self.arena
.lastgradient_used_stochastic_fallback
.store(false, Ordering::Relaxed);
let allow_second_order = matches!(
order,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian
) && self.analytic_outer_hessian_enabled();
let t_eval_start = std::time::Instant::now();
{
let prefix: Vec<String> = p.iter().take(4).map(|v| format!("{:.3}", v)).collect();
log::debug!(
"[REML] outer-eval begin {:?} | rho[..4]=[{}] | k={} | 2nd-order={}",
order,
prefix.join(","),
p.len(),
allow_second_order
);
}
let rho_key = EvalCacheManager::sanitized_rhokey(p);
if let Some(eval) = self.cache_manager.cached_outer_eval(&rho_key) {
let cache_satisfies_request = !allow_second_order || eval.hessian.is_analytic();
if cache_satisfies_request {
let gnorm = eval.gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
log::debug!(
"[REML] outer-eval cache hit | cost {:.6e} | |g| {:.3e} | elapsed {:.1}ms",
eval.cost,
gnorm,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
return Ok(eval);
}
}
let t_pirls = std::time::Instant::now();
let bundle = match self.obtain_eval_bundle(p) {
Ok(bundle) => bundle,
Err(err) if err.is_inner_solve_retreat() => {
self.cache_manager.invalidate_eval_bundle();
log::debug!(
"P-IRLS inner-solve retreat at current rho ({}); returning infeasible outer eval.",
err
);
return Ok(OuterEval::infeasible(p.len()));
}
Err(err) => {
self.cache_manager.invalidate_eval_bundle();
return Err(err);
}
};
let decision = match order {
crate::solver::outer_strategy::OuterEvalOrder::ValueAndGradient => None,
crate::solver::outer_strategy::OuterEvalOrder::ValueGradientHessian => {
if allow_second_order {
Some(self.selecthessian_strategy_policy(&bundle))
} else {
None
}
}
};
let eval_mode = match decision.as_ref().map(|decision| decision.strategy) {
Some(HessianEvalStrategyKind::SpectralExact) => {
super::unified::EvalMode::ValueGradientHessian
}
_ => super::unified::EvalMode::ValueAndGradient,
};
let pirls_ms = t_pirls.elapsed().as_secs_f64() * 1000.0;
log::debug!(
"[REML] outer-eval pirls done | elapsed {:.1}ms | backend {:?} | mode {:?}",
pirls_ms,
bundle.backend_kind(),
eval_mode
);
let t_assemble = std::time::Instant::now();
let result = if bundle.backend_kind() == GeometryBackendKind::SparseExactSpd {
self.evaluate_unified_sparse(p, &bundle, eval_mode)?
} else {
self.evaluate_unified(p, &bundle, eval_mode)?
};
let assemble_ms = t_assemble.elapsed().as_secs_f64() * 1000.0;
let ift_residual_energy = result.ift_residual_energy;
store_ift_residual_energy_for_outer_theta(p, ift_residual_energy);
let gradient = result.gradient.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"unified evaluator returned no gradient in {:?} mode",
eval_mode
))
})?;
let hessian = match decision.map(|decision| decision.strategy) {
Some(HessianEvalStrategyKind::SpectralExact) => result.hessian,
None => HessianResult::Unavailable,
};
let eval = OuterEval {
cost: result.cost,
gradient,
hessian,
inner_beta_hint: self.current_original_basis_beta(),
};
{
let gnorm = eval.gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
log::debug!(
"[REML] outer-eval done | cost {:.6e} | |g| {:.3e} | assemble {:.1}ms | total {:.1}ms",
eval.cost,
gnorm,
assemble_ms,
t_eval_start.elapsed().as_secs_f64() * 1000.0
);
}
self.update_hypergradient_budget_after_outer_eval(p, &eval.gradient, ift_residual_energy);
self.cache_manager.store_outer_eval(&rho_key, &eval);
Ok(eval)
}
fn build_link_ext_coords(
&self,
bundle: &EvalShared,
) -> Result<Vec<super::unified::HyperCoord>, EstimationError> {
if let Some(sas_state) = &self.runtime_sas_link_state {
let is_beta_logistic = matches!(
self.config.link_function(),
crate::types::LinkFunction::BetaLogistic
);
self.build_sas_link_ext_coords(bundle, sas_state, is_beta_logistic)
} else if let Some(mix_state) = &self.runtime_mixture_link_state {
self.build_mixture_link_ext_coords(bundle, mix_state)
} else {
Ok(Vec::new())
}
}
fn reject_firth_link_ext(&self) -> Result<(), EstimationError> {
if reml_robust_jeffreys_link(&self.config).is_some() {
crate::bail_invalid_estim!(
"link-parameter ext_coord optimization is incompatible with \
Firth-adjusted outer gradients"
.to_string(),
);
}
Ok(())
}
fn rotate_link_ext_coords_to_original(
&self,
bundle: &EvalShared,
coords: &mut [super::unified::HyperCoord],
) -> Result<(), EstimationError> {
if coords.is_empty() {
return Ok(());
}
let pirls_result = bundle.pirls_result.as_ref();
let needs_rotation = matches!(
pirls_result.coordinate_frame,
pirls::PirlsCoordinateFrame::TransformedQs
) && self.active_constraint_free_basis(pirls_result).is_none();
if !needs_rotation {
return Ok(());
}
let qs = &pirls_result.reparam_result.qs;
let qs_t = qs.t();
for (coord_idx, coord) in coords.iter_mut().enumerate() {
if coord.g.len() == qs.nrows() {
coord.g = qs.dot(&coord.g);
}
if coord.drift.block_local.is_some() || coord.drift.operator.is_some() {
crate::bail_invalid_estim!(
"link-ext HyperCoord[{coord_idx}] carries a non-dense drift \
variant (block_local={} operator={}); the coord-frame rotation \
helper only handles `HyperCoordDrift::from_dense`. Update the \
link-ext builder (or extend `rotate_link_ext_coords_to_original`) \
before attaching non-dense drifts to original-basis assemblies.",
coord.drift.block_local.is_some(),
coord.drift.operator.is_some(),
);
}
if let Some(b) = coord.drift.dense.as_mut()
&& b.nrows() == qs.nrows()
&& b.ncols() == qs.nrows()
{
let tmp = qs.dot(&*b);
*b = tmp.dot(&qs_t);
}
}
Ok(())
}
pub fn evaluate_unified_with_link_ext(
&self,
rho: &Array1<f64>,
mode: super::unified::EvalMode,
) -> Result<super::unified::RemlLamlResult, EstimationError> {
self.reject_firth_link_ext()?;
let bundle = self.obtain_eval_bundle(rho)?;
let mut ext_coords = self.build_link_ext_coords(&bundle)?;
if ext_coords.is_empty() {
return self.evaluate_unified(rho, &bundle, mode);
}
self.rotate_link_ext_coords_to_original(&bundle, &mut ext_coords)?;
let mut assembly = self.build_auto_assembly(rho, &bundle, mode)?;
let ext_dim = ext_coords.len();
let p_dim = ext_coords.first().map(|coord| coord.g.len()).unwrap_or(0);
assembly.ext_coords = ext_coords;
if mode == super::unified::EvalMode::ValueGradientHessian {
assembly.rho_ext_pair_fn = Some(Box::new(move |_, _| super::unified::HyperCoordPair {
a: 0.0,
g: Array1::zeros(p_dim),
b_mat: Array2::zeros((p_dim, p_dim)),
b_operator: None,
ld_s: 0.0,
}));
assembly.ext_coord_pair_fn =
Some(Box::new(move |_, _| super::unified::HyperCoordPair {
a: 0.0,
g: Array1::zeros(p_dim),
b_mat: Array2::zeros((p_dim, p_dim)),
b_operator: None,
ld_s: 0.0,
}));
assert!(ext_dim > 0);
}
self.assemble_and_evaluate(rho, &bundle, mode, assembly)
}
pub fn compute_efs_steps_with_link_ext(
&self,
rho: &Array1<f64>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
self.reject_firth_link_ext()?;
let bundle = self.obtain_eval_bundle(rho)?;
let mut ext_coords = self.build_link_ext_coords(&bundle)?;
if ext_coords.is_empty() {
return self.compute_efs_steps(rho);
}
self.rotate_link_ext_coords_to_original(&bundle, &mut ext_coords)?;
self.evaluate_efs(rho, &bundle, ext_coords)
}
fn evaluate_efs(
&self,
rho: &Array1<f64>,
bundle: &EvalShared,
ext_coords: Vec<super::unified::HyperCoord>,
) -> Result<crate::solver::outer_strategy::EfsEval, EstimationError> {
let mut assembly =
self.build_auto_assembly(rho, bundle, super::unified::EvalMode::ValueOnly)?;
assembly.tk_gradient = None;
assembly.ext_coords = ext_coords;
self.assemble_and_evaluate_efs(rho, bundle, assembly)
}
}
fn positive_penalty_rank_and_logdet(eigenvalues: &[f64]) -> (usize, f64) {
let threshold = super::unified::positive_eigenvalue_threshold(eigenvalues);
let rank = eigenvalues.iter().filter(|&&ev| ev > threshold).count();
let log_det = super::unified::exact_pseudo_logdet(eigenvalues, threshold);
(rank, log_det)
}
#[cfg(test)]
mod tk_math_tests {
use super::*;
use crate::faer_ndarray::FaerCholesky;
use crate::mixture_link::{
beta_logistic_inverse_link_jet, beta_logistic_inverse_link_pdffourth_derivative,
beta_logistic_inverse_link_pdfthird_derivative, inverse_link_jet_for_inverse_link,
inverse_link_pdffourth_derivative_for_inverse_link,
inverse_link_pdfthird_derivative_for_inverse_link, mixture_inverse_link_jet,
sas_inverse_link_jet, sas_inverse_link_pdffourth_derivative,
sas_inverse_link_pdfthird_derivative, state_fromspec,
};
use crate::pirls::{VarianceJet, e_obs_from_jets};
use crate::types::{LinkComponent, MixtureLinkSpec};
use faer::Side;
use ndarray::array;
use num_dual::{Dual3_64, Dual64, DualNum, third_derivative};
#[test]
fn firth_default_pc_prior_fills_flat_holes() {
let pc = firth_default_pc_prior();
let configured = RhoPrior::Normal { mean: 0.1, sd: 2.0 };
assert_eq!(*resolve_effective_rho_prior(&RhoPrior::Flat), pc);
assert_eq!(*resolve_effective_rho_prior(&configured), configured);
let indep = RhoPrior::Independent(vec![RhoPrior::Flat, configured.clone()]);
assert_eq!(
*resolve_effective_rho_prior(&indep),
RhoPrior::Independent(vec![pc.clone(), configured.clone()])
);
let no_holes = RhoPrior::Independent(vec![configured.clone(), configured.clone()]);
assert_eq!(*resolve_effective_rho_prior(&no_holes), no_holes);
}
#[test]
fn firth_default_coord_mask_marks_only_flat_coordinates() {
assert_eq!(firth_default_coord_mask(&RhoPrior::Flat, 3), vec![true; 3]);
let indep = RhoPrior::Independent(vec![
RhoPrior::Flat,
RhoPrior::Normal { mean: 0.0, sd: 1.0 },
RhoPrior::Flat,
]);
assert_eq!(firth_default_coord_mask(&indep, 3), vec![true, false, true]);
assert_eq!(
firth_default_coord_mask(&RhoPrior::Normal { mean: 0.0, sd: 1.0 }, 2),
vec![false; 2]
);
}
#[test]
fn firth_default_barrier_is_byte_zero_on_identified_side() {
use super::super::rho_prior_eval::{firth_default_barrier_terms, pc_prior_rate};
let upper = FIRTH_DEFAULT_PC_UPPER;
let theta = pc_prior_rate(upper, FIRTH_DEFAULT_PC_TAIL_PROB);
let rho_gate = -2.0 * upper.ln();
for &r in &[rho_gate, rho_gate + 1e-9, 0.0, 5.0, 50.0] {
let (c, g, h) = firth_default_barrier_terms(theta, upper, r);
assert_eq!((c, g, h), (0.0, 0.0, 0.0), "must be byte-zero at ρ={r}");
}
for &r in &[rho_gate - 1.0, rho_gate - 5.0, -20.0] {
let (c, g, h) = firth_default_barrier_terms(theta, upper, r);
assert!(c > 0.0, "cost must be positive below the gate at ρ={r}");
assert!(g < 0.0, "gradient must push ρ up (away from λ→0) at ρ={r}");
assert!(h > 0.0, "curvature must be positive at ρ={r}");
}
let (c_below, _, _) = firth_default_barrier_terms(theta, upper, rho_gate - 1e-6);
assert!(
c_below.abs() < 1e-9,
"cost continuous at the gate, got {c_below}"
);
let r = rho_gate - 2.0;
let cost_at = |dr: f64| firth_default_barrier_terms(theta, upper, r + dr).0;
let grad_at = |dr: f64| firth_default_barrier_terms(theta, upper, r + dr).1;
let (_, g, h) = firth_default_barrier_terms(theta, upper, r);
let fd_g = (cost_at(1e-6) - cost_at(-1e-6)) / 2e-6;
let fd_h = (grad_at(1e-5) - grad_at(-1e-5)) / 2e-5;
assert!((fd_g - g).abs() < 1e-6, "grad FD {fd_g} vs {g}");
assert!((fd_h - h).abs() < 1e-5, "hess FD {fd_h} vs {h}");
}
#[test]
fn penalty_rank_uses_actual_positive_eigenspace_not_root_rows() {
let e = array![[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0],];
let s = e.t().dot(&e);
let (evals, _) = s.eigh(Side::Lower).expect("penalty eigensystem");
let (rank, log_det) = positive_penalty_rank_and_logdet(evals.as_slice().unwrap());
assert_eq!(rank, 1);
assert!(
(log_det - 2.0_f64.ln()).abs() < 1e-12,
"logdet should use the single positive eigenvalue, got {log_det}"
);
}
fn solve_vec(h: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
h.cholesky(Side::Lower).expect("chol(H)").solvevec(rhs)
}
fn solve_xt(h: &Array2<f64>, x: &Array2<f64>) -> Array2<f64> {
let mut xt = x.t().to_owned();
h.cholesky(Side::Lower)
.expect("chol(H)")
.solve_mat_in_place(&mut xt);
xt
}
fn tk_scalar_dual<D: DualNum<f64> + Copy>(h: D, x: &[D], c: &[D], d_arr: &[D]) -> D {
let n = x.len();
let inv_h = D::one() / h;
let mut h_diag = vec![D::from(0.0); n];
for i in 0..n {
h_diag[i] = x[i] * x[i] * inv_h;
}
let mut d_term = D::from(0.0);
for i in 0..n {
d_term += d_arr[i] * h_diag[i] * h_diag[i];
}
d_term *= D::from(-0.125);
let mut c_term = D::from(0.0);
for i in 0..n {
for j in 0..n {
let k_ij = x[i] * x[j] * inv_h;
c_term += c[i] * c[j] * k_ij * k_ij * k_ij;
}
}
c_term *= D::from(1.0 / 12.0);
let mut q = D::from(0.0);
for i in 0..n {
q += x[i] * c[i] * h_diag[i];
}
let q_term = D::from(0.125) * q * q * inv_h;
d_term + c_term + q_term
}
#[test]
fn tierney_kadane_gradient_matches_dual_ad_reference_scalarhessian() {
let x_vec = vec![1.3_f64, -0.4, 0.7, -0.9];
let h_val = 2.5_f64;
let c = array![0.21_f64, -0.13, 0.18, 0.07];
let d = array![-0.05_f64, 0.09, 0.04, -0.07];
let e = array![0.03_f64, -0.04, 0.02, 0.018];
let eta_dot = array![0.11_f64, -0.06, 0.07, -0.085];
let x_dot_vec = vec![0.025_f64, -0.018, 0.022, -0.014];
let h_dot_val = 0.07_f64;
let h_dual = Dual64::new(h_val, h_dot_val);
let x_dual: Vec<Dual64> = x_vec
.iter()
.zip(x_dot_vec.iter())
.map(|(&v, &dv)| Dual64::new(v, dv))
.collect();
let c_dual: Vec<Dual64> = c
.iter()
.zip(d.iter().zip(eta_dot.iter()))
.map(|(&cv, (&dv, &edot))| Dual64::new(cv, dv * edot))
.collect();
let d_dual: Vec<Dual64> = d
.iter()
.zip(e.iter().zip(eta_dot.iter()))
.map(|(&dv, (&ev, &edot))| Dual64::new(dv, ev * edot))
.collect();
let v_tk_dual = tk_scalar_dual(h_dual, &x_dual, &c_dual, &d_dual);
let dv_tk_ad = v_tk_dual.eps;
let v_tk_value = v_tk_dual.re;
let x_mat = Array2::from_shape_vec((4, 1), x_vec.clone()).unwrap();
let h_mat = Array2::from_shape_vec((1, 1), vec![h_val]).unwrap();
let z_mat = solve_xt(&h_mat, &x_mat);
let solve = |rhs: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
Ok(solve_vec(&h_mat, rhs))
};
let shared = RemlState::tk_shared_intermediates(
&x_mat,
&z_mat,
&c,
"tk scalar dual baseline",
&solve,
)
.expect("shared TK intermediates");
let mut gram = Array2::<f64>::zeros((TK_BLOCK_SIZE, TK_BLOCK_SIZE));
let v_tk_prod = RemlState::tk_scalar_from_shared(&x_mat, &z_mat, &d, &shared, &mut gram)
.expect("TK scalar (production)");
let scalar_rel =
(v_tk_value - v_tk_prod).abs() / v_tk_value.abs().max(v_tk_prod.abs()).max(1.0e-14);
assert!(
scalar_rel < 1.0e-12,
"TK scalar mismatch (production vs closed-form): prod={v_tk_prod:.12e}, dual_re={v_tk_value:.12e}, rel={scalar_rel:.3e}"
);
let h_dot_mat = Array2::from_shape_vec((1, 1), vec![h_dot_val]).unwrap();
let x_dot_mat = Array2::from_shape_vec((4, 1), x_dot_vec.clone()).unwrap();
let gradient = RemlState::tk_gradient_from_shared(
&x_mat,
&z_mat,
&c,
&d,
&e,
&[],
&[],
&[h_dot_mat],
&[Some(eta_dot.clone())],
&[Some(x_dot_mat)],
&[Array1::<f64>::zeros(x_mat.nrows())],
&[Array1::<f64>::zeros(x_mat.ncols())],
None,
&shared,
&mut gram,
)
.expect("analytic TK derivative");
let analytic = gradient[0];
let rel = (analytic - dv_tk_ad).abs() / analytic.abs().max(dv_tk_ad.abs()).max(1.0e-14);
assert!(
rel < 1.0e-12,
"Tierney-Kadane analytic c/d propagation does not match Dual64 AD reference: analytic={analytic:.12e}, ad={dv_tk_ad:.12e}, rel={rel:.3e}"
);
}
fn taylor_mu_dual3(delta: Dual3_64, h0: f64, h1: f64, h2: f64, h3: f64) -> Dual3_64 {
let inv2 = 0.5_f64;
let inv6 = 1.0_f64 / 6.0_f64;
let d2 = delta * delta;
let d3 = d2 * delta;
Dual3_64::from_re(h0) + delta * h1 + d2 * (h2 * inv2) + d3 * (h3 * inv6)
}
fn taylor_mu_eta_dual3(delta: Dual3_64, h1: f64, h2: f64, h3: f64, h4: f64) -> Dual3_64 {
let inv2 = 0.5_f64;
let inv6 = 1.0_f64 / 6.0_f64;
let d2 = delta * delta;
let d3 = d2 * delta;
Dual3_64::from_re(h1) + delta * h2 + d2 * (h3 * inv2) + d3 * (h4 * inv6)
}
fn taylor_mu_etaeta_dual3(delta: Dual3_64, h2: f64, h3: f64, h4: f64, h5: f64) -> Dual3_64 {
let inv2 = 0.5_f64;
let inv6 = 1.0_f64 / 6.0_f64;
let d2 = delta * delta;
let d3 = d2 * delta;
Dual3_64::from_re(h2) + delta * h3 + d2 * (h4 * inv2) + d3 * (h5 * inv6)
}
fn binomial_w_obs_dual3(
mu: Dual3_64,
mu_eta: Dual3_64,
mu_etaeta: Dual3_64,
y: f64,
phi: f64,
) -> Dual3_64 {
let one = Dual3_64::from_re(1.0);
let two = Dual3_64::from_re(2.0);
let v = mu * (one - mu);
let v_mu = one - two * mu;
let h1sq = mu_eta * mu_eta;
let phi_v = v * phi;
let fisher = h1sq / phi_v;
let resid = Dual3_64::from_re(y) - mu;
let t_prime = (mu_etaeta * v - h1sq * v_mu) / (phi_v * v);
fisher - resid * t_prime
}
fn assert_e_obs_matches_dual3_ad(
link_label: &str,
eta0: f64,
y: f64,
h0: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
h5: f64,
phi: f64,
tol: f64,
) {
let v = h0 * (1.0 - h0);
let vj = VarianceJet::bernoulli(h0);
assert!(
v.is_finite() && v > 0.0,
"fixture {link_label} (eta={eta0}) has degenerate V(μ)={v}; pick a non-saturated point"
);
let analytic = e_obs_from_jets(y, h0, h1, h2, h3, h4, h5, vj, phi, 1.0);
let (_, _, _, d3_w) = third_derivative(
|delta| {
let mu = taylor_mu_dual3(delta, h0, h1, h2, h3);
let mu_eta = taylor_mu_eta_dual3(delta, h1, h2, h3, h4);
let mu_etaeta = taylor_mu_etaeta_dual3(delta, h2, h3, h4, h5);
binomial_w_obs_dual3(mu, mu_eta, mu_etaeta, y, phi)
},
0.0_f64,
);
let scale = analytic.abs().max(d3_w.abs()).max(1.0e-12);
let rel = (analytic - d3_w).abs() / scale;
assert!(
rel < tol,
"{link_label}: analytic e_obs ({analytic:.12e}) does not match Dual3 AD ({d3_w:.12e}) at eta={eta0}, y={y}; rel={rel:.3e}"
);
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_probit_bernoulli() {
assert!(file!().ends_with(".rs"));
let etas = [-1.4_f64, -0.4, 0.0, 0.6, 1.3];
let ys = [0.0_f64, 1.0];
let probit = InverseLink::Standard(StandardLink::Probit);
for &eta in &etas {
let jet = inverse_link_jet_for_inverse_link(&probit, eta).expect("probit jet");
let h4 =
inverse_link_pdfthird_derivative_for_inverse_link(&probit, eta).expect("probit h4");
let h5 = inverse_link_pdffourth_derivative_for_inverse_link(&probit, eta)
.expect("probit h5");
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"probit", eta, y, jet.mu, jet.d1, jet.d2, jet.d3, h4, h5, 1.0, 1.0e-9,
);
}
}
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_cloglog_bernoulli() {
assert!(file!().ends_with(".rs"));
let etas = [-1.6_f64, -0.5, 0.0, 0.4, 1.2];
let ys = [0.0_f64, 1.0];
let cloglog = InverseLink::Standard(StandardLink::CLogLog);
for &eta in &etas {
let jet = inverse_link_jet_for_inverse_link(&cloglog, eta).expect("cloglog jet");
let h4 = inverse_link_pdfthird_derivative_for_inverse_link(&cloglog, eta)
.expect("cloglog h4");
let h5 = inverse_link_pdffourth_derivative_for_inverse_link(&cloglog, eta)
.expect("cloglog h5");
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"cloglog", eta, y, jet.mu, jet.d1, jet.d2, jet.d3, h4, h5, 1.0, 1.0e-9,
);
}
}
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_logit_bernoulli_canonical_consistency() {
let etas = [-1.7_f64, -0.6, 0.0, 0.5, 1.4];
let ys = [0.0_f64, 1.0];
let logit = InverseLink::Standard(StandardLink::Logit);
for &eta in &etas {
let jet = inverse_link_jet_for_inverse_link(&logit, eta).expect("logit jet");
let h4 =
inverse_link_pdfthird_derivative_for_inverse_link(&logit, eta).expect("logit h4");
let h5 =
inverse_link_pdffourth_derivative_for_inverse_link(&logit, eta).expect("logit h5");
let canonical_fast_path = h4; let general_at_y_eq_mu = e_obs_from_jets(
jet.mu,
jet.mu,
jet.d1,
jet.d2,
jet.d3,
h4,
h5,
VarianceJet::bernoulli(jet.mu),
1.0,
1.0,
);
let rel_fast = (general_at_y_eq_mu - canonical_fast_path).abs()
/ canonical_fast_path.abs().max(1.0e-12);
assert!(
rel_fast < 1.0e-10,
"Logit canonical fast path ({canonical_fast_path:.12e}) does not match general e_obs at y=μ ({general_at_y_eq_mu:.12e}) at eta={eta}; rel={rel_fast:.3e}"
);
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"logit", eta, y, jet.mu, jet.d1, jet.d2, jet.d3, h4, h5, 1.0, 1.0e-9,
);
}
}
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_sas_bernoulli() {
let etas = [-1.1_f64, -0.3, 0.0, 0.5, 1.0];
let ys = [0.0_f64, 1.0];
let configs = [(-0.25_f64, 0.35_f64), (0.4_f64, -0.2_f64)];
for &(epsilon, log_delta) in &configs {
let state = SasLinkState::new(epsilon, log_delta).expect("sas state");
let link = InverseLink::Sas(state);
for &eta in &etas {
let jet = sas_inverse_link_jet(eta, state.epsilon, state.log_delta);
let h4 = sas_inverse_link_pdfthird_derivative(eta, state.epsilon, state.log_delta);
let h5 = sas_inverse_link_pdffourth_derivative(eta, state.epsilon, state.log_delta);
let h4_dispatch = inverse_link_pdfthird_derivative_for_inverse_link(&link, eta)
.expect("sas h4 via dispatch");
let h5_dispatch = inverse_link_pdffourth_derivative_for_inverse_link(&link, eta)
.expect("sas h5 via dispatch");
assert!(
(h4 - h4_dispatch).abs() <= 1.0e-12 * h4.abs().max(1.0),
"sas h4 dispatch mismatch at eta={eta}, eps={epsilon}, log_delta={log_delta}: direct={h4} dispatch={h4_dispatch}"
);
assert!(
(h5 - h5_dispatch).abs() <= 1.0e-12 * h5.abs().max(1.0),
"sas h5 dispatch mismatch at eta={eta}, eps={epsilon}, log_delta={log_delta}: direct={h5} dispatch={h5_dispatch}"
);
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"sas", eta, y, jet.mu, jet.d1, jet.d2, jet.d3, h4, h5, 1.0, 1.0e-9,
);
}
}
}
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_beta_logistic_bernoulli() {
let etas = [-1.0_f64, -0.25, 0.0, 0.4, 0.9];
let ys = [0.0_f64, 1.0];
let configs = [(0.18_f64, -0.22_f64), (-0.3_f64, 0.4_f64)];
for &(epsilon, log_delta) in &configs {
let delta = log_delta.exp();
let state = SasLinkState {
epsilon,
log_delta,
delta,
};
let link = InverseLink::BetaLogistic(state);
for &eta in &etas {
let jet = beta_logistic_inverse_link_jet(eta, state.log_delta, state.epsilon);
let h4 = beta_logistic_inverse_link_pdfthird_derivative(
eta,
state.log_delta,
state.epsilon,
);
let h5 = beta_logistic_inverse_link_pdffourth_derivative(
eta,
state.log_delta,
state.epsilon,
);
let h4_dispatch = inverse_link_pdfthird_derivative_for_inverse_link(&link, eta)
.expect("beta-logistic h4 dispatch");
let h5_dispatch = inverse_link_pdffourth_derivative_for_inverse_link(&link, eta)
.expect("beta-logistic h5 dispatch");
assert!(
(h4 - h4_dispatch).abs() <= 1.0e-12 * h4.abs().max(1.0),
"beta-logistic h4 dispatch mismatch at eta={eta}: direct={h4} dispatch={h4_dispatch}"
);
assert!(
(h5 - h5_dispatch).abs() <= 1.0e-12 * h5.abs().max(1.0),
"beta-logistic h5 dispatch mismatch at eta={eta}: direct={h5} dispatch={h5_dispatch}"
);
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"beta-logistic",
eta,
y,
jet.mu,
jet.d1,
jet.d2,
jet.d3,
h4,
h5,
1.0,
1.0e-9,
);
}
}
}
}
#[test]
fn e_obs_from_jets_matches_dual3_ad_mixture_bernoulli() {
assert!(file!().ends_with(".rs"));
let spec = MixtureLinkSpec {
components: vec![
LinkComponent::Probit,
LinkComponent::Logit,
LinkComponent::CLogLog,
LinkComponent::Cauchit,
],
initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
};
let state = state_fromspec(&spec).expect("mixture state");
let link = InverseLink::Mixture(state.clone());
let etas = [-1.2_f64, -0.4, 0.0, 0.5, 1.1];
let ys = [0.0_f64, 1.0];
for &eta in &etas {
let jet = mixture_inverse_link_jet(&state, eta);
let h4 = inverse_link_pdfthird_derivative_for_inverse_link(&link, eta)
.expect("mixture h4 dispatch");
let h5 = inverse_link_pdffourth_derivative_for_inverse_link(&link, eta)
.expect("mixture h5 dispatch");
for &y in &ys {
assert_e_obs_matches_dual3_ad(
"mixture", eta, y, jet.mu, jet.d1, jet.d2, jet.d3, h4, h5, 1.0, 1.0e-9,
);
}
}
}
}
#[cfg(test)]
mod adaptive_lm_lambda_tests {
use super::adaptive_lm_lambda_hint;
#[test]
fn pathological_cached_lambdas_fall_through_to_cold_default() {
assert_eq!(adaptive_lm_lambda_hint(f64::NAN, 5, true), None);
assert_eq!(adaptive_lm_lambda_hint(f64::INFINITY, 5, true), None);
assert_eq!(adaptive_lm_lambda_hint(0.0, 5, true), None);
assert_eq!(adaptive_lm_lambda_hint(-1e-3, 5, true), None);
}
#[test]
fn no_feedback_yet_falls_through_to_cold_default() {
assert_eq!(adaptive_lm_lambda_hint(1e-5, 0, false), None);
}
#[test]
fn newton_friendly_regime_admits_floor_to_1e_minus_9() {
assert_eq!(
adaptive_lm_lambda_hint(1e-9, 1, true),
Some(1e-9),
"cached λ at the LM floor must pass through unchanged"
);
assert_eq!(
adaptive_lm_lambda_hint(1e-12, 1, true),
Some(1e-9),
"below-floor cached λ clamped up to 1e-9"
);
assert_eq!(
adaptive_lm_lambda_hint(1e-2, 1, true),
Some(1e-3),
"above-ceiling cached λ clamped down to 1e-3 even in Newton-friendly regime"
);
assert_eq!(adaptive_lm_lambda_hint(1e-9, 2, true), Some(1e-9));
}
#[test]
fn hard_fit_regime_preserves_heavy_damping_signal() {
assert_eq!(
adaptive_lm_lambda_hint(0.5, 12, true),
Some(0.5),
"heavy-damping cached λ passes through unchanged"
);
assert_eq!(
adaptive_lm_lambda_hint(2.0, 12, true),
Some(1.0),
"above-ceiling cached λ clamped to 1.0"
);
assert_eq!(
adaptive_lm_lambda_hint(1e-6, 12, true),
Some(1e-3),
"below-floor cached λ clamped up to 1e-3 in hard-fit regime"
);
assert_eq!(adaptive_lm_lambda_hint(0.5, 5, false), Some(0.5));
}
#[test]
fn default_regime_matches_historical_static_clamp() {
assert_eq!(adaptive_lm_lambda_hint(1e-5, 5, true), Some(1e-5));
assert_eq!(adaptive_lm_lambda_hint(1e-9, 5, true), Some(1e-6));
assert_eq!(adaptive_lm_lambda_hint(1e-1, 5, true), Some(1e-3));
assert_eq!(adaptive_lm_lambda_hint(1e-5, 3, true), Some(1e-5));
assert_eq!(adaptive_lm_lambda_hint(1e-5, 9, true), Some(1e-5));
}
}
#[cfg(test)]
mod ift_warm_start_tests {
use super::*;
use crate::construction::CanonicalPenalty;
use crate::linalg::matrix::SymmetricMatrix;
use ndarray::Array2;
#[test]
fn joint_ift_cache_rejects_pending_theta_when_extended_hyperparameters_change() {
let cache = super::IftJointModeResponseRuntimeCache {
theta: Array1::from_vec(vec![0.1, -0.2, 0.5]),
rho_dim: 2,
beta_original: Array1::from_vec(vec![1.0]),
mode_response_cols: Array2::zeros((1, 3)),
active_constraints: false,
};
let pending_theta = Array1::from_vec(vec![0.1, -0.2, 0.9]);
let new_rho = Array1::from_vec(vec![0.1, -0.2]);
assert!(
!super::joint_ift_cache_matches_theta(&cache, &pending_theta, &new_rho),
"extended hyperparameters must participate in joint IFT cache validity"
);
}
fn predict_warm_start_beta_ift_inner(
cache: &super::IftWarmStartCache,
canonical_penalties: &[CanonicalPenalty],
new_rho: &Array1<f64>,
p: usize,
last_ift_residual: Option<f64>,
) -> Option<Coefficients> {
super::predict_warm_start_beta_ift_inner_with_outcome(
cache,
canonical_penalties,
new_rho,
p,
last_ift_residual,
None,
None,
)
.map(|(coef, _outcome)| coef)
}
pub(super) fn predict_warm_start_beta_ift_with_factor(
cache: &super::IftWarmStartCache,
canonical_penalties: &[CanonicalPenalty],
new_rho: &Array1<f64>,
p: usize,
last_ift_residual: Option<f64>,
factor_override: &dyn crate::linalg::matrix::FactorizedSystem,
) -> Option<Coefficients> {
super::predict_warm_start_beta_ift_inner_with_outcome(
cache,
canonical_penalties,
new_rho,
p,
last_ift_residual,
None,
Some(factor_override),
)
.map(|(coef, _outcome)| coef)
}
fn dense_canonical_from_local(local: Array2<f64>, p: usize) -> CanonicalPenalty {
use crate::faer_ndarray::FaerEigh;
use faer::Side;
let (evals, evecs) = local.eigh(Side::Lower).expect("eigh penalty");
let mut rows: Vec<Array1<f64>> = Vec::new();
let mut positive_eigenvalues: Vec<f64> = Vec::new();
for (idx, &lam) in evals.iter().enumerate() {
if lam > 1e-12 {
let scale = lam.sqrt();
let v = evecs.column(idx).to_owned();
rows.push(v.mapv(|x| x * scale));
positive_eigenvalues.push(lam);
}
}
let rank = rows.len();
let mut root = Array2::<f64>::zeros((rank, p));
for (i, row) in rows.iter().enumerate() {
for j in 0..p {
root[[i, j]] = row[j];
}
}
let nullity = p - rank;
CanonicalPenalty {
root,
col_range: 0..p,
total_dim: p,
nullity,
local,
prior_mean: Array1::zeros(p),
positive_eigenvalues,
op: None,
}
}
#[test]
fn ift_predictor_satisfies_linearized_foc_original_basis() {
let p = 5usize;
let s1 = Array2::from_shape_vec(
(p, p),
vec![
2.0, 0.3, 0.0, 0.0, 0.0, 0.3, 2.5, 0.4, 0.0, 0.0, 0.0, 0.4, 1.8, 0.2, 0.0, 0.0,
0.0, 0.2, 1.2, 0.1, 0.0, 0.0, 0.0, 0.1, 1.5,
],
)
.unwrap();
let s2 = Array2::from_shape_vec(
(p, p),
vec![
1.1, 0.0, 0.2, 0.0, 0.1, 0.0, 0.9, 0.0, 0.3, 0.0, 0.2, 0.0, 1.4, 0.0, 0.2, 0.0,
0.3, 0.0, 1.7, 0.0, 0.1, 0.0, 0.2, 0.0, 2.1,
],
)
.unwrap();
let cp1 = dense_canonical_from_local(s1.clone(), p);
let cp2 = dense_canonical_from_local(s2.clone(), p);
let canonical = vec![cp1, cp2];
let beta_cur = ndarray::array![0.4, -0.7, 0.2, 0.9, -0.3];
let rho_cur = ndarray::array![0.2_f64, -0.1];
let mut h_pen = Array2::<f64>::eye(p);
for (k_idx, cp) in canonical.iter().enumerate() {
let lam = rho_cur[k_idx].exp();
for i in 0..p {
for j in 0..p {
h_pen[[i, j]] += lam * cp.local[[i, j]];
}
}
}
let cache = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen.clone()),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let new_rho = ndarray::array![0.25_f64, -0.05];
let predicted = predict_warm_start_beta_ift_inner(&cache, &canonical, &new_rho, p, None)
.expect("IFT predictor should accept small Δρ");
let dbeta = &predicted.0 - &beta_cur;
let lhs = h_pen.dot(&dbeta);
let mut rhs = Array1::<f64>::zeros(p);
for (k_idx, cp) in canonical.iter().enumerate() {
let drho = new_rho[k_idx] - rho_cur[k_idx];
let scale = drho * rho_cur[k_idx].exp();
let sb = cp.local.dot(&beta_cur);
for i in 0..p {
rhs[i] += scale * sb[i];
}
}
let residual: f64 = lhs
.iter()
.zip(rhs.iter())
.map(|(&a, &b)| (a + b) * (a + b))
.sum::<f64>()
.sqrt();
let scale: f64 = rhs.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
assert!(
residual / scale < 1e-9,
"IFT predictor violates linearized FOC: residual {residual:.3e}, scale {scale:.3e}"
);
let dbeta_norm: f64 = dbeta.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
dbeta_norm > 1e-6,
"IFT predictor produced zero β-update; check that Δρ propagated"
);
}
#[test]
fn ift_predictor_precompute_path_matches_inline_path() {
let p = 5usize;
let s1 = Array2::from_shape_vec(
(p, p),
vec![
1.0, 0.2, 0.0, 0.0, 0.0, 0.2, 1.5, 0.1, 0.0, 0.0, 0.0, 0.1, 1.3, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
)
.unwrap();
let s2 = Array2::from_shape_vec(
(p, p),
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 0.05, 0.0, 0.0,
0.0, 0.05, 1.7, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1,
],
)
.unwrap();
let cp1 = dense_canonical_from_local(s1, p);
let cp2 = dense_canonical_from_local(s2, p);
let canonical = vec![cp1.clone(), cp2.clone()];
let beta_cur = ndarray::array![0.4_f64, -0.7, 0.2, 0.9, -0.3];
let rho_cur = ndarray::array![0.2_f64, -0.1];
let mut h_pen = Array2::<f64>::eye(p);
for (k_idx, cp) in canonical.iter().enumerate() {
let lam = rho_cur[k_idx].exp();
for i in 0..p {
for j in 0..p {
h_pen[[i, j]] += lam * cp.local[[i, j]];
}
}
}
let lambda_s_beta_blocks: Vec<ndarray::Array1<f64>> = canonical
.iter()
.map(|cp| {
let r = &cp.col_range;
let beta_block = beta_cur.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
cp.local.dot(¢ered)
})
.collect();
let cache_inline = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen.clone()),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let cache_precompute = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen.clone()),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: Some(lambda_s_beta_blocks),
};
let new_rho = ndarray::array![0.25_f64, -0.05];
let predicted_inline =
predict_warm_start_beta_ift_inner(&cache_inline, &canonical, &new_rho, p, None)
.expect("inline-path predict");
let predicted_precompute =
predict_warm_start_beta_ift_inner(&cache_precompute, &canonical, &new_rho, p, None)
.expect("precompute-path predict");
for i in 0..p {
let diff = (predicted_inline.0[i] - predicted_precompute.0[i]).abs();
assert!(
diff < 1e-12,
"precompute and inline IFT paths diverged at index {i}: \
inline={} precompute={} diff={:.3e}",
predicted_inline.0[i],
predicted_precompute.0[i],
diff,
);
}
let cache_wrong_len = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen.clone()),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: Some(vec![ndarray::Array1::zeros(0)]), };
let predicted_fallback =
predict_warm_start_beta_ift_inner(&cache_wrong_len, &canonical, &new_rho, p, None)
.expect("wrong-length precompute should still predict via inline fallback");
for i in 0..p {
let diff = (predicted_inline.0[i] - predicted_fallback.0[i]).abs();
assert!(
diff < 1e-12,
"wrong-length precompute fell back incorrectly at index {i}",
);
}
let pre_built_factor = SymmetricMatrix::Dense(h_pen.clone())
.factorize()
.expect("factorize for _with_factor test");
let predicted_via_factor = predict_warm_start_beta_ift_with_factor(
&cache_inline,
&canonical,
&new_rho,
p,
None,
pre_built_factor.as_ref(),
)
.expect("with_factor predict");
for i in 0..p {
let diff = (predicted_inline.0[i] - predicted_via_factor.0[i]).abs();
assert!(
diff < 1e-12,
"_with_factor and inline paths diverged at index {i}: \
inline={} factor={} diff={:.3e}",
predicted_inline.0[i],
predicted_via_factor.0[i],
diff,
);
}
let mut h_perturbed = h_pen.clone();
for i in 0..p {
h_perturbed[[i, i]] *= 5.0;
}
let perturbed_factor = SymmetricMatrix::Dense(h_perturbed)
.factorize()
.expect("factorize perturbed H");
let predicted_with_wrong_factor = predict_warm_start_beta_ift_with_factor(
&cache_inline,
&canonical,
&new_rho,
p,
None,
perturbed_factor.as_ref(),
)
.expect("with_factor predict (wrong factor)");
let mut max_diff = 0.0_f64;
for i in 0..p {
let diff = (predicted_inline.0[i] - predicted_with_wrong_factor.0[i]).abs();
if diff > max_diff {
max_diff = diff;
}
}
assert!(
max_diff > 1e-3,
"wrong factor produced bit-identical output (max_diff={:.3e}); \
_with_factor path is silently ignoring its factor argument",
max_diff,
);
}
#[test]
fn ift_predictor_basis_conversion_matches_original() {
let p = 4usize;
let s1 = Array2::from_shape_vec(
(p, p),
vec![
1.5, 0.2, 0.0, 0.0, 0.2, 1.8, 0.1, 0.0, 0.0, 0.1, 1.2, 0.05, 0.0, 0.0, 0.05, 1.4,
],
)
.unwrap();
let cp1 = dense_canonical_from_local(s1.clone(), p);
let canonical = vec![cp1];
let beta_cur = ndarray::array![0.5, -0.3, 0.8, 0.2];
let rho_cur = ndarray::array![0.1_f64];
let mut h_orig = Array2::<f64>::eye(p);
let lam = rho_cur[0].exp();
for i in 0..p {
for j in 0..p {
h_orig[[i, j]] += lam * canonical[0].local[[i, j]];
}
}
let raw = Array2::from_shape_vec(
(p, p),
vec![
1.0, 0.5, 0.0, 0.1, 0.0, 1.0, 0.3, 0.0, 0.2, 0.0, 1.0, 0.4, 0.0, 0.1, 0.0, 1.0,
],
)
.unwrap();
let (q, _r) = {
use crate::faer_ndarray::FaerQr;
raw.qr().expect("QR")
};
let qs: Array2<f64> = q;
let qq = qs.t().dot(&qs);
for i in 0..p {
for j in 0..p {
let target = if i == j { 1.0 } else { 0.0 };
assert!(
(qq[[i, j]] - target).abs() < 1e-10,
"qs not orthogonal at [{},{}]: {}",
i,
j,
qq[[i, j]]
);
}
}
let h_tfd = qs.t().dot(&h_orig).dot(&qs);
let cache_tfd = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_tfd),
qs: qs.clone(),
frame_was_original: false,
lambda_s_beta_blocks: None,
};
let cache_orig = super::super::IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur.clone(),
penalized_hessian_transformed: SymmetricMatrix::Dense(h_orig.clone()),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let new_rho = ndarray::array![0.3_f64];
let predicted_tfd =
predict_warm_start_beta_ift_inner(&cache_tfd, &canonical, &new_rho, p, None)
.expect("tfd predict");
let predicted_orig =
predict_warm_start_beta_ift_inner(&cache_orig, &canonical, &new_rho, p, None)
.expect("orig predict");
for i in 0..p {
assert!(
(predicted_tfd.0[i] - predicted_orig.0[i]).abs() < 1e-10,
"basis-conversion path mismatch at index {i}: tfd={}, orig={}",
predicted_tfd.0[i],
predicted_orig.0[i],
);
}
}
#[test]
fn ift_predictor_rejects_large_drho() {
let p = 3usize;
let s1 = Array2::from_shape_vec((p, p), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
.unwrap();
let cp1 = dense_canonical_from_local(s1, p);
let canonical = vec![cp1];
let beta_cur = ndarray::array![1.0, 1.0, 1.0];
let rho_cur = ndarray::array![0.0_f64];
let h_pen = Array2::<f64>::eye(p) * 2.0;
let cache = super::super::IftWarmStartCache {
beta_original: beta_cur,
rho: rho_cur,
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let new_rho = ndarray::array![3.0_f64];
let predicted = predict_warm_start_beta_ift_inner(&cache, &canonical, &new_rho, p, None);
assert!(
predicted.is_none(),
"predictor should reject Δρ above default cap, got {:?}",
predicted
);
let predicted_good_history =
predict_warm_start_beta_ift_inner(&cache, &canonical, &new_rho, p, Some(0.005));
assert!(
predicted_good_history.is_some(),
"predictor should accept Δρ=3 under expanded cap (good prior quality)",
);
let modest_rho = ndarray::array![1.0_f64];
let predicted_bad_history =
predict_warm_start_beta_ift_inner(&cache, &canonical, &modest_rho, p, Some(0.6));
assert!(
predicted_bad_history.is_none(),
"predictor should reject Δρ=1 under tightened cap (poor prior quality)",
);
}
#[test]
fn ift_predictor_returns_noop_when_all_drho_below_eps() {
use crate::estimate::reml::IftWarmStartCache;
let p = 3usize;
let s1 = Array2::from_shape_vec((p, p), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
.unwrap();
let cp1 = dense_canonical_from_local(s1, p);
let canonical = vec![cp1];
let beta_cur = ndarray::array![0.5_f64, -0.3, 1.7];
let rho_cur = ndarray::array![0.0_f64];
let h_pen = Array2::<f64>::eye(p) * 2.0;
let cache = IftWarmStartCache {
beta_original: beta_cur.clone(),
rho: rho_cur,
penalized_hessian_transformed: SymmetricMatrix::Dense(h_pen),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let new_rho = ndarray::array![1e-15_f64];
let predicted = predict_warm_start_beta_ift_inner(&cache, &canonical, &new_rho, p, None)
.expect("noop must return Some(β_cur)");
for i in 0..p {
assert_eq!(
predicted.0[i], beta_cur[i],
"noop must return cached β bit-equal at idx {i}: got {} vs cached {}",
predicted.0[i], beta_cur[i]
);
}
}
#[test]
fn ift_predictor_rejects_unstamped_cache() {
let p = 2usize;
let cache = super::super::IftWarmStartCache {
beta_original: ndarray::array![1.0, 2.0],
rho: Array1::zeros(0),
penalized_hessian_transformed: SymmetricMatrix::Dense(Array2::eye(p)),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let new_rho = ndarray::array![0.1_f64];
let predicted = predict_warm_start_beta_ift_inner(&cache, &[], &new_rho, p, None);
assert!(predicted.is_none());
}
#[test]
fn adaptive_ift_max_drho_follows_quality_tiers() {
use super::adaptive_ift_max_drho;
assert_eq!(adaptive_ift_max_drho(None), 2.0);
assert_eq!(adaptive_ift_max_drho(Some(f64::NAN)), 2.0);
assert_eq!(adaptive_ift_max_drho(Some(-1.0)), 2.0);
assert_eq!(adaptive_ift_max_drho(Some(f64::INFINITY)), 0.5);
assert_eq!(adaptive_ift_max_drho(Some(0.0)), 4.0);
assert_eq!(adaptive_ift_max_drho(Some(0.005)), 4.0);
assert_eq!(adaptive_ift_max_drho(Some(0.01)), 3.0);
assert_eq!(adaptive_ift_max_drho(Some(0.04)), 3.0);
assert_eq!(adaptive_ift_max_drho(Some(0.05)), 2.0);
assert_eq!(adaptive_ift_max_drho(Some(0.10)), 2.0);
assert_eq!(adaptive_ift_max_drho(Some(0.20)), 1.0);
assert_eq!(adaptive_ift_max_drho(Some(0.30)), 1.0);
assert_eq!(adaptive_ift_max_drho(Some(0.50)), 0.5);
assert_eq!(adaptive_ift_max_drho(Some(1.5)), 0.5);
let residuals = [0.001, 0.02, 0.10, 0.30, 0.80];
let caps: Vec<f64> = residuals
.iter()
.map(|&r| adaptive_ift_max_drho(Some(r)))
.collect();
for w in caps.windows(2) {
assert!(
w[0] >= w[1],
"adaptive cap is not monotone non-increasing in residual: {caps:?}"
);
}
}
#[test]
fn parallel_lambda_s_beta_blocks_matches_serial() {
use rayon::prelude::*;
let p = 50usize;
let n_penalties = 8usize;
let mut canonical = Vec::with_capacity(n_penalties);
for k in 0..n_penalties {
let scale = (k as f64 + 1.0) * 0.5;
let mut s = Array2::<f64>::zeros((p, p));
for i in 0..p {
s[[i, i]] = scale;
if i + 1 < p {
s[[i, i + 1]] = scale * 0.1;
s[[i + 1, i]] = scale * 0.1;
}
}
canonical.push(dense_canonical_from_local(s, p));
}
let beta_cur =
ndarray::Array1::from_shape_fn(p, |i| (i as f64 * 0.1).sin() + (i as f64 * 0.05).cos());
let serial: Vec<ndarray::Array1<f64>> = canonical
.iter()
.map(|cp| {
let r = &cp.col_range;
let beta_block = beta_cur.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
cp.local.dot(¢ered)
})
.collect();
let parallel: Vec<ndarray::Array1<f64>> = canonical
.par_iter()
.map(|cp| {
let r = &cp.col_range;
let beta_block = beta_cur.slice(s![r.start..r.end]);
let centered = &beta_block - &cp.prior_mean;
cp.local.dot(¢ered)
})
.collect();
assert_eq!(serial.len(), parallel.len());
for (s_block, p_block) in serial.iter().zip(parallel.iter()) {
assert_eq!(s_block.len(), p_block.len());
for (a, b) in s_block.iter().zip(p_block.iter()) {
assert_eq!(
a.to_bits(),
b.to_bits(),
"parallel mat-vec diverged from serial: {a} vs {b}",
);
}
}
}
#[test]
fn ift_residual_sentinel_is_distinguishable_from_zero() {
use super::IFT_RESIDUAL_NO_SIGNAL_BITS;
assert!(
f64::from_bits(IFT_RESIDUAL_NO_SIGNAL_BITS).is_nan(),
"sentinel must decode to NaN so reads can detect 'no signal yet'",
);
assert_eq!(0.0_f64.to_bits(), 0, "f64::to_bits(0.0) is 0 by IEEE 754");
assert_ne!(
IFT_RESIDUAL_NO_SIGNAL_BITS, 0,
"sentinel must not collide with f64::to_bits(0.0)",
);
let r_zero = f64::from_bits(0.0_f64.to_bits());
assert!(r_zero.is_finite() && r_zero >= 0.0, "0.0 is genuine signal");
let r_sentinel = f64::from_bits(IFT_RESIDUAL_NO_SIGNAL_BITS);
assert!(
!(r_sentinel.is_finite() && r_sentinel >= 0.0),
"sentinel must fail the reader's accept predicate",
);
for &val in &[0.0_f64, 1e-10, 0.05, 0.5, 4.0] {
let bits = val.to_bits();
let back = f64::from_bits(bits);
assert_eq!(back, val, "round-trip failed for {val}");
}
}
#[test]
fn ift_predictor_rejects_dim_mismatches() {
let p = 3usize;
let s1 = Array2::from_shape_vec((p, p), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
.unwrap();
let cp1 = dense_canonical_from_local(s1, p);
let canonical = vec![cp1];
let beta_cur = ndarray::array![1.0_f64, 1.0, 1.0];
let rho_cur = ndarray::array![0.0_f64];
let cache = super::super::IftWarmStartCache {
beta_original: beta_cur,
rho: rho_cur,
penalized_hessian_transformed: SymmetricMatrix::Dense(Array2::<f64>::eye(p) * 2.0),
qs: Array2::eye(p),
frame_was_original: true,
lambda_s_beta_blocks: None,
};
let bad_new_rho = ndarray::array![0.1_f64, 0.2];
assert!(
predict_warm_start_beta_ift_inner(&cache, &canonical, &bad_new_rho, p, None).is_none(),
"new_rho dim mismatch must reject",
);
let new_rho = ndarray::array![0.1_f64];
assert!(
predict_warm_start_beta_ift_inner(&cache, &[], &new_rho, p, None).is_none(),
"penalty dim mismatch must reject",
);
assert!(
predict_warm_start_beta_ift_inner(&cache, &canonical, &new_rho, 4, None).is_none(),
"beta dim mismatch must reject",
);
}
#[test]
fn adaptive_tangent_alpha_cap_follows_quality_tiers() {
use super::adaptive_tangent_alpha_cap;
assert_eq!(adaptive_tangent_alpha_cap(None), 1.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(f64::NAN)), 1.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(-1.0)), 1.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(f64::INFINITY)), 0.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.0)), 2.0);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.005)), 2.0);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.01)), 1.75);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.04)), 1.75);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.05)), 1.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.10)), 1.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.20)), 1.0);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.49)), 1.0);
assert_eq!(adaptive_tangent_alpha_cap(Some(0.50)), 0.5);
assert_eq!(adaptive_tangent_alpha_cap(Some(2.0)), 0.5);
let residuals = [0.001, 0.02, 0.10, 0.30, 0.80];
let caps: Vec<f64> = residuals
.iter()
.map(|&r| adaptive_tangent_alpha_cap(Some(r)))
.collect();
for w in caps.windows(2) {
assert!(
w[0] >= w[1],
"tangent α cap is not monotone non-increasing in residual: {caps:?}"
);
}
}
}
#[cfg(test)]
mod tests_diagnostics {
use super::*;
impl<'a> RemlState<'a> {
pub(crate) fn objective_logdet_h_proj(
&self,
rho: &Array1<f64>,
) -> Result<f64, EstimationError> {
let bundle = self.obtain_eval_bundle(rho)?;
let assembly =
self.build_auto_assembly(rho, &bundle, super::unified::EvalMode::ValueOnly)?;
let logdet = super::unified::HessianOperator::logdet(assembly.hessian_op.as_ref())
+ assembly.hessian_logdet_correction;
Ok(logdet)
}
pub(crate) fn debug_eta_w_c(
&self,
rho: &Array1<f64>,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let pr = self.execute_pirls_if_needed(rho)?;
Ok((
pr.final_eta.clone(),
pr.finalweights.clone(),
pr.solve_c_array.clone(),
))
}
}
}