use super::family::clamp_bernoulli_link_probability;
use super::*;
use crate::matrix::{LinearOperator, SignedWeightsView};
use gam_math::jet_tower::Tower4;
pub(crate) fn standardize_latent_z_with_policy(
z: &Array1<f64>,
weights: &Array1<f64>,
context: &str,
policy: &LatentZPolicy,
) -> Result<(Array1<f64>, LatentZNormalization), String> {
if z.len() != weights.len() {
return Err(format!(
"{context} latent-score normalization length mismatch: z={}, weights={}",
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
if !(weight_sum.is_finite()
&& weight_sum > 0.0
&& weight_sq_sum.is_finite()
&& weight_sq_sum > 0.0)
{
return Err(format!("{context} requires positive finite total weight"));
}
let effective_n = weight_sum * weight_sum / weight_sq_sum;
if !(effective_n.is_finite() && effective_n > 1.0) {
return Err(format!(
"{context} requires at least two effective observations for latent-score normalization"
));
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let sd = var.sqrt();
if !(sd.is_finite() && sd > BMS_VARIANCE_FLOOR) {
return Err(format!(
"{context} requires z with positive finite weighted standard deviation"
));
}
let target_norm = match policy.normalization {
LatentZNormalizationMode::None => LatentZNormalization { mean: 0.0, sd: 1.0 },
LatentZNormalizationMode::FitWeighted => LatentZNormalization { mean, sd },
LatentZNormalizationMode::Frozen {
mean: frozen_mean,
sd: frozen_sd,
} => LatentZNormalization {
mean: frozen_mean,
sd: frozen_sd,
},
};
let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
let check_msg = || {
format!(
"{context} requires z to already be approximately latent N(0,1) before identification normalization; got mean={mean:.6e}, sd={sd:.6e}, effective_n={effective_n:.1}, allowed_mean={mean_tol:.3e}, allowed_sd={sd_tol:.3e}"
)
};
if mean.abs() > mean_tol || (sd - 1.0).abs() > sd_tol {
match policy.check_mode {
LatentZCheckMode::Strict => return Err(check_msg()),
LatentZCheckMode::WarnOnly => log::warn!("{}", check_msg()),
LatentZCheckMode::Off => {}
}
}
let normalization = target_norm;
let z_std = normalization.apply(z, context)?;
let skew = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(3))
.sum::<f64>()
/ weight_sum;
let kurt = z_std
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi.powi(4))
.sum::<f64>()
/ weight_sum
- 3.0;
if skew.abs() > policy.max_abs_skew || kurt.abs() > policy.max_abs_excess_kurtosis {
let msg = format!(
"{context} requires z to be approximately Gaussian after identification normalization; got skewness={skew:.3}, excess_kurtosis={kurt:.3}"
);
match policy.check_mode {
LatentZCheckMode::Strict => return Err(msg),
LatentZCheckMode::WarnOnly => log::warn!("{}", msg),
LatentZCheckMode::Off => {}
}
}
if skew.abs() > 0.75 || kurt.abs() > 2.0 {
log::warn!(
"{context}: z has skewness={skew:.3} and excess kurtosis={kurt:.3}; latent-measure auto-selection will use empirical calibration unless stricter diagnostics pass"
);
}
Ok((z_std, normalization))
}
pub fn padded_deviation_seed(seed: &Array1<f64>, min_iqr: f64, pad_fraction: f64) -> Array1<f64> {
let mut sorted = seed.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if sorted.len() < 4 {
return seed.clone();
}
let n = sorted.len();
let q1 = sorted[n / 4];
let q3 = sorted[3 * n / 4];
let iqr = (q3 - q1).max(min_iqr);
let pad = pad_fraction * iqr;
let mut out = seed.to_vec();
out.push(sorted[0] - pad);
out.push(sorted[n - 1] + pad);
Array1::from_vec(out)
}
const POOLED_PILOT_MAX_NEWTON_ITERS: usize = 50;
pub(crate) const POOLED_PILOT_RIDGE_INIT: f64 = 1e-8;
pub(crate) const POOLED_PILOT_DET_FLOOR: f64 = 1e-18;
pub(crate) const POOLED_PILOT_RIDGE_GROWTH: f64 = 10.0;
pub(crate) const POOLED_PILOT_RIDGE_MAX: f64 = 1e6;
const POOLED_PILOT_MAX_BACKTRACKS: usize = 25;
pub(crate) const POOLED_PILOT_BACKTRACK_SHRINK: f64 = 0.5;
pub(crate) const POOLED_PILOT_STALL_TOL: f64 = 1e-10;
pub(crate) const POOLED_PILOT_MIN_ABS_SLOPE: f64 = 1e-6;
pub(super) fn pooled_probit_baseline(
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
) -> Result<(f64, f64), String> {
if y.len() != z.len() || y.len() != weights.len() {
return Err(format!(
"pooled bernoulli-marginal-slope pilot length mismatch: y={}, z={}, weights={}",
y.len(),
z.len(),
weights.len()
));
}
let weight_sum = weights.iter().copied().sum::<f64>();
if !weight_sum.is_finite() || weight_sum <= 0.0 {
return Err(
"pooled bernoulli-marginal-slope pilot requires positive finite total weight"
.to_string(),
);
}
let prevalence = y
.iter()
.zip(weights.iter())
.map(|(&yi, &wi)| yi * wi)
.sum::<f64>()
/ weight_sum;
let prevalence = prevalence.clamp(1e-6, 1.0 - 1e-6);
let z_mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| zi * wi)
.sum::<f64>()
/ weight_sum;
let z_var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let yz_cov = y
.iter()
.zip(z.iter())
.zip(weights.iter())
.map(|((&yi, &zi), &wi)| wi * (yi - prevalence) * (zi - z_mean))
.sum::<f64>()
/ weight_sum;
let mut beta0 = standard_normal_quantile(prevalence).map_err(|e| {
format!("failed to initialize pooled bernoulli-marginal-slope pilot intercept: {e}")
})?;
let mut beta1 = if z_var > BMS_VARIANCE_FLOOR {
yz_cov / z_var
} else {
0.0
};
let objective_grad_hess =
|intercept: f64, slope: f64| -> Result<(f64, f64, f64, f64, f64, f64), String> {
let mut obj = 0.0;
let mut g0 = 0.0;
let mut g1 = 0.0;
let mut h00 = 0.0;
let mut h01 = 0.0;
let mut h11 = 0.0;
for ((&yi, &zi), &wi) in y.iter().zip(z.iter()).zip(weights.iter()) {
if wi == 0.0 {
continue;
}
let eta = intercept + slope * zi;
let s = 2.0 * yi - 1.0;
let margin = s * eta;
let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(margin);
let g_eta = -wi * s * lambda;
let h_eta = wi * lambda * (margin + lambda);
obj -= wi * logcdf;
g0 += g_eta;
g1 += g_eta * zi;
h00 += h_eta;
h01 += h_eta * zi;
h11 += h_eta * zi * zi;
}
Ok((obj, g0, g1, h00, h01, h11))
};
let mut obj_prev = f64::INFINITY;
for _ in 0..POOLED_PILOT_MAX_NEWTON_ITERS {
let (obj, g0, g1, h00, h01, h11) = objective_grad_hess(beta0, beta1)?;
if !obj.is_finite() || !g0.is_finite() || !g1.is_finite() {
return Err(
"pooled bernoulli-marginal-slope pilot produced non-finite objective or gradient"
.to_string(),
);
}
let grad_max = g0.abs().max(g1.abs());
if grad_max < BMS_DERIV_TOL {
break;
}
let mut ridge = POOLED_PILOT_RIDGE_INIT;
let (step0, step1) = loop {
let h00_r = h00 + ridge;
let h11_r = h11 + ridge;
let det = h00_r * h11_r - h01 * h01;
if det.is_finite() && det.abs() > POOLED_PILOT_DET_FLOOR {
let s0 = (h11_r * g0 - h01 * g1) / det;
let s1 = (-h01 * g0 + h00_r * g1) / det;
if s0.is_finite() && s1.is_finite() {
break (s0, s1);
}
}
ridge *= POOLED_PILOT_RIDGE_GROWTH;
if ridge > POOLED_PILOT_RIDGE_MAX {
return Err(
"pooled bernoulli-marginal-slope pilot Hessian solve failed".to_string()
);
}
};
let mut accepted = false;
let mut step_scale = 1.0;
for _ in 0..POOLED_PILOT_MAX_BACKTRACKS {
let cand0 = beta0 - step_scale * step0;
let cand1 = beta1 - step_scale * step1;
let (cand_obj, _, _, _, _, _) = objective_grad_hess(cand0, cand1)?;
if cand_obj.is_finite() && cand_obj <= obj {
beta0 = cand0;
beta1 = cand1;
obj_prev = cand_obj;
accepted = true;
break;
}
step_scale *= POOLED_PILOT_BACKTRACK_SHRINK;
}
if !accepted {
if (obj_prev - obj).abs() < POOLED_PILOT_STALL_TOL {
break;
}
return Err("pooled bernoulli-marginal-slope pilot line search failed".to_string());
}
}
let a = beta0;
let b = if beta1.abs() < POOLED_PILOT_MIN_ABS_SLOPE {
if beta1.is_sign_negative() {
-POOLED_PILOT_MIN_ABS_SLOPE
} else {
POOLED_PILOT_MIN_ABS_SLOPE
}
} else {
beta1
};
Ok((a / (1.0 + b * b).sqrt(), b))
}
pub(super) fn pilot_irls_hessian_row_metric_at_eta(
eta_pilot: &Array1<f64>,
sample_weights: &Array1<f64>,
) -> Array1<f64> {
let n = eta_pilot.len();
let mut w = Array1::<f64>::zeros(n);
for i in 0..n {
let eta = eta_pilot[i];
let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
let phi = normal_pdf(eta).max(1e-300);
let var = (mu * (1.0 - mu)).max(1e-300);
w[i] = sample_weights[i] * (phi * phi) / var;
}
w
}
pub(super) fn rigid_pooled_probit_pilot_eta(
base_link: &InverseLink,
z: &Array1<f64>,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
baseline_marginal: f64,
baseline_logslope: f64,
probit_scale: f64,
) -> Result<Array1<f64>, String> {
let n = z.len();
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
let a_pre = baseline_marginal + marginal_offset[i];
let b_pre = baseline_logslope + logslope_offset[i];
let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
.map_err(|e| format!("rigid_pooled_probit_pilot_eta marginal link map: {e}"))?
.q;
out[i] = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
}
Ok(out)
}
pub(crate) const PILOT_RIDGE_DIAG_FRACTION: f64 = 1e-6;
pub(crate) const PILOT_RIDGE_DIAG_FLOOR: f64 = 1e-12;
pub(super) fn pilot_eta_for_link_dev_orthogonalisation(
base_link: &InverseLink,
y: &Array1<f64>,
z: &Array1<f64>,
weights: &Array1<f64>,
marginal_design: &DesignMatrix,
marginal_offset: &Array1<f64>,
logslope_offset: &Array1<f64>,
baseline_marginal: f64,
baseline_logslope: f64,
probit_scale: f64,
) -> Result<Array1<f64>, String> {
use crate::faer_ndarray::FaerCholesky;
let n = y.len();
if marginal_design.nrows() != n {
return Err(format!(
"pilot_eta_for_link_dev_orthogonalisation: marginal design has {} rows, expected {}",
marginal_design.nrows(),
n,
));
}
let mut working_eta = Array1::<f64>::zeros(n);
let mut w_irls = Array1::<f64>::zeros(n);
let mut residual = Array1::<f64>::zeros(n);
for i in 0..n {
let a_pre = baseline_marginal + marginal_offset[i];
let b_pre = baseline_logslope + logslope_offset[i];
let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
.map_err(|e| {
format!("pilot_eta_for_link_dev_orthogonalisation marginal link map: {e}")
})?
.q;
let eta = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
working_eta[i] = eta;
let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
let phi = normal_pdf(eta).max(1e-300);
let var = (mu * (1.0 - mu)).max(1e-300);
w_irls[i] = weights[i] * (phi * phi) / var;
residual[i] = (y[i] - mu) / phi;
}
let p_marg = marginal_design.ncols();
if p_marg == 0 {
return Ok(working_eta);
}
let xtwr = marginal_design.compute_xtwy(&w_irls, &residual)?;
let mut xtwx = marginal_design.xt_diag_x_signed_op(SignedWeightsView::from_array(&w_irls))?;
let trace_diag: f64 = (0..p_marg).map(|i| xtwx[[i, i]]).sum();
let ridge =
(trace_diag / p_marg as f64).max(PILOT_RIDGE_DIAG_FLOOR) * PILOT_RIDGE_DIAG_FRACTION;
for i in 0..p_marg {
xtwx[[i, i]] += ridge;
}
let factor = xtwx
.cholesky(faer::Side::Lower)
.map_err(|e| format!("pilot_eta_for_link_dev_orthogonalisation Cholesky failed: {e}"))?;
let delta_beta_marg = factor.solvevec(&xtwr);
let marg_contrib = marginal_design.dot(&delta_beta_marg);
Ok(&working_eta + &marg_contrib)
}
pub(super) fn joint_setup(
data: ArrayView2<'_, f64>,
marginalspec: &TermCollectionSpec,
logslopespec: &TermCollectionSpec,
marginal_penalties: usize,
logslope_penalties: usize,
extra_rho0: &[f64],
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> ExactJointHyperSetup {
let marginal_terms = spatial_length_scale_term_indices(marginalspec);
let logslope_terms = spatial_length_scale_term_indices(logslopespec);
let rho_dim = marginal_penalties + logslope_penalties + extra_rho0.len();
let mut rho0vec = Array1::<f64>::zeros(rho_dim);
for (idx, &value) in extra_rho0.iter().enumerate() {
rho0vec[marginal_penalties + logslope_penalties + idx] = value;
}
let rho_lower = Array1::<f64>::from_elem(rho_dim, -12.0);
let rho_upper = Array1::<f64>::from_elem(rho_dim, 12.0);
let marginal_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
marginalspec,
&marginal_terms,
kappa_options,
)
.reseed_from_data(data, marginalspec, &marginal_terms, kappa_options);
let logslope_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
logslopespec,
&logslope_terms,
kappa_options,
)
.reseed_from_data(data, logslopespec, &logslope_terms, kappa_options);
let mut values = marginal_kappa.as_array().to_vec();
values.extend(logslope_kappa.as_array().iter());
let marginal_dims = marginal_kappa.dims_per_term().to_vec();
let logslope_dims = logslope_kappa.dims_per_term().to_vec();
let mut dims = marginal_dims.clone();
dims.extend(logslope_dims.iter().copied());
let log_kappa0 = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(values), dims.clone());
let marginal_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut lower_vals = marginal_lower.as_array().to_vec();
lower_vals.extend(logslope_lower.as_array().iter());
let log_kappa_lower =
SpatialLogKappaCoords::new_with_dims(Array1::from_vec(lower_vals), dims.clone());
let marginal_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
marginalspec,
&marginal_terms,
&marginal_dims,
kappa_options,
);
let logslope_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
data,
logslopespec,
&logslope_terms,
&logslope_dims,
kappa_options,
);
let mut upper_vals = marginal_upper.as_array().to_vec();
upper_vals.extend(logslope_upper.as_array().iter());
let log_kappa_upper = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(upper_vals), dims);
let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
ExactJointHyperSetup::new(
rho0vec,
rho_lower,
rho_upper,
log_kappa0,
log_kappa_lower,
log_kappa_upper,
)
}
#[inline]
pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin: f64,
weight: f64,
) -> (f64, f64, f64, f64) {
if weight == 0.0 || signed_margin == f64::INFINITY {
return (0.0, 0.0, 0.0, 0.0);
}
if signed_margin == f64::NEG_INFINITY {
return (f64::NEG_INFINITY, weight, 0.0, 0.0);
}
if signed_margin.is_nan() {
return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
}
let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let k1 = -lambda;
let k2 = lambda * (signed_margin + lambda);
let k3 = lambda
* (1.0
- signed_margin * signed_margin
- 3.0 * signed_margin * lambda
- 2.0 * lambda * lambda);
let k4 = lambda
* ((signed_margin.powi(3) - 3.0 * signed_margin)
+ (7.0 * signed_margin * signed_margin - 4.0) * lambda
+ 12.0 * signed_margin * lambda * lambda
+ 6.0 * lambda.powi(3));
(weight * k1, weight * k2, weight * k3, weight * k4)
}
pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth(
signed_margin: f64,
weight: f64,
) -> Result<(f64, f64, f64, f64), String> {
if weight == 0.0 || signed_margin == f64::INFINITY {
return Ok((0.0, 0.0, 0.0, 0.0));
}
if !signed_margin.is_finite() {
return Err(format!(
"non-finite signed margin in exact probit derivative helper: {signed_margin}"
));
}
Ok(signed_probit_neglog_derivatives_up_to_fourth_numeric(
signed_margin,
weight,
))
}
#[inline]
pub(crate) fn signed_probit_neglog_unary_stack(signed_margin: f64, weight: f64) -> [f64; 5] {
if weight == 0.0 || signed_margin == f64::INFINITY {
return [0.0; 5];
}
if signed_margin == f64::NEG_INFINITY {
return [f64::INFINITY, f64::NEG_INFINITY, weight, 0.0, 0.0];
}
if signed_margin.is_nan() {
return [f64::NAN; 5];
}
let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
let m = signed_margin;
let k1 = -lambda;
let k2 = lambda * (m + lambda);
let k3 = lambda * (1.0 - m * m - 3.0 * m * lambda - 2.0 * lambda * lambda);
let k4 = lambda
* ((m * m * m - 3.0 * m)
+ (7.0 * m * m - 4.0) * lambda
+ 12.0 * m * lambda * lambda
+ 6.0 * lambda * lambda * lambda);
[
-weight * logcdf,
weight * k1,
weight * k2,
weight * k3,
weight * k4,
]
}
#[inline]
pub(super) fn rigid_observed_logslope(logslope: f64, probit_scale: f64) -> f64 {
probit_scale * logslope
}
#[inline]
pub(super) fn rigid_observed_scale(logslope: f64, probit_scale: f64) -> f64 {
let observed_logslope = rigid_observed_logslope(logslope, probit_scale);
(1.0 + observed_logslope * observed_logslope).sqrt()
}
#[inline]
pub(super) fn rigid_intercept_from_marginal(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
marginal_eta * rigid_observed_scale(logslope, probit_scale)
}
#[inline]
pub(super) fn rigid_prescale_intercept_from_marginal(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale) / probit_scale
}
#[inline]
pub(super) fn rigid_prescale_intercept_derivative_abs(
marginal_eta: f64,
logslope: f64,
probit_scale: f64,
) -> f64 {
let c = rigid_observed_scale(logslope, probit_scale);
probit_scale * normal_pdf(marginal_eta) / c
}
#[inline]
pub(super) fn rigid_observed_eta(
marginal_eta: f64,
logslope: f64,
z: f64,
probit_scale: f64,
) -> f64 {
marginal_slope_standard_normal_scalar_eta(marginal_eta, logslope, z, probit_scale)
}
#[inline]
pub(super) fn marginal_slope_standard_normal_scalar_eta(
q: f64,
slope: f64,
z: f64,
probit_scale: f64,
) -> f64 {
let observed_slope = rigid_observed_logslope(slope, probit_scale);
q * (1.0 + observed_slope * observed_slope).sqrt() + observed_slope * z
}
pub(super) fn unary_derivatives_normal_cdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
normal_cdf(x),
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
]
}
pub(super) fn unary_derivatives_normal_pdf(x: f64) -> [f64; 5] {
let pdf = normal_pdf(x);
[
pdf,
-x * pdf,
(x * x - 1.0) * pdf,
(-x.powi(3) + 3.0 * x) * pdf,
(x.powi(4) - 6.0 * x * x + 3.0) * pdf,
]
}
#[inline]
pub(super) fn lse_accumulate(log_max: &mut f64, sum: &mut f64, log_term: f64) {
if !log_term.is_finite() {
return;
}
if log_term > *log_max {
if log_max.is_finite() {
*sum = *sum * (*log_max - log_term).exp() + 1.0;
} else {
*sum = 1.0;
}
*log_max = log_term;
} else {
*sum += (log_term - *log_max).exp();
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MarginalSlopeCovarianceShape {
Diagonal,
Full,
LowRank,
}
#[derive(Clone, Debug, PartialEq)]
pub enum MarginalSlopeCovariance {
Diagonal(Array1<f64>),
Full(Array2<f64>),
LowRank(Array2<f64>),
}
pub(crate) const COVARIANCE_QUADRATIC_FORM_PSD_TOL: f64 = -1e-10;
impl MarginalSlopeCovariance {
pub fn shape(&self) -> MarginalSlopeCovarianceShape {
match self {
Self::Diagonal(_) => MarginalSlopeCovarianceShape::Diagonal,
Self::Full(_) => MarginalSlopeCovarianceShape::Full,
Self::LowRank(_) => MarginalSlopeCovarianceShape::LowRank,
}
}
pub fn dim(&self) -> usize {
match self {
Self::Diagonal(diag) => diag.len(),
Self::Full(cov) => cov.nrows(),
Self::LowRank(factor) => factor.nrows(),
}
}
pub fn validate(&self, context: &str) -> Result<(), String> {
match self {
Self::Diagonal(diag) => {
if diag.is_empty() {
return Err(format!("{context} diagonal covariance is empty"));
}
for (idx, &value) in diag.iter().enumerate() {
if !(value.is_finite() && value >= 0.0) {
return Err(format!(
"{context} diagonal covariance entry {idx} must be finite and non-negative, got {value}"
));
}
}
}
Self::Full(cov) => {
if cov.nrows() == 0 || cov.nrows() != cov.ncols() {
return Err(format!(
"{context} full covariance must be non-empty and square, got {}x{}",
cov.nrows(),
cov.ncols()
));
}
for i in 0..cov.nrows() {
for j in 0..cov.ncols() {
let value = cov[[i, j]];
if !value.is_finite() {
return Err(format!(
"{context} full covariance entry ({i},{j}) is non-finite"
));
}
if (value - cov[[j, i]]).abs()
> 1e-10 * (1.0 + value.abs().max(cov[[j, i]].abs()))
{
return Err(format!(
"{context} full covariance must be symmetric at ({i},{j})"
));
}
}
}
}
Self::LowRank(factor) => {
if factor.nrows() == 0 {
return Err(format!(
"{context} low-rank covariance factor has zero rows"
));
}
for ((i, j), &value) in factor.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"{context} low-rank covariance factor entry ({i},{j}) is non-finite"
));
}
}
}
}
Ok(())
}
pub fn quadratic_form(&self, vector: &[f64]) -> Result<f64, String> {
self.validate("marginal-slope covariance")?;
if vector.len() != self.dim() {
return Err(format!(
"marginal-slope covariance dimension mismatch: vector={}, covariance={}",
vector.len(),
self.dim()
));
}
if vector.iter().any(|value| !value.is_finite()) {
return Err("marginal-slope covariance vector contains non-finite values".to_string());
}
let value = match self {
Self::Diagonal(diag) => vector
.iter()
.zip(diag.iter())
.map(|(&v, &sigma)| v * v * sigma)
.sum::<f64>(),
Self::Full(cov) => {
let mut total = 0.0;
for i in 0..cov.nrows() {
let mut row_dot = 0.0;
for j in 0..cov.ncols() {
row_dot += cov[[i, j]] * vector[j];
}
total += vector[i] * row_dot;
}
total
}
Self::LowRank(factor) => {
let mut total = 0.0;
for r in 0..factor.ncols() {
let mut projection = 0.0;
for k in 0..factor.nrows() {
projection += factor[[k, r]] * vector[k];
}
total += projection * projection;
}
total
}
};
if value.is_finite() && value >= COVARIANCE_QUADRATIC_FORM_PSD_TOL {
Ok(value.max(0.0))
} else {
Err(format!(
"marginal-slope covariance quadratic form must be non-negative, got {value}"
))
}
}
}
pub fn marginal_slope_covariance_from_scores(
scores: ArrayView2<'_, f64>,
weights: &Array1<f64>,
) -> Result<MarginalSlopeCovariance, String> {
let (n, k) = scores.dim();
if k == 0 {
return Err("marginal-slope score matrix must have at least one column".to_string());
}
if weights.len() != n {
return Err(format!(
"marginal-slope covariance weight length mismatch: weights={}, rows={n}",
weights.len()
));
}
let total_weight = weights.iter().copied().sum::<f64>();
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err("marginal-slope covariance needs positive finite total weight".to_string());
}
let mut mean = Array1::<f64>::zeros(k);
for i in 0..n {
let weight = weights[i];
if !(weight.is_finite() && weight >= 0.0) {
return Err(format!(
"marginal-slope covariance weight {i} must be finite and non-negative, got {weight}"
));
}
for j in 0..k {
let score = scores[[i, j]];
if !score.is_finite() {
return Err(format!(
"marginal-slope covariance score ({i},{j}) is non-finite"
));
}
mean[j] += weight * score;
}
}
mean.mapv_inplace(|value| value / total_weight);
let mut cov = Array2::<f64>::zeros((k, k));
for i in 0..n {
let weight = weights[i];
for a in 0..k {
let da = scores[[i, a]] - mean[a];
for b in 0..=a {
let value = weight * da * (scores[[i, b]] - mean[b]) / total_weight;
cov[[a, b]] += value;
if a != b {
cov[[b, a]] += value;
}
}
}
}
if k == 1 {
return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
}
let diag: Vec<f64> = (0..k).map(|i| cov[[i, i]]).collect();
let diag_max = diag.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let numerical_floor = 1e-10 * (1.0 + diag_max);
let mut is_strict_diagonal = true;
'strict: for a in 0..k {
for b in (a + 1)..k {
if cov[[a, b]].abs() > numerical_floor {
is_strict_diagonal = false;
break 'strict;
}
}
}
if is_strict_diagonal {
return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
}
use crate::faer_ndarray::FaerEigh;
let (evals, evecs) = cov
.eigh(faer::Side::Lower)
.map_err(|err| format!("marginal-slope covariance eigendecomposition failed: {err}"))?;
let max_eval = evals
.iter()
.fold(0.0_f64, |acc, &value| acc.max(value.abs()));
let rank_tol = 1e-10 * max_eval.max(1.0);
let positive: Vec<(usize, f64)> = evals
.iter()
.enumerate()
.filter_map(|(idx, &value)| (value > rank_tol).then_some((idx, value)))
.collect();
if positive.len() < k {
let mut factor = Array2::<f64>::zeros((k, positive.len()));
for (col, (idx, value)) in positive.iter().enumerate() {
let scale = value.sqrt();
for row in 0..k {
factor[[row, col]] = evecs[[row, *idx]] * scale;
}
}
return Ok(MarginalSlopeCovariance::LowRank(factor));
}
let sum_w_sq = weights.iter().map(|&w| w * w).sum::<f64>();
let n_eff = if sum_w_sq > 0.0 {
(total_weight * total_weight) / sum_w_sq
} else {
1.0
};
const OFFDIAG_Z_THRESHOLD: f64 = 4.0;
let mut is_stat_diagonal = true;
'stat: for a in 0..k {
for b in (a + 1)..k {
let stat_se = (diag[a].max(0.0) * diag[b].max(0.0) / n_eff)
.max(0.0)
.sqrt();
let threshold = numerical_floor.max(OFFDIAG_Z_THRESHOLD * stat_se);
if cov[[a, b]].abs() > threshold {
is_stat_diagonal = false;
break 'stat;
}
}
}
if is_stat_diagonal {
Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()))
} else {
Ok(MarginalSlopeCovariance::Full(cov))
}
}
pub fn marginal_slope_preserving_scale(
slopes: &[f64],
covariance: &MarginalSlopeCovariance,
probit_scale: f64,
) -> Result<f64, String> {
if !probit_scale.is_finite() {
return Err(format!(
"marginal-slope probit scale must be finite, got {probit_scale}"
));
}
let observed_slopes = slopes
.iter()
.map(|&slope| probit_scale * slope)
.collect::<Vec<_>>();
let variance = covariance.quadratic_form(&observed_slopes)?;
Ok((1.0 + variance).sqrt())
}
pub fn marginal_slope_probit_eta(
q: f64,
z: &[f64],
slopes: &[f64],
covariance: &MarginalSlopeCovariance,
probit_scale: f64,
) -> Result<f64, String> {
if z.len() != slopes.len() {
return Err(format!(
"marginal-slope score/slope dimension mismatch: z={}, slopes={}",
z.len(),
slopes.len()
));
}
if slopes.len() != covariance.dim() {
return Err(format!(
"marginal-slope covariance dimension mismatch: slopes={}, covariance={}",
slopes.len(),
covariance.dim()
));
}
if !q.is_finite() || z.iter().any(|value| !value.is_finite()) {
return Err("marginal-slope probit eta inputs must be finite".to_string());
}
let scale = marginal_slope_preserving_scale(slopes, covariance, probit_scale)?;
let linear = z
.iter()
.zip(slopes.iter())
.map(|(&score, &slope)| probit_scale * slope * score)
.sum::<f64>();
Ok(q * scale + linear)
}
pub(super) fn empirical_rigid_calibration_eval(
intercept: f64,
log_target_mu: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
) -> Result<(f64, f64, f64), String> {
if !intercept.is_finite() {
return Err(format!(
"empirical latent calibration: non-finite intercept {intercept}"
));
}
let observed_slope = rigid_observed_logslope(slope, probit_scale);
const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_8;
let mut log_max_phi = f64::NEG_INFINITY;
let mut sum_phi = 0.0_f64;
let mut log_max_cdf = f64::NEG_INFINITY;
let mut sum_cdf = 0.0_f64;
let mut log_max_pos = f64::NEG_INFINITY;
let mut sum_pos = 0.0_f64;
let mut log_max_neg = f64::NEG_INFINITY;
let mut sum_neg = 0.0_f64;
for (&node, &weight) in nodes.iter().zip(weights.iter()) {
if !(weight.is_finite() && weight > 0.0) {
continue;
}
let eta = intercept + observed_slope * node;
if !eta.is_finite() {
return Err(format!(
"empirical latent calibration: non-finite η at intercept={intercept}, slope={slope}, node={node}"
));
}
let log_w = weight.ln();
let log_phi = -0.5 * eta * eta - HALF_LOG_2PI;
let log_term_phi = log_w + log_phi;
let log_term_cdf = log_w + normal_logcdf(eta);
lse_accumulate(&mut log_max_phi, &mut sum_phi, log_term_phi);
lse_accumulate(&mut log_max_cdf, &mut sum_cdf, log_term_cdf);
if eta != 0.0 {
let log_term_eta_phi = log_term_phi + eta.abs().ln();
if eta > 0.0 {
lse_accumulate(&mut log_max_pos, &mut sum_pos, log_term_eta_phi);
} else {
lse_accumulate(&mut log_max_neg, &mut sum_neg, log_term_eta_phi);
}
}
}
if !(sum_phi.is_finite() && sum_cdf.is_finite() && sum_phi > 0.0 && sum_cdf > 0.0) {
return Err(format!(
"empirical latent calibration: log-space accumulation failed (sum_phi={sum_phi}, sum_cdf={sum_cdf}, intercept={intercept})"
));
}
let log_s_phi = log_max_phi + sum_phi.ln();
let log_s_cdf = log_max_cdf + sum_cdf.ln();
let f = log_s_cdf - log_target_mu;
let log_f_prime = log_s_phi - log_s_cdf;
let f_prime = if log_f_prime > -740.0 {
log_f_prime.exp()
} else {
f64::MIN_POSITIVE
};
let exp_safe = |log_x: f64| -> f64 { if log_x > -740.0 { log_x.exp() } else { 0.0 } };
let pos_over_cdf = if sum_pos > 0.0 {
exp_safe(log_max_pos + sum_pos.ln() - log_s_cdf)
} else {
0.0
};
let neg_over_cdf = if sum_neg > 0.0 {
exp_safe(log_max_neg + sum_neg.ln() - log_s_cdf)
} else {
0.0
};
let s_etaphi_over_s_cdf = pos_over_cdf - neg_over_cdf;
let f_double_prime = -s_etaphi_over_s_cdf - f_prime * f_prime;
if !(f.is_finite() && f_prime.is_finite() && f_prime > 0.0 && f_double_prime.is_finite()) {
return Err(format!(
"empirical latent calibration: non-finite log-space state f={f}, f'={f_prime}, f''={f_double_prime} at intercept={intercept}"
));
}
Ok((f, f_prime, f_double_prime))
}
pub(crate) fn empirical_intercept_from_marginal(
target_mu: f64,
target_q: f64,
slope: f64,
probit_scale: f64,
nodes: &[f64],
weights: &[f64],
initial: Option<f64>,
) -> Result<f64, String> {
if !(target_mu.is_finite() && target_mu > 0.0 && target_mu < 1.0) {
return Err(format!(
"empirical latent calibration requires target mu in (0,1), got {target_mu}"
));
}
let log_target_mu = target_mu.ln();
let closed_form_seed = rigid_intercept_from_marginal(target_q, slope, probit_scale);
let seed = initial.unwrap_or(closed_form_seed);
let eval = |a: f64| {
empirical_rigid_calibration_eval(a, log_target_mu, slope, probit_scale, nodes, weights)
};
let abs_tol = 1e-13_f64.max(4.0 * f64::EPSILON);
let solve_from = |s: f64| {
crate::families::monotone_root::solve_monotone_root(
eval,
s,
"empirical latent intercept",
abs_tol,
64,
48,
)
.map_err(|e| e.to_string())
};
let (root, _, f_best) = match solve_from(seed) {
Ok(v) => v,
Err(first_err) => {
if seed == closed_form_seed {
return Err(first_err);
}
solve_from(closed_form_seed).map_err(|retry_err| {
format!("{first_err}; closed-form retry from a={closed_form_seed:.6}: {retry_err}")
})?
}
};
if f_best.abs() > abs_tol {
return Err(format!(
"empirical latent intercept solve failed: log-residual={f_best:.3e} at a={root:.6}, target mu={target_mu:.6}"
));
}
Ok(root)
}
#[inline]
pub(super) fn rigid_standard_normal_neglog_only(
q: f64,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<f64, String> {
let s = 2.0 * y - 1.0;
let eta = marginal_slope_standard_normal_scalar_eta(q, g, z, probit_scale);
let m = s * eta;
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
if !logcdf.is_finite() {
return Err(format!(
"rigid probit neglog_only: non-finite log Φ at q={q}, g={g}, z={z}, y={y}"
));
}
Ok(-w * logcdf)
}
#[inline]
pub(crate) fn rigid_standard_normal_row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
p: &[S; 2],
marginal: BernoulliMarginalLinkMap,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<S, String> {
let signed = rigid_standard_normal_signed_margin(p, marginal, z, y, probit_scale);
let m = signed.value();
if !(m.is_finite() || m == f64::INFINITY) {
return Err(format!(
"non-finite signed margin in rigid probit row NLL: {m}"
));
}
Ok(signed.compose_unary(signed_probit_neglog_unary_stack(m, w)))
}
#[inline]
pub(crate) fn rigid_standard_normal_signed_margin<S: gam_math::jet_scalar::JetScalar<2>>(
p: &[S; 2],
marginal: BernoulliMarginalLinkMap,
z: f64,
y: f64,
probit_scale: f64,
) -> S {
let q = p[0].compose_unary([
marginal.q,
marginal.q1,
marginal.q2,
marginal.q3,
marginal.q4,
]);
let slope = p[1];
let observed_slope = slope.scale(probit_scale);
let b2 = observed_slope.mul(&observed_slope);
let c = b2.add(&S::constant(1.0)).sqrt();
let eta = q.mul(&c).add(&observed_slope.scale(z));
eta.scale(2.0 * y - 1.0)
}
pub(crate) struct RigidStandardNormalRow {
pub(crate) marginal: BernoulliMarginalLinkMap,
pub(crate) g: f64,
pub(crate) z: f64,
pub(crate) y: f64,
pub(crate) w: f64,
pub(crate) probit_scale: f64,
}
impl gam_math::jet_tower::RowNllProgramGeneric<2> for RigidStandardNormalRow {
fn n_rows(&self) -> usize {
1
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
if row != 0 {
return Err(format!("RigidStandardNormalRow: row {row} out of range"));
}
Ok([self.marginal.eta_value(), self.g])
}
fn row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
&self,
row: usize,
p: &[S; 2],
) -> Result<S, String> {
if row != 0 {
return Err(format!("RigidStandardNormalRow: row {row} out of range"));
}
rigid_standard_normal_row_nll_generic(
p,
self.marginal,
self.z,
self.y,
self.w,
self.probit_scale,
)
}
}
#[inline]
pub(crate) fn rigid_standard_normal_tower(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<Tower4<2>, String> {
let program = RigidStandardNormalRow {
marginal,
g,
z,
y,
w,
probit_scale,
};
gam_math::jet_tower::generic_full_tower(&program, 0)
}
#[inline]
fn rigid_standard_normal_signed_jet(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
probit_scale: f64,
) -> Tower4<2> {
let p = [
Tower4::<2>::variable(marginal.eta_value(), 0),
Tower4::<2>::variable(g, 1),
];
rigid_standard_normal_signed_margin(&p, marginal, z, y, probit_scale)
}
#[inline]
pub(super) fn rigid_standard_normal_towers_batch<T>(
marginals: &[BernoulliMarginalLinkMap],
slopes: &[f64],
zs: &[f64],
ys: &[f64],
weights: &[f64],
probit_scale: f64,
out: &mut [T],
mut fill: impl FnMut(&Tower4<2>) -> Result<T, String>,
) -> Result<(), String> {
let chunk = marginals.len();
if slopes.len() != chunk
|| zs.len() != chunk
|| ys.len() != chunk
|| weights.len() != chunk
|| out.len() != chunk
{
return Err(format!(
"rigid_standard_normal_towers_batch length mismatch: marginals={chunk}, \
slopes={}, zs={}, ys={}, weights={}, out={}",
slopes.len(),
zs.len(),
ys.len(),
weights.len(),
out.len()
));
}
let mut signed: Vec<Tower4<2>> = Vec::with_capacity(chunk);
let mut margins: Vec<f64> = Vec::with_capacity(chunk);
for i in 0..chunk {
let jet =
rigid_standard_normal_signed_jet(marginals[i], slopes[i], zs[i], ys[i], probit_scale);
margins.push(jet.v);
signed.push(jet);
}
let mut stacks: Vec<[f64; 5]> = Vec::with_capacity(chunk);
for i in 0..chunk {
let m = margins[i];
if !(m.is_finite() || m == f64::INFINITY) {
return Err(format!(
"non-finite signed margin in rigid probit tower batch: {m}"
));
}
stacks.push(signed_probit_neglog_unary_stack(m, weights[i]));
}
for i in 0..chunk {
let tower = signed[i].compose_unary(stacks[i]);
out[i] = fill(&tower)?;
}
Ok(())
}
#[inline]
pub(super) fn rigid_standard_normal_row_kernel(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
let program = RigidStandardNormalRow {
marginal,
g,
z,
y,
w,
probit_scale,
};
gam_math::jet_tower::generic_row_kernel(&program, 0)
}
#[inline]
pub(super) fn rigid_standard_normal_mixed_z_sensitivity(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<[f64; 2], String> {
use gam_math::jet_tower::Tower2;
let mut q = Tower2::<3>::constant(marginal.q);
q.g[0] = marginal.q1;
q.h[0][0] = marginal.q2;
let slope = Tower2::<3>::variable(g, 1);
let z_var = Tower2::<3>::variable(z, 2);
let observed_logslope = slope * probit_scale;
let c = (observed_logslope * observed_logslope + 1.0).sqrt();
let eta = q * c + slope * (z_var * probit_scale);
let signed = eta * (2.0 * y - 1.0);
if !(signed.v.is_finite() || signed.v == f64::INFINITY) {
return Err(format!(
"rigid probit mixed-z sensitivity: non-finite signed margin {} at q={}, g={g}, z={z}, y={y}",
signed.v, marginal.q
));
}
let stack = signed_probit_neglog_unary_stack(signed.v, w);
if !stack[0].is_finite() {
return Err(format!(
"rigid probit mixed-z sensitivity: non-finite log Φ at q={}, g={g}, z={z}, y={y}",
marginal.q
));
}
let tower = signed.compose_unary([stack[0], stack[1], stack[2]]);
let s_q = -tower.h[0][2];
let s_g = -tower.h[1][2];
if !(s_q.is_finite() && s_g.is_finite()) {
return Err(format!(
"rigid probit mixed-z sensitivity: non-finite ∂²(log L)/∂(q,g)∂z = [{s_q}, {s_g}] at q={}, g={g}, z={z}",
marginal.q
));
}
Ok([s_q, s_g])
}
pub(super) fn rigid_standard_normal_score_zeta_sensitivity(
base_link: &InverseLink,
marginal_eta: &Array1<f64>,
slope_eta: &Array1<f64>,
z: &Array1<f64>,
y: &Array1<f64>,
weights: &Array1<f64>,
probit_scale: f64,
marginal_design: ArrayView2<'_, f64>,
logslope_design: ArrayView2<'_, f64>,
p_beta: usize,
) -> Result<Array2<f64>, String> {
let n = marginal_eta.len();
let p_m = marginal_design.ncols();
let r = logslope_design.ncols();
if slope_eta.len() != n
|| z.len() != n
|| y.len() != n
|| weights.len() != n
|| marginal_design.nrows() != n
|| logslope_design.nrows() != n
{
return Err(format!(
"score_zeta_sensitivity row mismatch: marginal_eta={n}, slope_eta={}, z={}, y={}, \
weights={}, marginal_design rows={}, logslope_design rows={}",
slope_eta.len(),
z.len(),
y.len(),
weights.len(),
marginal_design.nrows(),
logslope_design.nrows()
));
}
if p_m + r > p_beta {
return Err(format!(
"score_zeta_sensitivity width overflow: marginal({p_m}) + logslope({r}) > p_beta({p_beta})"
));
}
let mut s = Array2::<f64>::zeros((n, p_beta));
for i in 0..n {
let marginal = bernoulli_marginal_link_map(base_link, marginal_eta[i])?;
let [s_q, s_g] = rigid_standard_normal_mixed_z_sensitivity(
marginal,
slope_eta[i],
z[i],
y[i],
weights[i],
probit_scale,
)?;
if s_q != 0.0 {
let m_row = marginal_design.row(i);
for (j, &mij) in m_row.iter().enumerate() {
s[[i, j]] = s_q * mij;
}
}
if s_g != 0.0 {
let g_row = logslope_design.row(i);
for (j, &gij) in g_row.iter().enumerate() {
s[[i, p_m + j]] = s_g * gij;
}
}
}
Ok(s)
}
#[inline]
pub(super) fn rigid_standard_normal_third_full(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<[[[f64; 2]; 2]; 2], String> {
Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t3)
}
#[inline]
pub(super) fn contract_third_full(t: &[[[f64; 2]; 2]; 2], d_eta: f64, d_g: f64) -> [[f64; 2]; 2] {
[
[
t[0][0][0] * d_eta + t[0][0][1] * d_g,
t[0][1][0] * d_eta + t[0][1][1] * d_g,
],
[
t[1][0][0] * d_eta + t[1][0][1] * d_g,
t[1][1][0] * d_eta + t[1][1][1] * d_g,
],
]
}
#[inline]
pub(super) fn rigid_standard_normal_fourth_full(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<[[[[f64; 2]; 2]; 2]; 2], String> {
Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t4)
}
#[inline]
pub(super) fn contract_fourth_full(
t: &[[[[f64; 2]; 2]; 2]; 2],
u_eta: f64,
u_g: f64,
v_eta: f64,
v_g: f64,
) -> [[f64; 2]; 2] {
let mut out = [[0.0; 2]; 2];
for a in 0..2 {
for b in 0..2 {
let mut sum = 0.0;
sum += t[a][b][0][0] * u_eta * v_eta;
sum += t[a][b][0][1] * u_eta * v_g;
sum += t[a][b][1][0] * u_g * v_eta;
sum += t[a][b][1][1] * u_g * v_g;
out[a][b] = sum;
}
}
out
}
pub(super) fn ensure_finite_third_full_cache_row(
t: &[[[f64; 2]; 2]; 2],
context: &str,
) -> Result<(), String> {
if t.iter().flatten().flatten().all(|value| value.is_finite()) {
Ok(())
} else {
Err(format!(
"{context}: warmed third-derivative cache row contains a non-finite value"
))
}
}
pub(super) fn ensure_finite_fourth_full_cache_row(
t: &[[[[f64; 2]; 2]; 2]; 2],
context: &str,
) -> Result<(), String> {
if t.iter()
.flatten()
.flatten()
.flatten()
.all(|value| value.is_finite())
{
Ok(())
} else {
Err(format!(
"{context}: warmed fourth-derivative cache row contains a non-finite value"
))
}
}
pub(crate) fn unary_derivatives_sqrt(x: f64) -> [f64; 5] {
let s = x.max(1e-300).sqrt();
let x1 = x.max(1e-300);
let x2 = x1 * x1;
let x3 = x2 * x1;
[
s,
0.5 / s,
-0.25 / (x1 * s),
3.0 / (8.0 * x2 * s),
-15.0 / (16.0 * x3 * s),
]
}
pub(crate) fn unary_derivatives_neglog_phi(x: f64, weight: f64) -> [f64; 5] {
signed_probit_neglog_unary_stack(x, weight)
}
pub(crate) fn unary_derivatives_log(x: f64) -> [f64; 5] {
let x2 = x * x;
let x3 = x2 * x;
let x4 = x3 * x;
[x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
}
pub(crate) fn unary_derivatives_log_normal_pdf(x: f64) -> [f64; 5] {
let c = 0.5 * (2.0 * std::f64::consts::PI).ln();
[-0.5 * x * x - c, -x, -1.0, 0.0, 0.0]
}
#[cfg(test)]
mod jet_tower_oracle_tests {
use super::*;
fn rigid_standard_normal_third_and_fourth_full(
marginal: BernoulliMarginalLinkMap,
g: f64,
z: f64,
y: f64,
w: f64,
probit_scale: f64,
) -> Result<([[[f64; 2]; 2]; 2], [[[[f64; 2]; 2]; 2]; 2]), String> {
let tower = rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?;
Ok((tower.t3, tower.t4))
}
use gam_math::jet_tower::{
KernelChannels, RowNllProgram, evaluate_program, verify_kernel_channels,
};
struct BernoulliRigidStandardNormalNllProgram {
primaries: Vec<[f64; 2]>,
z: Vec<f64>,
y: Vec<f64>,
w: Vec<f64>,
probit_scale: f64,
}
impl RowNllProgram<2> for BernoulliRigidStandardNormalNllProgram {
fn n_rows(&self) -> usize {
self.primaries.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
self.primaries
.get(row)
.copied()
.ok_or_else(|| format!("bernoulli rigid nll program: row {row} out of range"))
}
fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
let z = self.z[row];
let y = self.y[row];
let w = self.w[row];
let s = self.probit_scale;
let eta_marginal = p[0];
let link = bernoulli_marginal_link_map(
&InverseLink::Standard(crate::types::StandardLink::Probit),
eta_marginal.v,
)?;
let q = eta_marginal.compose_unary([link.q, link.q1, link.q2, link.q3, link.q4]);
let g = p[1];
let observed_slope = g * s;
let c = (observed_slope * observed_slope + 1.0).compose_unary(unary_derivatives_sqrt(
observed_slope.v * observed_slope.v + 1.0,
));
let eta = q * c + observed_slope * z;
let signed = eta * (2.0 * y - 1.0);
Ok(signed.compose_unary(unary_derivatives_neglog_phi(signed.v, w)))
}
}
fn scalar_nll(eta_marginal: f64, g: f64, z: f64, y: f64, w: f64, s: f64) -> f64 {
let link = bernoulli_marginal_link_map(
&InverseLink::Standard(crate::types::StandardLink::Probit),
eta_marginal,
)
.unwrap();
let observed_slope = g * s;
let c = (observed_slope * observed_slope + 1.0).sqrt();
let eta = link.q * c + observed_slope * z;
let signed = (2.0 * y - 1.0) * eta;
let cdf = 0.5 * libm::erfc(-signed / std::f64::consts::SQRT_2);
-w * cdf.max(1e-300).ln()
}
#[test]
fn rigid_bernoulli_row_kernel_agrees_with_jet_tower_program_all_channels() {
let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
let n = eta.len();
let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
for &probit_scale in &[1.0_f64, 0.8] {
let program = BernoulliRigidStandardNormalNllProgram {
primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
z: z.to_vec(),
y: y.to_vec(),
w: w.to_vec(),
probit_scale,
};
for row in 0..n {
let tower = evaluate_program(&program, row).expect("tower evaluation");
let marginal = bernoulli_marginal_link_map(
&InverseLink::Standard(crate::types::StandardLink::Probit),
eta[row],
)
.expect("link map");
let (value, gradient, hessian) = rigid_standard_normal_row_kernel(
marginal,
g[row],
z[row],
y[row],
w[row],
probit_scale,
)
.expect("production row kernel");
let (third_full, fourth_full) = rigid_standard_normal_third_and_fourth_full(
marginal,
g[row],
z[row],
y[row],
w[row],
probit_scale,
)
.expect("production third+fourth");
let third: Vec<([f64; 2], [[f64; 2]; 2])> = dirs
.iter()
.map(|d| (*d, contract_third_full(&third_full, d[0], d[1])))
.collect();
let fourth: Vec<([f64; 2], [f64; 2], [[f64; 2]; 2])> = dirs
.iter()
.enumerate()
.map(|(i, u)| {
let v = dirs[(i + 1) % dirs.len()];
(
*u,
v,
contract_fourth_full(&fourth_full, u[0], u[1], v[0], v[1]),
)
})
.collect();
let claims = KernelChannels {
value,
gradient,
hessian,
third,
fourth,
};
verify_kernel_channels(&tower, &claims, 1e-9).unwrap_or_else(|e| {
panic!(
"probit_scale {probit_scale} row {row}: production rigid Bernoulli \
RowKernel disagrees with #932 jet-tower truth: {e}"
)
});
let h = 1e-3;
let f = |de: f64, dg: f64| {
scalar_nll(
eta[row] + de,
g[row] + dg,
z[row],
y[row],
w[row],
probit_scale,
)
};
let f0 = f(0.0, 0.0);
assert!(
(f0 - tower.v).abs() <= 1e-9 * f0.abs().max(1.0),
"row {row}: independent scalar NLL {f0:+.12e} != tower value {:+.12e}",
tower.v
);
let g_eta = (f(-2.0 * h, 0.0) - 8.0 * f(-h, 0.0) + 8.0 * f(h, 0.0)
- f(2.0 * h, 0.0))
/ (12.0 * h);
let g_g = (f(0.0, -2.0 * h) - 8.0 * f(0.0, -h) + 8.0 * f(0.0, h) - f(0.0, 2.0 * h))
/ (12.0 * h);
for (label, fd, ad) in [("∂η", g_eta, tower.g[0]), ("∂g", g_g, tower.g[1])] {
assert!(
(fd - ad).abs() <= 1e-5 * ad.abs().max(1.0),
"row {row} {label}: FD witness {fd:+.6e} != tower grad {ad:+.6e}"
);
}
}
}
}
#[test]
fn rigid_third_and_fourth_full_shares_one_tower_bit_identical() {
let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
for &probit_scale in &[1.0_f64, 0.8] {
for r in 0..eta.len() {
let marginal = bernoulli_marginal_link_map(
&InverseLink::Standard(crate::types::StandardLink::Probit),
eta[r],
)
.expect("link map");
let t3_sep = rigid_standard_normal_third_full(
marginal,
g[r],
z[r],
y[r],
w[r],
probit_scale,
)
.expect("separate third");
let t4_sep = rigid_standard_normal_fourth_full(
marginal,
g[r],
z[r],
y[r],
w[r],
probit_scale,
)
.expect("separate fourth");
let (t3_comb, t4_comb) = rigid_standard_normal_third_and_fourth_full(
marginal,
g[r],
z[r],
y[r],
w[r],
probit_scale,
)
.expect("combined third+fourth");
for a in 0..2 {
for b in 0..2 {
for c in 0..2 {
assert_eq!(
t3_comb[a][b][c], t3_sep[a][b][c],
"t3[{a}][{b}][{c}] row {r} scale {probit_scale} not bit-identical"
);
for d in 0..2 {
assert_eq!(
t4_comb[a][b][c][d], t4_sep[a][b][c][d],
"t4[{a}][{b}][{c}][{d}] row {r} scale {probit_scale} not bit-identical"
);
}
}
}
}
}
}
}
#[test]
fn rigid_bernoulli_generic_program_matches_tower4_program_all_channels() {
use gam_math::jet_tower::{
generic_fourth_contracted, generic_full_tower, generic_row_kernel,
generic_third_contracted,
};
let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
let n = eta.len();
let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
let close = |a: f64, b: f64, label: &str| {
let band = 1e-12 + 1e-12 * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: generic {a:+.15e} vs Tower4-program {b:+.15e} (band {band:.3e})"
);
};
for &probit_scale in &[1.0_f64, 0.8] {
let tower_program = BernoulliRigidStandardNormalNllProgram {
primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
z: z.to_vec(),
y: y.to_vec(),
w: w.to_vec(),
probit_scale,
};
for row in 0..n {
let truth = evaluate_program(&tower_program, row).expect("Tower4 program tower");
let marginal = bernoulli_marginal_link_map(
&InverseLink::Standard(crate::types::StandardLink::Probit),
eta[row],
)
.expect("link map");
let program = RigidStandardNormalRow {
marginal,
g: g[row],
z: z[row],
y: y[row],
w: w[row],
probit_scale,
};
let full = generic_full_tower(&program, 0).expect("generic full tower");
close(full.v, truth.v, "full value");
for a in 0..2 {
close(full.g[a], truth.g[a], "full grad");
for b in 0..2 {
close(full.h[a][b], truth.h[a][b], "full hess");
for c in 0..2 {
close(full.t3[a][b][c], truth.t3[a][b][c], "full t3");
for d in 0..2 {
close(full.t4[a][b][c][d], truth.t4[a][b][c][d], "full t4");
}
}
}
}
let (val, grad, hess) =
generic_row_kernel(&program, 0).expect("generic row kernel");
close(val, truth.v, "order2 value");
for a in 0..2 {
close(grad[a], truth.g[a], "order2 grad");
for b in 0..2 {
close(hess[a][b], truth.h[a][b], "order2 hess");
}
}
for dir in &dirs {
let third = generic_third_contracted(&program, 0, dir)
.expect("generic third contracted");
let truth3 = truth.third_contracted(dir);
for a in 0..2 {
for b in 0..2 {
close(third[a][b], truth3[a][b], "third contracted");
}
}
}
for (i, u) in dirs.iter().enumerate() {
let v = dirs[(i + 1) % dirs.len()];
let fourth = generic_fourth_contracted(&program, 0, u, &v)
.expect("generic fourth contracted");
let truth4 = truth.fourth_contracted(u, &v);
for a in 0..2 {
for b in 0..2 {
close(fourth[a][b], truth4[a][b], "fourth contracted");
}
}
}
}
}
}
}
#[cfg(test)]
mod flex_primary_hessian_oracle_tests {
use super::*;
use super::family::*;
use crate::matrix::DenseDesignMatrix;
use ndarray::Array1;
use ndarray::Array2;
use std::sync::Arc;
use std::sync::Mutex;
fn make_flex_oracle_family(
n: usize,
) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
let cfg = DeviationBlockConfig {
num_internal_knots: 3,
..DeviationBlockConfig::default()
};
let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
.expect("build score warp block");
let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
&link_seed, &link_seed, &cfg,
)
.expect("build link deviation block");
let y: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
let z: Array1<f64> =
Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
-0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
}
});
let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
1.0
} else {
0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
}
});
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(y),
weights: Arc::new(weights),
z: Arc::new(z.clone()),
latent_measure: LatentMeasureKind::StandardNormal,
gaussian_frailty_sd: Some(0.15),
base_link: InverseLink::Standard(crate::types::StandardLink::Probit),
marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
score_warp: Some(score_prepared.runtime.clone()),
link_dev: Some(link_prepared.runtime.clone()),
policy: gam_runtime::resource::ResourcePolicy::default_library(),
cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
};
let beta_m = Array1::from_vec(vec![0.12, -0.04]);
let beta_g = Array1::from_vec(vec![0.35, 0.03]);
let beta_h = Array1::from_iter(
(0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
);
let beta_w = Array1::from_iter(
(0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
);
let states = vec![
ParameterBlockState {
eta: marginal_x.dot(&beta_m),
beta: beta_m,
},
ParameterBlockState {
eta: logslope_x.dot(&beta_g),
beta: beta_g,
},
ParameterBlockState {
beta: beta_h,
eta: Array1::zeros(z.len()),
},
ParameterBlockState {
beta: beta_w,
eta: Array1::zeros(z.len()),
},
];
(family, states)
}
fn flex_gradient_at_perturbed(
family: &BernoulliMarginalSlopeFamily,
states: &[ParameterBlockState],
primary: &super::super::hessian_paths::PrimarySlices,
row: usize,
u: usize,
delta: f64,
) -> Array1<f64> {
let mut states = states.to_vec();
if u == primary.q {
states[0].eta[row] += delta;
} else if u == primary.logslope {
states[1].eta[row] += delta;
} else if let Some(h_range) = primary.h.as_ref()
&& h_range.contains(&u)
{
states[2].beta[u - h_range.start] += delta;
} else if let Some(w_range) = primary.w.as_ref()
&& w_range.contains(&u)
{
states[3].beta[u - w_range.start] += delta;
} else {
panic!("primary coordinate {u} out of range for flex oracle");
}
let row_ctx = family
.build_row_exact_context_with_stats_and_cell_cache(row, &states, None, false)
.expect("perturbed row context");
let (_neglog, grad, _hess) = family
.compute_row_primary_gradient_hessian(row, &states, primary, &row_ctx)
.expect("perturbed gradient");
grad
}
#[test]
fn flex_primary_hessian_matches_central_fd_of_gradient() {
let n = 12usize;
let (family, states) = make_flex_oracle_family(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("flex exact eval cache");
let primary = &cache.primary;
let r = primary.total;
assert!(
r >= 4,
"flex fixture must carry q + logslope + deviation blocks"
);
let h = 1e-4;
let mut max_rel = 0.0_f64;
for &row in &[2usize, 5, 8] {
let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
let (_neglog, _grad, analytic_hess) = family
.compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
.expect("analytic flex gradient + hessian");
for u in 0..r {
let grad_plus = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
let grad_minus = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
for v in 0..r {
let fd = (grad_plus[v] - grad_minus[v]) / (2.0 * h);
let analytic = analytic_hess[[v, u]];
let denom = 1.0 + analytic.abs().max(fd.abs());
let rel = (analytic - fd).abs() / denom;
max_rel = max_rel.max(rel);
assert!(
rel <= 1e-6,
"flex hand Hessian H[{v}][{u}] = {analytic:.6e} disagrees with central \
FD of the gradient {fd:.6e} at row {row} (rel {rel:.3e}); a product-rule \
term is dropped or mis-signed"
);
}
}
}
assert!(
max_rel <= 1e-6,
"flex Hessian FD oracle max rel {max_rel:.3e}"
);
}
#[test]
fn arbiter_flex_hessian_h00_fd_step_scaling() {
let n = 12usize;
let (family, states) = make_flex_oracle_family(n);
let cache = family
.build_exact_eval_cache(&states)
.expect("flex exact eval cache");
let primary = &cache.primary;
let row = 2usize;
let u = primary.q; let v = primary.q;
let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
let (_neglog, _grad, analytic_hess) = family
.compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
.expect("analytic flex gradient + hessian");
let analytic = analytic_hess[[v, u]];
let fd_at = |h: f64| -> f64 {
let gp = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
let gm = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
(gp[v] - gm[v]) / (2.0 * h)
};
let h = 1e-3_f64;
let fd_h = fd_at(h);
let fd_half = fd_at(h * 0.5);
let fd_quarter = fd_at(h * 0.25);
let gap_h = (analytic - fd_h).abs();
let gap_half = (analytic - fd_half).abs();
let gap_quarter = (analytic - fd_quarter).abs();
let rich = (4.0 * fd_half - fd_h) / 3.0;
let rich_gap = (analytic - rich).abs();
let denom = analytic.abs().max(1.0);
let record = format!(
"FLEX H[0][0] ARBITER row 2: analytic={analytic:+.12e} \
fd(h)={fd_h:+.12e} fd(h/2)={fd_half:+.12e} fd(h/4)={fd_quarter:+.12e} \
gap(h)={gap_h:.3e} gap(h/2)={gap_half:.3e} gap(h/4)={gap_quarter:.3e} \
ratio_h_over_half={:.3} ratio_half_over_quarter={:.3} \
richardson={rich:+.12e} richardson_gap={rich_gap:.3e} (rich_rel={:.3e})",
gap_h / gap_half.max(f64::MIN_POSITIVE),
gap_half / gap_quarter.max(f64::MIN_POSITIVE),
rich_gap / denom,
);
assert!(
rich_gap / denom <= 1e-7,
"{record}\nVERDICT: Richardson residual exceeds the FD-truncation floor — \
the hand H[0][0] genuinely diverges (real dropped/mis-signed term), NOT FD noise"
);
}
}