use super::*;
pub(crate) const TK_BLOCK_SIZE: usize = 128;
pub(crate) const TK_CHUNK_MAX_ROWS: usize = 2048;
pub(crate) const TK_CHUNK_OVERSUBSCRIBE: usize = 4;
pub(crate) const TK_MAX_OBSERVATIONS: usize = 20_000;
pub(crate) const TK_MAX_COEFFICIENTS: usize = 2_000;
pub(crate) const ADAPTIVE_KKT_ETA: f64 = 0.1;
pub(crate) const ADAPTIVE_KKT_FLOOR_REML_DIVISOR: f64 = 100.0;
pub(crate) const TK_MAX_DENSE_WORK: usize = 5_000_000;
pub(crate) const LARGE_N_EFS_THRESHOLD: f64 = 1.0e8;
pub(crate) const EFS_SINGLE_LOOP_PIRLS_SWEEPS: usize = 2;
pub(crate) const EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL: usize = usize::MAX / 4;
pub(crate) const EFS_SINGLE_LOOP_BIAS_THRESHOLD: f64 = 0.10;
pub(crate) const EFS_SINGLE_LOOP_BIAS_CONSECUTIVE_LIMIT: usize = 3;
pub(crate) const HGB_INNER_FLOOR: f64 = 1e-12;
pub(crate) const HGB_LINEAR_FLOOR: f64 = 1e-12;
pub(crate) const HGB_TRACE_FLOOR: f64 = 1e-12;
pub(crate) const HGB_HISTORY_CAP: usize = 10;
pub(crate) const HGB_WARMUP_ITERS_MIN: usize = 3;
pub(crate) const HGB_WARMUP_ITERS_MAX: usize = 10;
pub(crate) const HGB_TARGET_FRACTION: f64 = 0.1;
pub(crate) const HGB_SECANT_DRHO_MAX_SQUARED: f64 = 1.0;
pub(crate) const HGB_MIN_PAIRS_FOR_SENSITIVITY: usize = 3;
pub(crate) const HGB_REGRESSION_RIDGE: f64 = 1e-6;
pub(crate) const HGB_SENS_STABILITY_RATIO: f64 = 1.5;
pub(crate) const S_INNER_INIT: f64 = 1.0;
pub(crate) const S_LINEAR_INIT: f64 = 1.0;
pub(crate) const S_TRACE_INIT: f64 = 1.0;
pub(crate) const HGB_SENS_FLOOR: f64 = 1e-6;
pub(crate) const IFT_QUALITY_HISTORY_CAP: usize = 5;
pub(crate) const ETA_OVERFLOW_CLAMP: f64 = 700.0;
pub(crate) const IFT_QUALITY_GROW_BAND: f64 = 1e-3;
pub(crate) const IFT_QUALITY_SHRINK_BAND: f64 = 1e-1;
pub(crate) const IFT_QUALITY_FLAT_FALLBACK_BAND: f64 = 0.5;
pub(crate) const IFT_STEP_CAP_GROW_FACTOR: f64 = 1.5;
pub(crate) const IFT_STEP_CAP_SHRINK_FACTOR: f64 = 0.5;
pub(crate) const KKT_TOL_PRIMAL: f64 = 1e-7;
pub(crate) const KKT_TOL_DUAL: f64 = 1e-7;
pub(crate) const KKT_TOL_COMP: f64 = 1e-7;
pub(crate) const KKT_TOL_STAT: f64 = 5e-6;
pub(crate) const ACTIVE_CONSTRAINT_SLACK_TOL: f64 = 1e-8;
pub(crate) const ORTHONORM_DROP_TOL: f64 = 1e-10;
#[derive(Debug, Clone)]
pub(crate) struct AloStabilizationEval {
pub(crate) cost: f64,
pub(crate) gradient: Option<Array1<f64>>,
pub(crate) k_hat: Option<f64>,
pub(crate) max_leverage: f64,
pub(crate) min_denominator: f64,
}
pub(crate) const ALO_STABILIZATION_MIN_N: usize = 20;
pub(crate) const ALO_EDF_FRACTION_SATURATION: f64 = 0.70;
pub(crate) const ALO_PERVASIVE_LEVERAGE_FRACTION: f64 = 0.25;
pub(crate) const ALO_PARAMETRIC_LEVERAGE_SHARE: f64 = 0.75;
pub(crate) const ALO_DENOM_INSTABILITY_THRESHOLD: f64 = 0.20;
pub(crate) const ALO_MAX_LEVERAGE_THRESHOLD: f64 = 0.80;
pub(crate) const ALO_TAU: f64 = 0.5;
pub(crate) const ALO_GAMMA: f64 = 0.5;
pub(crate) const ALO_DEVIANCE_SATURATION: f64 = 9.0;
pub(crate) const ALO_GRADIENT_MAX_WORK: usize = 4_000_000;
pub(crate) struct AloFactoredHessian<'a> {
pub(crate) x: &'a Array2<f64>,
pub(crate) sensitivity: &'a crate::solver::sensitivity::FitSensitivity<'a>,
pub(crate) h_inv_xt: &'a Array2<f64>,
}
pub(crate) fn alo_leverage_barrier(h: f64) -> f64 {
let excess = (h - ALO_MAX_LEVERAGE_THRESHOLD).max(0.0);
excess * excess
}
pub(crate) fn alo_leverage_barrier_derivative(h: f64) -> f64 {
if h > ALO_MAX_LEVERAGE_THRESHOLD {
2.0 * (h - ALO_MAX_LEVERAGE_THRESHOLD)
} else {
0.0
}
}
pub(crate) fn gaussian_alo_raw_deviance(y: f64, eta_loo: f64, prior_weight: f64, phi: f64) -> f64 {
let residual = y - eta_loo;
prior_weight * residual * residual / phi.max(f64::MIN_POSITIVE)
}
pub(crate) fn gaussian_alo_deviance(y: f64, eta_loo: f64, prior_weight: f64, phi: f64) -> f64 {
let raw = gaussian_alo_raw_deviance(y, eta_loo, prior_weight, phi);
ALO_DEVIANCE_SATURATION * (raw / ALO_DEVIANCE_SATURATION).tanh()
}
pub(crate) fn gaussian_alo_deviance_saturation_factor(raw: f64) -> f64 {
let t = (raw / ALO_DEVIANCE_SATURATION).tanh();
1.0 - t * t
}
pub(crate) fn transformed_penalty_matvec(
penalty: &crate::construction::CanonicalPenalty,
beta: &Array1<f64>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(beta.len());
let beta_block = beta.slice(ndarray::s![penalty.col_range.clone()]);
let centered = &beta_block - &penalty.prior_mean;
let local = penalty.local.dot(¢ered);
out.slice_mut(ndarray::s![penalty.col_range.clone()])
.assign(&local);
out
}
impl EvalShared {
pub(crate) fn canonical_penalty_scores_at_mode(
&self,
canonical_penalties: &[crate::construction::CanonicalPenalty],
) -> Result<Arc<Vec<Array1<f64>>>, EstimationError> {
if let Some(scores) = self.penalty_scores_at_mode.get() {
if scores.len() != canonical_penalties.len() {
return Err(EstimationError::LayoutError(format!(
"shared penalty-score cache mismatch: cached {} score vectors, \
requested {} canonical penalties",
scores.len(),
canonical_penalties.len()
)));
}
return Ok(Arc::clone(scores));
}
let beta_hat = self.pirls_result.beta_transformed.as_ref();
let scores = Arc::new(
canonical_penalties
.iter()
.map(|pen| transformed_penalty_matvec(pen, beta_hat))
.collect::<Vec<_>>(),
);
match self.penalty_scores_at_mode.set(Arc::clone(&scores)) {
Ok(()) => Ok(scores),
Err(_) => Ok(Arc::clone(
self.penalty_scores_at_mode
.get()
.expect("OnceLock set raced, so it is initialized"),
)),
}
}
}
pub(crate) static OUTER_IFT_RESIDUAL_ENERGY: OnceLock<Mutex<HashMap<Vec<u64>, (f64, u64)>>> =
OnceLock::new();
pub(crate) static OUTER_IFT_RESIDUAL_ENERGY_ITER: AtomicU64 = AtomicU64::new(0);
pub(crate) fn outer_ift_residual_energy_cache() -> &'static Mutex<HashMap<Vec<u64>, (f64, u64)>> {
OUTER_IFT_RESIDUAL_ENERGY.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) fn record_current_outer_iter_for_ift(iter: u64) {
OUTER_IFT_RESIDUAL_ENERGY_ITER.store(iter, Ordering::Relaxed);
}
pub(crate) fn current_outer_iter() -> u64 {
OUTER_IFT_RESIDUAL_ENERGY_ITER.load(Ordering::Relaxed)
}
pub(crate) fn clear_outer_ift_residual_energy_for_fit() {
if let Some(cache) = OUTER_IFT_RESIDUAL_ENERGY.get()
&& let Ok(mut cache) = cache.lock()
{
cache.clear();
}
OUTER_IFT_RESIDUAL_ENERGY_ITER.store(0, Ordering::Relaxed);
}
pub(crate) fn store_ift_residual_energy_for_outer_theta(theta: &Array1<f64>, energy: Option<f64>) {
let Some(key) = super::rho_key::sanitized_rhokey(theta) else {
return;
};
if let Ok(mut cache) = outer_ift_residual_energy_cache().lock() {
if let Some(energy) = energy.filter(|energy| energy.is_finite() && *energy >= 0.0) {
cache.insert(key, (energy, current_outer_iter()));
} else {
cache.remove(&key);
}
}
}
pub(crate) struct PenaltySubspace {
pub(crate) evals: Array1<f64>,
pub(crate) rank: usize,
}
pub(crate) struct HyperGradHistoryEntry {
pub(crate) rho: Array1<f64>,
pub(crate) g_outer: Array1<f64>,
pub(crate) e_inner: f64,
pub(crate) e_linear: f64,
pub(crate) sigma_sq: f64,
pub(crate) k: usize,
}
pub(crate) struct HyperGradientBudget {
pub(crate) target_mse: f64,
pub(crate) inner_floor: f64,
pub(crate) linear_floor: f64,
pub(crate) trace_floor: f64,
pub(crate) s_inner: f64,
pub(crate) s_linear: f64,
pub(crate) s_trace: f64,
pub(crate) history: VecDeque<HyperGradHistoryEntry>,
pub(crate) sensitivity_history: VecDeque<[f64; 3]>,
pub(crate) warmup_engaged: bool,
}
impl HyperGradientBudget {
pub(crate) fn new() -> Self {
Self {
target_mse: 0.0,
inner_floor: HGB_INNER_FLOOR,
linear_floor: HGB_LINEAR_FLOOR,
trace_floor: HGB_TRACE_FLOOR,
s_inner: S_INNER_INIT,
s_linear: S_LINEAR_INIT,
s_trace: S_TRACE_INIT,
history: VecDeque::with_capacity(HGB_HISTORY_CAP),
sensitivity_history: VecDeque::with_capacity(HGB_WARMUP_ITERS_MIN),
warmup_engaged: false,
}
}
pub(crate) fn push(&mut self, entry: HyperGradHistoryEntry) {
self.history.push_back(entry);
while self.history.len() > HGB_HISTORY_CAP {
self.history.pop_front();
}
}
pub(crate) fn previous_gradient_norm(&self) -> f64 {
self.history
.iter()
.rev()
.nth(1)
.or_else(|| self.history.back())
.map(|entry| l2_norm(&entry.g_outer))
.filter(|norm| norm.is_finite())
.unwrap_or(0.0)
}
pub(crate) fn reestimate_sensitivities(&mut self) -> Option<[f64; 3]> {
let pairs = self.secant_gradient_pairs();
if pairs.len() < HGB_MIN_PAIRS_FOR_SENSITIVITY {
log::info!(
"[HGB] small-sample fallback to defaults: pairs={}, threshold={}",
pairs.len(),
HGB_MIN_PAIRS_FOR_SENSITIVITY
);
return None;
}
let Some(s_inner) = self.estimate_energy_sensitivity(&pairs, |entry| entry.e_inner) else {
return None;
};
let Some(s_linear) = self.estimate_energy_sensitivity(&pairs, |entry| entry.e_linear)
else {
return None;
};
let Some(s_trace) = self.estimate_trace_sensitivity() else {
return None;
};
let sensitivities = [
s_inner.max(HGB_SENS_FLOOR),
s_linear.max(HGB_SENS_FLOOR),
s_trace.max(HGB_SENS_FLOOR),
];
self.s_inner = sensitivities[0];
self.s_linear = sensitivities[1];
self.s_trace = sensitivities[2];
self.sensitivity_history.push_back(sensitivities);
while self.sensitivity_history.len() > HGB_WARMUP_ITERS_MIN {
self.sensitivity_history.pop_front();
}
Some(sensitivities)
}
pub(crate) fn estimate_energy_sensitivity<F>(
&self,
pairs: &[(Array1<f64>, Array1<f64>, usize)],
energy: F,
) -> Option<f64>
where
F: Fn(&HyperGradHistoryEntry) -> f64,
{
let mut estimates = Vec::new();
for i in 0..pairs.len() {
let (drho_i, dg_i, left_idx) = &pairs[i];
let rho_dim = drho_i.len();
let grad_dim = dg_i.len();
if rho_dim == 0 || grad_dim == 0 {
continue;
}
let mut xtx = Array2::<f64>::zeros((rho_dim, rho_dim));
for d in 0..rho_dim {
xtx[[d, d]] = HGB_REGRESSION_RIDGE;
}
let mut xty = Array2::<f64>::zeros((rho_dim, grad_dim));
let mut fit_pairs = 0usize;
for (j, (drho_j, dg_j, _)) in pairs.iter().enumerate() {
if i == j || drho_j.len() != rho_dim || dg_j.len() != grad_dim {
continue;
}
if drho_j.iter().any(|v| !v.is_finite()) || dg_j.iter().any(|v| !v.is_finite()) {
continue;
}
for row in 0..rho_dim {
for col in 0..rho_dim {
xtx[[row, col]] += drho_j[row] * drho_j[col];
}
for grad in 0..grad_dim {
xty[[row, grad]] += drho_j[row] * dg_j[grad];
}
}
fit_pairs += 1;
}
if fit_pairs == 0 {
continue;
}
let Ok(chol) = xtx.cholesky(Side::Lower) else {
continue;
};
chol.solve_mat_in_place(&mut xty);
let mut predicted = Array1::<f64>::zeros(grad_dim);
for grad in 0..grad_dim {
let mut value = 0.0;
for rho in 0..rho_dim {
value += drho_i[rho] * xty[[rho, grad]];
}
predicted[grad] = value;
}
let residual = dg_i - &predicted;
let e0 = energy(&self.history[*left_idx]);
let e1 = energy(&self.history[*left_idx + 1]);
let denom_energy = e0.max(e1).max(1e-300);
if !denom_energy.is_finite() || denom_energy < 0.0 {
continue;
}
let estimate = l2_norm(&residual) / (2.0 * denom_energy).sqrt();
if estimate.is_finite() && estimate > 0.0 {
estimates.push(estimate);
}
}
mean_positive(&estimates)
}
pub(crate) fn secant_gradient_pairs(&self) -> Vec<(Array1<f64>, Array1<f64>, usize)> {
let entries: Vec<_> = self.history.iter().collect();
let mut pairs = Vec::new();
for i in 0..entries.len().saturating_sub(1) {
let a = entries[i];
let b = entries[i + 1];
if a.rho.len() != b.rho.len() || a.g_outer.len() != b.g_outer.len() {
continue;
}
let drho = &b.rho - &a.rho;
let dg = &b.g_outer - &a.g_outer;
let drho_norm_squared = drho.dot(&drho);
if drho.iter().all(|v| v.is_finite())
&& dg.iter().all(|v| v.is_finite())
&& drho_norm_squared > 0.0
&& drho_norm_squared <= HGB_SECANT_DRHO_MAX_SQUARED
{
pairs.push((drho, dg, i));
}
}
pairs
}
pub(crate) fn estimate_trace_sensitivity(&self) -> Option<f64> {
let last_k = self.history.back()?.k;
if last_k == 0 {
return None;
}
let fixed: Vec<&HyperGradHistoryEntry> = self
.history
.iter()
.rev()
.take_while(|entry| entry.k == last_k)
.collect();
if fixed.len() < HGB_WARMUP_ITERS_MIN {
return None;
}
if fixed
.iter()
.any(|entry| !entry.sigma_sq.is_finite() || entry.sigma_sq < 0.0)
{
return None;
}
let dim = fixed[0].g_outer.len();
if dim == 0 || fixed.iter().any(|entry| entry.g_outer.len() != dim) {
return None;
}
let mut means = Array1::<f64>::zeros(dim);
for entry in fixed.iter() {
means += &entry.g_outer;
}
means /= fixed.len() as f64;
let mut variance_sum = 0.0;
for entry in fixed.iter() {
let diff = &entry.g_outer - &means;
variance_sum += diff.dot(&diff);
}
let denom = ((fixed.len() - 1) * dim) as f64;
let std = (variance_sum / denom).max(0.0).sqrt();
(std.is_finite() && std > 0.0).then_some(std)
}
pub(crate) fn allocate_with_sensitivities(
&self,
s_inner: f64,
s_linear: f64,
s_trace: f64,
) -> (f64, f64, f64, [bool; 3]) {
let floors = self.inner_floor + self.linear_floor + self.trace_floor;
let usable = self.target_mse - floors;
if usable <= 0.0 || !usable.is_finite() {
log::warn!(
"[HGB] target_mse below mandatory floors; target_mse={:.3e} floors={:.3e}",
self.target_mse,
floors
);
return (
self.inner_floor,
self.linear_floor,
self.trace_floor,
[true, true, true],
);
}
let wi = s_inner * s_inner;
let wl = s_linear * s_linear;
let wt = s_trace * s_trace;
let sum = (wi + wl + wt).max(HGB_SENS_FLOOR * HGB_SENS_FLOOR);
(
self.inner_floor + usable * wi / sum,
self.linear_floor + usable * wl / sum,
self.trace_floor + usable * wt / sum,
[false, false, false],
)
}
pub(crate) fn sensitivities_stable(&self) -> bool {
if self.sensitivity_history.len() < HGB_WARMUP_ITERS_MIN {
return false;
}
for channel in 0..3 {
let mut min_recent = f64::INFINITY;
let mut max_recent: f64 = 0.0;
for sensitivities in self.sensitivity_history.iter() {
let value = sensitivities[channel];
if !value.is_finite() || value <= 0.0 {
return false;
}
min_recent = min_recent.min(value);
max_recent = max_recent.max(value);
}
if max_recent / min_recent >= HGB_SENS_STABILITY_RATIO {
return false;
}
}
true
}
}
pub(crate) struct HyperGradientRuntimeState {
pub(crate) budget: HyperGradientBudget,
pub(crate) adaptive_kkt_override: Option<f64>,
pub(crate) trace_state: Arc<Mutex<super::reml_outer_engine::StochasticTraceState>>,
}
impl HyperGradientRuntimeState {
pub(crate) fn new() -> Self {
Self {
budget: HyperGradientBudget::new(),
adaptive_kkt_override: None,
trace_state: Arc::new(Mutex::new(
super::reml_outer_engine::StochasticTraceState::default(),
)),
}
}
}
pub(crate) static HYPERGRADIENT_BUDGETS: OnceLock<
Mutex<HashMap<usize, HyperGradientRuntimeState>>,
> = OnceLock::new();
pub(crate) fn hypergradient_budgets() -> &'static Mutex<HashMap<usize, HyperGradientRuntimeState>> {
HYPERGRADIENT_BUDGETS.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Default)]
pub(crate) struct IftQualityRuntimeState {
pub(crate) quality_history: Vec<f64>,
pub(crate) next_step_cap: Option<f64>,
pub(crate) fallback_next_flat: bool,
}
pub(crate) static IFT_QUALITY_STATES: OnceLock<Mutex<HashMap<usize, IftQualityRuntimeState>>> =
OnceLock::new();
pub(crate) fn ift_quality_states() -> &'static Mutex<HashMap<usize, IftQualityRuntimeState>> {
IFT_QUALITY_STATES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Clone)]
pub(crate) struct IftModeResponseRuntimeCache {
pub(crate) rho: Array1<f64>,
pub(crate) rho_mode_response_cols: Option<Array2<f64>>,
pub(crate) ext_mode_response_cols: Option<Array2<f64>>,
}
pub(crate) static IFT_MODE_RESPONSE_CACHES: OnceLock<
Mutex<HashMap<usize, IftModeResponseRuntimeCache>>,
> = OnceLock::new();
pub(crate) fn ift_mode_response_caches()
-> &'static Mutex<HashMap<usize, IftModeResponseRuntimeCache>> {
IFT_MODE_RESPONSE_CACHES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Clone)]
pub(crate) struct IftJointModeResponseRuntimeCache {
pub(crate) theta: Array1<f64>,
pub(crate) rho_dim: usize,
pub(crate) beta_original: Array1<f64>,
pub(crate) mode_response_cols: Array2<f64>,
pub(crate) active_constraints: bool,
}
pub(crate) static IFT_JOINT_MODE_RESPONSE_CACHES: OnceLock<
Mutex<HashMap<usize, IftJointModeResponseRuntimeCache>>,
> = OnceLock::new();
pub(crate) fn ift_joint_mode_response_caches()
-> &'static Mutex<HashMap<usize, IftJointModeResponseRuntimeCache>> {
IFT_JOINT_MODE_RESPONSE_CACHES.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) fn joint_ift_cache_matches_theta(
cache: &IftJointModeResponseRuntimeCache,
theta: &Array1<f64>,
new_rho: &Array1<f64>,
) -> bool {
if cache.theta.len() <= cache.rho_dim
|| theta.len() != cache.theta.len()
|| new_rho.len() != cache.rho_dim
{
return false;
}
for i in 0..cache.rho_dim {
if theta[i].to_bits() != new_rho[i].to_bits() {
return false;
}
}
for i in cache.rho_dim..theta.len() {
if theta[i].to_bits() != cache.theta[i].to_bits() {
return false;
}
}
true
}
thread_local! {
pub(crate) static IFT_LATEST_OUTER_THETA: std::cell::RefCell<Option<Array1<f64>>> =
const { std::cell::RefCell::new(None) };
pub(crate) static IFT_LATEST_OUTER_RHO_UPPER_BOUNDS: std::cell::RefCell<Option<Array1<f64>>> =
const { std::cell::RefCell::new(None) };
}
pub(crate) fn record_current_outer_theta_for_ift(theta: &Array1<f64>) {
let value = if theta.is_empty() || theta.iter().any(|v| !v.is_finite()) {
None
} else {
Some(theta.clone())
};
IFT_LATEST_OUTER_THETA.with(|slot| *slot.borrow_mut() = value);
}
pub(crate) fn record_current_outer_rho_upper_bounds_for_ift(upper: &Array1<f64>) {
let value = if upper.is_empty() || upper.iter().any(|v| !v.is_finite()) {
None
} else {
Some(upper.clone())
};
IFT_LATEST_OUTER_RHO_UPPER_BOUNDS.with(|slot| *slot.borrow_mut() = value);
}
pub(crate) fn latest_outer_rho_upper_bounds_for_ift() -> Option<Array1<f64>> {
IFT_LATEST_OUTER_RHO_UPPER_BOUNDS.with(|slot| slot.borrow().clone())
}
pub(crate) fn latest_outer_theta_for_ift() -> Option<Array1<f64>> {
IFT_LATEST_OUTER_THETA.with(|slot| slot.borrow().clone())
}
pub(crate) fn l2_norm(values: &Array1<f64>) -> f64 {
values.iter().map(|v| v * v).sum::<f64>().sqrt()
}
pub(crate) fn mean_positive(values: &[f64]) -> Option<f64> {
let mut sum = 0.0;
let mut count = 0usize;
for &value in values {
if value.is_finite() && value > 0.0 {
sum += value;
count += 1;
}
}
(count > 0).then_some(sum / count as f64)
}
#[derive(Default)]
pub(crate) struct EfsSingleLoopBiasGuardState {
pub(crate) owner: usize,
pub(crate) consecutive: usize,
}
pub(crate) static EFS_SINGLE_LOOP_BIAS_GUARD: LazyLock<Mutex<EfsSingleLoopBiasGuardState>> =
LazyLock::new(|| Mutex::new(EfsSingleLoopBiasGuardState::default()));
#[inline]
pub(crate) fn compute_gradient_for_tk(mode: super::reml_outer_engine::EvalMode) -> bool {
mode != super::reml_outer_engine::EvalMode::ValueOnly
}
#[inline]
pub(crate) fn efs_single_loop_encoded_cap() -> usize {
EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL + EFS_SINGLE_LOOP_PIRLS_SWEEPS
}
#[inline]
pub(crate) fn decode_efs_single_loop_cap(raw_cap: usize) -> Option<usize> {
(raw_cap >= EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL)
.then(|| raw_cap - EFS_SINGLE_LOOP_PIRLS_CAP_SENTINEL)
.filter(|cap| *cap > 0)
}
#[inline]
pub(crate) fn screening_residual_penalty(cost: f64, pr: &PirlsResult) -> f64 {
crate::solver::objective_base::failed_inner_residual_barrier_cost(
cost,
pr.status.is_failed_max_iterations(),
pr.relative_gradient_norm(),
)
}
pub(crate) fn hash_array_view(hasher: &mut Fingerprinter, values: ndarray::ArrayView1<'_, f64>) {
hasher.write_usize(values.len());
for &value in values {
hasher.write_f64(value);
}
}
pub(crate) fn hash_array2(hasher: &mut Fingerprinter, values: &Array2<f64>) {
hasher.write_usize(values.nrows());
hasher.write_usize(values.ncols());
for &value in values {
hasher.write_f64(value);
}
}
pub(crate) fn hash_aux_prior_strength(
hasher: &mut Fingerprinter,
strength: crate::terms::latent::AuxPriorStrength,
) {
use crate::terms::latent::AuxPriorStrength;
match strength {
AuxPriorStrength::Auto => hasher.write_str("auto"),
AuxPriorStrength::Fixed(value) => {
hasher.write_str("fixed");
hasher.write_f64(value);
}
}
}
pub(in crate::solver::estimate) fn latent_id_mode_cache_fingerprint(
id_mode: &crate::terms::latent::LatentIdMode,
) -> u64 {
use crate::terms::latent::{AuxPriorFamily, LatentIdMode};
let mut hasher = Fingerprinter::new();
hasher.write_str("latent-id-mode-cache-v1");
match id_mode {
LatentIdMode::AuxPrior {
u,
family,
strength,
} => {
hasher.write_str("aux-prior");
hash_array2(&mut hasher, u);
match family {
AuxPriorFamily::Ridge => hasher.write_str("ridge"),
AuxPriorFamily::Linear => hasher.write_str("linear"),
}
hash_aux_prior_strength(&mut hasher, *strength);
}
LatentIdMode::AuxPriorDimSelection {
u,
family,
strength,
..
} => {
hasher.write_str("aux-prior-dim-selection");
hash_array2(&mut hasher, u);
match family {
AuxPriorFamily::Ridge => hasher.write_str("ridge"),
AuxPriorFamily::Linear => hasher.write_str("linear"),
}
hash_aux_prior_strength(&mut hasher, *strength);
}
LatentIdMode::DimSelection { .. } => hasher.write_str("dim-selection"),
LatentIdMode::IsometryToReference {
reference,
strength,
} => {
hasher.write_str("isometry-to-reference");
hash_array2(&mut hasher, reference);
hash_aux_prior_strength(&mut hasher, *strength);
}
LatentIdMode::AuxOutcome { head, .. } => {
use crate::terms::decoders::behavioral_head::AuxOutcomeFamily;
hasher.write_str("aux-outcome");
match head.family() {
AuxOutcomeFamily::Binomial => hasher.write_str("binomial"),
AuxOutcomeFamily::Multinomial { n_classes } => {
hasher.write_str("multinomial");
hasher.write_usize(n_classes);
}
}
hasher.write_usize(head.n_obs());
hasher.write_f64(head.effective_labeled_count());
}
LatentIdMode::None => hasher.write_str("none"),
}
hasher.finish_u64()
}
pub(crate) fn hash_array3(hasher: &mut Fingerprinter, values: &ndarray::Array3<f64>) {
let (a, b, c) = values.dim();
hasher.write_usize(a);
hasher.write_usize(b);
hasher.write_usize(c);
for &value in values {
hasher.write_f64(value);
}
}
pub(crate) fn hash_psi_slice(
hasher: &mut Fingerprinter,
target: &crate::terms::analytic_penalties::PsiSlice,
) {
hasher.write_usize(target.range.start);
hasher.write_usize(target.range.end);
match target.latent_dim {
Some(latent_dim) => {
hasher.write_bool(true);
hasher.write_usize(latent_dim);
}
None => hasher.write_bool(false),
}
}
pub(crate) fn hash_scalar_weight_schedule(
hasher: &mut Fingerprinter,
schedule: &crate::terms::analytic_penalties::ScalarWeightSchedule,
) {
use crate::terms::sae::manifold::ScheduleKind;
hasher.write_f64(schedule.w_start);
hasher.write_f64(schedule.w_end);
match &schedule.kind {
ScheduleKind::Geometric { rate } => {
hasher.write_str("geometric");
hasher.write_f64(*rate);
}
ScheduleKind::Linear { steps } => {
hasher.write_str("linear");
hasher.write_usize(*steps);
}
ScheduleKind::ReciprocalIter => hasher.write_str("reciprocal-iter"),
}
hasher.write_usize(schedule.iter_count);
}
pub(crate) fn hash_weight_schedule_option(
hasher: &mut Fingerprinter,
schedule: &Option<crate::terms::analytic_penalties::ScalarWeightSchedule>,
) {
match schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_scalar_weight_schedule(hasher, schedule);
}
None => hasher.write_bool(false),
}
}
pub(crate) fn hash_gumbel_temperature_schedule(
hasher: &mut Fingerprinter,
schedule: &crate::terms::sae::manifold::GumbelTemperatureSchedule,
) {
use crate::terms::sae::manifold::ScheduleKind;
hasher.write_f64(schedule.tau_start);
hasher.write_f64(schedule.tau_min);
match &schedule.decay {
ScheduleKind::Geometric { rate } => {
hasher.write_str("geometric");
hasher.write_f64(*rate);
}
ScheduleKind::Linear { steps } => {
hasher.write_str("linear");
hasher.write_usize(*steps);
}
ScheduleKind::ReciprocalIter => hasher.write_str("reciprocal-iter"),
}
hasher.write_usize(schedule.iter_count);
}
pub(crate) fn hash_gumbel_schedule_option(
hasher: &mut Fingerprinter,
schedule: &Option<crate::terms::sae::manifold::GumbelTemperatureSchedule>,
) {
match schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_gumbel_temperature_schedule(hasher, schedule);
}
None => hasher.write_bool(false),
}
}
pub(crate) fn hash_isometry_reference(
hasher: &mut Fingerprinter,
reference: &crate::terms::analytic_penalties::IsometryReference,
) {
use crate::terms::analytic_penalties::IsometryReference;
match reference {
IsometryReference::Euclidean => hasher.write_str("euclidean"),
IsometryReference::UserSupplied(values) => {
hasher.write_str("user-supplied");
hash_array2(hasher, values.as_ref());
}
}
}
pub(crate) fn hash_weight_field(
hasher: &mut Fingerprinter,
field: &crate::terms::analytic_penalties::WeightField,
) {
use crate::terms::analytic_penalties::WeightField;
match field {
WeightField::Identity => hasher.write_str("identity"),
WeightField::Factored { u, rank, p_out } => {
hasher.write_str("factored");
hash_array2(hasher, u.as_ref());
hasher.write_usize(*rank);
hasher.write_usize(*p_out);
}
}
}
pub(crate) fn hash_sparsity_kind(
hasher: &mut Fingerprinter,
kind: crate::terms::analytic_penalties::SparsityKind,
) {
use crate::terms::analytic_penalties::SparsityKind;
match kind {
SparsityKind::SmoothedL1 { eps } => {
hasher.write_str("smoothed-l1");
hasher.write_f64(eps);
}
SparsityKind::Hoyer => hasher.write_str("hoyer"),
SparsityKind::Log { delta } => {
hasher.write_str("log");
hasher.write_f64(delta);
}
}
}
pub(crate) fn hash_difference_op_kind(
hasher: &mut Fingerprinter,
kind: &crate::terms::analytic_penalties::DifferenceOpKind,
) {
use crate::terms::analytic_penalties::DifferenceOpKind;
match kind {
DifferenceOpKind::ForwardDiff1D => hasher.write_str("forward-diff-1d"),
DifferenceOpKind::GraphEdges(edges) => {
hasher.write_str("graph-edges");
hasher.write_usize(edges.len());
for &(from, to) in edges {
hasher.write_usize(from);
hasher.write_usize(to);
}
}
}
}
pub(crate) fn hash_groups(hasher: &mut Fingerprinter, groups: &[Vec<usize>]) {
hasher.write_usize(groups.len());
for group in groups {
hasher.write_usize(group.len());
for &axis in group {
hasher.write_usize(axis);
}
}
}
pub(crate) fn hash_analytic_penalty_kind(
hasher: &mut Fingerprinter,
penalty: &crate::terms::analytic_penalties::AnalyticPenaltyKind,
) {
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, PenaltyConcavity};
hasher.write_str(penalty.name());
hasher.write_str(&format!("{:?}", penalty.tier()));
hasher.write_usize(penalty.rho_count());
match penalty {
AnalyticPenaltyKind::Isometry(p) => {
hasher.write_str("isometry");
hash_psi_slice(hasher, &p.target);
hash_isometry_reference(hasher, &p.reference);
hasher.write_usize(p.rho_index);
hasher.write_usize(p.p_out);
hash_weight_field(hasher, &p.weight);
hasher.write_f64(p.scalar_weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
match p.duchon_radial_source.as_ref() {
Some(source) => {
hasher.write_bool(true);
hash_array2(hasher, source.centers.as_ref());
hash_array2(hasher, source.radial_coefficients.as_ref());
match source.length_scale {
Some(length_scale) => {
hasher.write_bool(true);
hasher.write_f64(length_scale);
}
None => hasher.write_bool(false),
}
hasher.write_str(&format!("{:?}", source.nullspace_order));
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::Sparsity(p) => {
hasher.write_str("sparsity");
hasher.write_str(&format!("{:?}", p.target_tier));
hash_sparsity_kind(hasher, p.kind);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
hasher.write_usize(p.strength_rho_index);
match p.eps_rho_index {
Some(idx) => {
hasher.write_bool(true);
hasher.write_usize(idx);
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::SoftmaxAssignmentSparsity(p) => {
hasher.write_str("softmax-assignment-sparsity");
hasher.write_usize(p.k_atoms);
hasher.write_f64(p.temperature);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::IBPAssignment(p) => {
hasher.write_str("ibp-assignment");
hasher.write_usize(p.k_max);
hasher.write_f64(p.alpha);
hasher.write_f64(p.tau);
hash_gumbel_schedule_option(hasher, &p.temperature_schedule);
hasher.write_bool(p.learnable_alpha);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Ard(p) => {
hasher.write_str("ard");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
hasher.write_usize(p.rho_indices.len());
for &idx in &p.rho_indices {
hasher.write_usize(idx);
}
hasher.write_f64(p.n_eff);
}
AnalyticPenaltyKind::TopKActivation(p) => {
hasher.write_str("topk-activation");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.k);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::JumpReLU(p) => {
hasher.write_str("jumprelu");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hash_array_view(hasher, p.thresholds.view());
hasher.write_f64(p.weight);
hasher.write_f64(p.smoothing_eps);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::TotalVariation(p) => {
hasher.write_str("total-variation");
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hash_difference_op_kind(hasher, &p.difference_op);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::NuclearNorm(p) => {
hasher.write_str("nuclear-norm");
hash_psi_slice(hasher, &p.target);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.smoothing_eps);
match p.max_rank {
Some(max_rank) => {
hasher.write_bool(true);
hasher.write_usize(max_rank);
}
None => hasher.write_bool(false),
}
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::BlockSparsity(p) => {
hasher.write_str("block-sparsity");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.groups);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::MechanismSparsity(p) => {
hasher.write_str("mechanism-sparsity");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.feature_groups);
hasher.write_f64(p.weight);
hasher.write_f64(p.smoothing_eps);
hasher.write_f64(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
match &p.weight_schedule {
Some(schedule) => {
hasher.write_bool(true);
hash_scalar_weight_schedule(hasher, schedule.as_ref());
}
None => hasher.write_bool(false),
}
}
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
hasher.write_str("row-precision-prior");
hash_array3(hasher, &p.lambda_per_row);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::IvaeRidgeMeanGauge(p) => {
hasher.write_str("ivae-ridge-mean-gauge");
hash_array2(hasher, &p.aux);
hash_array2(hasher, &p.ridge_inv);
hasher.write_f64(p.ridge_eps);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
hasher.write_str("parametric-row-precision-prior");
hash_array2(hasher, &p.aux);
hash_array_view(hasher, p.log_alpha.view());
hash_array_view(hasher, p.raw_beta.view());
hash_array2(hasher, &p.mu);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hash_psi_slice(hasher, &p.target);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::ScadMcp(p) => {
hasher.write_str("scad-mcp");
hash_psi_slice(hasher, &p.target);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.gamma);
hasher.write_f64(p.smoothing_eps);
match p.variant {
PenaltyConcavity::Mcp => hasher.write_str("mcp"),
PenaltyConcavity::Scad => hasher.write_str("scad"),
}
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::BlockOrthogonality(p) => {
hasher.write_str("block-orthogonality");
hash_psi_slice(hasher, &p.target);
hash_groups(hasher, &p.groups);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::DecoderIncoherence(p) => {
hasher.write_str("decoder-incoherence");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.block_sizes.len());
for &m in &p.block_sizes {
hasher.write_usize(m);
}
hasher.write_usize(p.p_out);
hash_array2(hasher, &p.coactivation);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Orthogonality(p) => {
hasher.write_str("orthogonality");
hash_psi_slice(hasher, &p.target);
hasher.write_usize(p.latent_dim);
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::NestedPrefix(p) => {
hasher.write_str("nested-prefix");
hash_psi_slice(hasher, &p.target);
hasher.write_str(&format!("{:?}", p.target_tier));
hasher.write_usize(p.prefix_sizes.len());
for &m in &p.prefix_sizes {
hasher.write_usize(m);
}
hasher.write_usize(p.shell_weights.len());
for &w in &p.shell_weights {
hasher.write_f64(w);
}
hasher.write_f64(p.eps);
hasher.write_usize(p.rho_indices.len());
for &idx in &p.rho_indices {
hasher.write_usize(idx);
}
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::Monotonicity(p) => {
hasher.write_str("monotonicity");
hasher.write_f64(p.weight);
hasher.write_usize(p.n_eff);
hasher.write_f64(p.direction);
hasher.write_f64(p.smoothing_eps);
hasher.write_bool(p.learnable_weight);
hasher.write_usize(p.rho_index);
hash_weight_schedule_option(hasher, &p.weight_schedule);
}
AnalyticPenaltyKind::SheafConsistency(p) => {
hasher.write_str("sheaf-consistency");
hasher.write_f64(p.weight());
let dims = p.stalk_dims();
hasher.write_usize(dims.len());
for &d in dims {
hasher.write_usize(d);
}
}
}
}
pub(crate) fn analytic_penalty_registry_fingerprint(
registry: &crate::terms::analytic_penalties::AnalyticPenaltyRegistry,
) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("analytic-penalty-registry-v1");
hasher.write_usize(registry.penalties.len());
for penalty in ®istry.penalties {
hash_analytic_penalty_kind(&mut hasher, penalty);
}
hasher.finish_u64()
}
pub(crate) fn hash_design_matrix(
hasher: &mut Fingerprinter,
design: &DesignMatrix,
) -> Result<(), String> {
const HASH_CHUNK_TARGET_BYTES: usize = 8 * 1024 * 1024;
const HASH_CHUNK_MIN_ROWS: usize = 1;
const HASH_CHUNK_MAX_ROWS: usize = 4096;
let n = design.nrows();
let p = design.ncols();
hasher.write_usize(n);
hasher.write_usize(p);
let bytes_per_row = p.saturating_mul(std::mem::size_of::<f64>()).max(1);
let chunk_rows =
(HASH_CHUNK_TARGET_BYTES / bytes_per_row).clamp(HASH_CHUNK_MIN_ROWS, HASH_CHUNK_MAX_ROWS);
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let chunk = design
.try_row_chunk(start..end)
.map_err(|e| format!("persistent warm-start design hash failed: {e}"))?;
hash_array2(hasher, &chunk);
}
Ok(())
}
pub(crate) fn hash_canonical_penalties(
hasher: &mut Fingerprinter,
penalties: &[crate::construction::CanonicalPenalty],
) {
hasher.write_usize(penalties.len());
for penalty in penalties {
hasher.write_usize(penalty.col_range.start);
hasher.write_usize(penalty.col_range.end);
hasher.write_usize(penalty.total_dim);
hasher.write_usize(penalty.nullity);
hash_array2(hasher, &penalty.root);
hash_array2(hasher, &penalty.local);
hash_array_view(hasher, penalty.prior_mean.view());
hasher.write_usize(penalty.positive_eigenvalues.len());
for &value in &penalty.positive_eigenvalues {
hasher.write_f64(value);
}
hasher.write_bool(penalty.op.is_some());
}
}
pub(crate) fn finite_positive_from_bits(bits: u64) -> Option<f64> {
if bits == 0 {
return None;
}
let value = f64::from_bits(bits);
if value.is_finite() && value > 0.0 {
Some(value)
} else {
None
}
}
pub(crate) fn finite_nonnegative_from_bits(bits: u64) -> Option<f64> {
let value = f64::from_bits(bits);
if value.is_finite() && value >= 0.0 {
Some(value)
} else {
None
}
}
pub(crate) fn finite_nonnegative_bits_or_no_signal(value: Option<f64>) -> u64 {
value
.filter(|v| v.is_finite() && *v >= 0.0)
.map(f64::to_bits)
.unwrap_or(IFT_RESIDUAL_NO_SIGNAL_BITS)
}
#[derive(Clone)]
pub(crate) struct TkCorrectionTerms {
pub(crate) value: f64,
pub(crate) gradient: Option<Array1<f64>>,
pub(crate) hessian: Option<Array2<f64>>,
}
pub(crate) struct TkSharedIntermediates {
pub(crate) h_diag: Array1<f64>,
pub(crate) x_m: Array1<f64>,
pub(crate) y: Array1<f64>,
pub(crate) active_blocks: Vec<TkActiveBlock>,
}
pub(crate) struct TkActiveBlock {
pub(crate) start: usize,
pub(crate) end: usize,
pub(crate) entries: Vec<(usize, f64)>,
}
pub(crate) struct DerivativeContext {
pub(crate) deriv_provider: Box<dyn super::reml_outer_engine::HessianDerivativeProvider>,
pub(crate) dispersion: super::reml_outer_engine::DispersionHandling,
pub(crate) log_likelihood: f64,
pub(crate) firth_op: Option<std::sync::Arc<super::FirthDenseOperator>>,
pub(crate) barrier_config: Option<super::reml_outer_engine::BarrierConfig>,
}
#[inline]
pub(crate) fn reml_spec(likelihood: &GlmLikelihoodSpec) -> LikelihoodSpec {
likelihood.spec.clone()
}
#[inline]
pub(crate) fn reml_is_gaussian_identity(likelihood: &GlmLikelihoodSpec) -> bool {
reml_spec(likelihood).is_gaussian_identity()
}
#[inline]
pub(crate) fn reml_jeffreys_supported_link(likelihood: &GlmLikelihoodSpec) -> Option<InverseLink> {
let spec = reml_spec(likelihood);
if !matches!(spec.response, ResponseFamily::Binomial) {
return None;
}
if inverse_link_has_fisher_weight_jet(&spec.link) {
Some(spec.link.clone())
} else {
None
}
}
#[inline]
pub(crate) fn reml_robust_jeffreys_link(config: &RemlConfig) -> Option<InverseLink> {
if !config.firth_bias_reduction {
return None;
}
reml_jeffreys_supported_link(&config.likelihood)
}
pub(crate) const FIRTH_DEFAULT_PC_UPPER: f64 = 10.0;
pub(crate) const FIRTH_DEFAULT_PC_TAIL_PROB: f64 = 0.01;
#[inline]
pub(crate) fn firth_default_pc_prior() -> RhoPrior {
RhoPrior::PenalizedComplexity {
upper: FIRTH_DEFAULT_PC_UPPER,
tail_prob: FIRTH_DEFAULT_PC_TAIL_PROB,
}
}
pub(crate) fn firth_default_coord_mask(configured: &RhoPrior, len: usize) -> Vec<bool> {
match configured {
RhoPrior::Flat => vec![true; len],
RhoPrior::Independent(priors) if priors.len() == len => {
priors.iter().map(|p| matches!(p, RhoPrior::Flat)).collect()
}
_ => vec![false; len],
}
}
pub(crate) fn resolve_effective_rho_prior(configured: &RhoPrior) -> std::borrow::Cow<'_, RhoPrior> {
match configured {
RhoPrior::Flat => std::borrow::Cow::Owned(firth_default_pc_prior()),
RhoPrior::Independent(priors) if priors.iter().any(|p| matches!(p, RhoPrior::Flat)) => {
let filled = priors
.iter()
.map(|p| match p {
RhoPrior::Flat => firth_default_pc_prior(),
other => other.clone(),
})
.collect();
std::borrow::Cow::Owned(RhoPrior::Independent(filled))
}
other => std::borrow::Cow::Borrowed(other),
}
}
#[inline]
pub(crate) fn reml_fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
let spec = reml_spec(likelihood);
match (&spec.response, &spec.link) {
(ResponseFamily::Beta { phi }, _) => *phi,
(ResponseFamily::NegativeBinomial { .. }, _) => 1.0,
(ResponseFamily::Tweedie { .. }, _) => likelihood.fixed_phi().unwrap_or(1.0),
(
ResponseFamily::Gaussian
| ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Gamma,
_,
) => likelihood.fixed_phi().unwrap_or(1.0),
(ResponseFamily::RoystonParmar, _) => likelihood.fixed_phi().unwrap_or(1.0),
}
}
pub(crate) const MIN_IMPORTANCE_ESS_FRACTION: f64 = 0.10;
pub(crate) struct Gam784BlockTarget<'t> {
pub(crate) x_transformed: &'t Array2<f64>,
pub(crate) block_vecs: Array2<f64>,
pub(crate) block_lambdas: Array1<f64>,
pub(crate) eta_hat: Array1<f64>,
pub(crate) weights_obs: Array1<f64>,
pub(crate) y: Array1<f64>,
pub(crate) prior_weights: Array1<f64>,
pub(crate) likelihood: GlmLikelihoodSpec,
pub(crate) inverse_link: InverseLink,
pub(crate) phi: f64,
pub(crate) penalty_scores: Arc<Vec<Array1<f64>>>,
pub(crate) lambdas: Vec<f64>,
pub(crate) base_deviance: f64,
}
impl Gam784BlockTarget<'_> {
pub(crate) fn displacement(&self, t: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
let delta = self.block_vecs.dot(t);
let s = crate::faer_ndarray::fast_av(self.x_transformed, &delta);
(delta, s)
}
pub(crate) fn neg_score_at(&self, eta: &Array1<f64>) -> Array1<f64> {
let spec_response = reml_spec(&self.likelihood).response.clone();
let family = pirls::weight_family_for_glm_likelihood(&self.likelihood);
let fam_scale = match &spec_response {
ResponseFamily::Gaussian | ResponseFamily::Tweedie { .. } => {
1.0 / self.likelihood.fixed_phi().unwrap_or(1.0)
}
_ => 1.0,
};
const BINOMIAL_MU_EPS: f64 = 1e-12;
const MU_FLOOR: f64 = 1e-10;
let is_binomial = matches!(spec_response, ResponseFamily::Binomial);
let mut out = Array1::<f64>::zeros(eta.len());
for i in 0..eta.len() {
let jet = match crate::mixture_link::inverse_link_jet_for_inverse_link(
&self.inverse_link,
eta[i],
) {
Ok(jet) => jet,
Err(_) => continue,
};
let mu_c = if is_binomial {
jet.mu.clamp(BINOMIAL_MU_EPS, 1.0 - BINOMIAL_MU_EPS)
} else {
jet.mu.max(MU_FLOOR)
};
let v = pirls::variance_jet_for_weight_family(family, mu_c).v;
if !(v.is_finite() && v > 0.0) {
continue;
}
let d_dev_d_mu = -2.0 * self.prior_weights[i] * (self.y[i] - mu_c) / v * fam_scale;
out[i] = d_dev_d_mu * jet.d1 / (2.0 * self.phi);
}
out
}
}
impl crate::inference::hmc::BlockExcessTarget for Gam784BlockTarget<'_> {
fn block_dim(&self) -> usize {
self.block_lambdas.len()
}
fn rho_dim(&self) -> usize {
self.lambdas.len()
}
fn block_curvatures(&self) -> &Array1<f64> {
&self.block_lambdas
}
fn excess(&self, t: &Array1<f64>) -> f64 {
let (delta, s) = self.displacement(t);
let mut mu_disp = Array1::<f64>::zeros(self.eta_hat.len());
for i in 0..self.eta_hat.len() {
let eta_i = self.eta_hat[i] + s[i];
match crate::mixture_link::inverse_link_jet_for_inverse_link(&self.inverse_link, eta_i)
{
Ok(jet) => mu_disp[i] = jet.mu,
Err(_) => return f64::INFINITY,
}
}
let dev_disp = crate::pirls::calculate_deviance(
self.y.view(),
&mu_disp,
&self.likelihood,
self.prior_weights.view(),
);
if !dev_disp.is_finite() {
return f64::INFINITY;
}
let neg_loglik_diff = (dev_disp - self.base_deviance) / (2.0 * self.phi);
let mut penalty_term = 0.0_f64;
for (score, &lam) in self.penalty_scores.iter().zip(self.lambdas.iter()) {
penalty_term += lam * score.dot(&delta);
}
let mut curv = 0.0_f64;
for i in 0..s.len() {
curv += self.weights_obs[i] * s[i] * s[i];
}
neg_loglik_diff + penalty_term - 0.5 * curv
}
fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
let delta = self.block_vecs.dot(t);
let mut grad = Array1::<f64>::zeros(self.lambdas.len());
for (k, (score, &lam)) in self
.penalty_scores
.iter()
.zip(self.lambdas.iter())
.enumerate()
{
grad[k] = lam * score.dot(&delta);
}
grad
}
fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
let (_delta, s) = self.displacement(t);
self.neg_score_at(&(&self.eta_hat + &s))
}
fn base_neg_score(&self) -> Array1<f64> {
self.neg_score_at(&self.eta_hat)
}
fn excess_with_displaced_neg_score(&self, t: &Array1<f64>) -> (f64, Option<Array1<f64>>) {
let (delta, s) = self.displacement(t);
let n = self.eta_hat.len();
let spec_response = reml_spec(&self.likelihood).response.clone();
let family = pirls::weight_family_for_glm_likelihood(&self.likelihood);
let fam_scale = match &spec_response {
ResponseFamily::Gaussian | ResponseFamily::Tweedie { .. } => {
1.0 / self.likelihood.fixed_phi().unwrap_or(1.0)
}
_ => 1.0,
};
const BINOMIAL_MU_EPS: f64 = 1e-12;
const MU_FLOOR: f64 = 1e-10;
let is_binomial = matches!(spec_response, ResponseFamily::Binomial);
let mut mu_disp = Array1::<f64>::zeros(n);
let mut ngs = Array1::<f64>::zeros(n);
for i in 0..n {
let eta_i = self.eta_hat[i] + s[i];
let jet = match crate::mixture_link::inverse_link_jet_for_inverse_link(
&self.inverse_link,
eta_i,
) {
Ok(jet) => jet,
Err(_) => return (f64::INFINITY, None),
};
mu_disp[i] = jet.mu;
let mu_c = if is_binomial {
jet.mu.clamp(BINOMIAL_MU_EPS, 1.0 - BINOMIAL_MU_EPS)
} else {
jet.mu.max(MU_FLOOR)
};
let v = pirls::variance_jet_for_weight_family(family, mu_c).v;
if v.is_finite() && v > 0.0 {
let d_dev_d_mu = -2.0 * self.prior_weights[i] * (self.y[i] - mu_c) / v * fam_scale;
ngs[i] = d_dev_d_mu * jet.d1 / (2.0 * self.phi);
}
}
let dev_disp = crate::pirls::calculate_deviance(
self.y.view(),
&mu_disp,
&self.likelihood,
self.prior_weights.view(),
);
if !dev_disp.is_finite() {
return (f64::INFINITY, None);
}
let neg_loglik_diff = (dev_disp - self.base_deviance) / (2.0 * self.phi);
let mut penalty_term = 0.0_f64;
for (score, &lam) in self.penalty_scores.iter().zip(self.lambdas.iter()) {
penalty_term += lam * score.dot(&delta);
}
let mut curv = 0.0_f64;
for i in 0..s.len() {
curv += self.weights_obs[i] * s[i] * s[i];
}
let excess = neg_loglik_diff + penalty_term - 0.5 * curv;
if excess.is_finite() {
(excess, Some(ngs))
} else {
(excess, None)
}
}
fn excess_with_displaced_neg_score_batch(
&self,
draws: &Array2<f64>,
) -> Vec<(f64, Option<Array1<f64>>)> {
let m = self.block_lambdas.len();
let n = self.eta_hat.len();
let n_draws = draws.ncols();
assert_eq!(
draws.nrows(),
m,
"posterior displacement draw rows must match smoothing block count"
);
let delta_all = crate::faer_ndarray::fast_ab(&self.block_vecs, draws);
let s_all = crate::faer_ndarray::fast_ab(self.x_transformed, &delta_all);
let spec_response = reml_spec(&self.likelihood).response.clone();
let family = pirls::weight_family_for_glm_likelihood(&self.likelihood);
let fam_scale = match &spec_response {
ResponseFamily::Gaussian | ResponseFamily::Tweedie { .. } => {
1.0 / self.likelihood.fixed_phi().unwrap_or(1.0)
}
_ => 1.0,
};
const BINOMIAL_MU_EPS: f64 = 1e-12;
const MU_FLOOR: f64 = 1e-10;
let is_binomial = matches!(spec_response, ResponseFamily::Binomial);
let mut out = Vec::with_capacity(n_draws);
let mut mu_disp = Array1::<f64>::zeros(n);
let mut ngs = Array1::<f64>::zeros(n);
let mut delta = Array1::<f64>::zeros(self.block_vecs.nrows());
'draw: for sidx in 0..n_draws {
let s_col = s_all.column(sidx);
ngs.fill(0.0);
for i in 0..n {
let eta_i = self.eta_hat[i] + s_col[i];
let jet = match crate::mixture_link::inverse_link_jet_for_inverse_link(
&self.inverse_link,
eta_i,
) {
Ok(jet) => jet,
Err(_) => {
out.push((f64::INFINITY, None));
continue 'draw;
}
};
mu_disp[i] = jet.mu;
let mu_c = if is_binomial {
jet.mu.clamp(BINOMIAL_MU_EPS, 1.0 - BINOMIAL_MU_EPS)
} else {
jet.mu.max(MU_FLOOR)
};
let v = pirls::variance_jet_for_weight_family(family, mu_c).v;
if v.is_finite() && v > 0.0 {
let d_dev_d_mu =
-2.0 * self.prior_weights[i] * (self.y[i] - mu_c) / v * fam_scale;
ngs[i] = d_dev_d_mu * jet.d1 / (2.0 * self.phi);
}
}
let dev_disp = crate::pirls::calculate_deviance(
self.y.view(),
&mu_disp,
&self.likelihood,
self.prior_weights.view(),
);
if !dev_disp.is_finite() {
out.push((f64::INFINITY, None));
continue;
}
let neg_loglik_diff = (dev_disp - self.base_deviance) / (2.0 * self.phi);
delta.assign(&delta_all.column(sidx));
let mut penalty_term = 0.0_f64;
for (score, &lam) in self.penalty_scores.iter().zip(self.lambdas.iter()) {
penalty_term += lam * score.dot(&delta);
}
let mut curv = 0.0_f64;
for i in 0..n {
curv += self.weights_obs[i] * s_col[i] * s_col[i];
}
let excess = neg_loglik_diff + penalty_term - 0.5 * curv;
if excess.is_finite() {
out.push((excess, Some(ngs.clone())));
} else {
out.push((excess, None));
}
}
out
}
}