Skip to main content

gam_models/bms/
gradient_paths.rs

1use super::family::clamp_bernoulli_link_probability;
2use super::*;
3use gam_linalg::matrix::{LinearOperator, SignedWeightsView};
4use gam_math::jet_tower::Tower4;
5
6pub(crate) fn standardize_latent_z_with_policy(
7    z: &Array1<f64>,
8    weights: &Array1<f64>,
9    context: &str,
10    policy: &LatentZPolicy,
11) -> Result<(Array1<f64>, LatentZNormalization), String> {
12    if z.len() != weights.len() {
13        return Err(format!(
14            "{context} latent-score normalization length mismatch: z={}, weights={}",
15            z.len(),
16            weights.len()
17        ));
18    }
19    let weight_sum = weights.iter().copied().sum::<f64>();
20    let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
21    if !(weight_sum.is_finite()
22        && weight_sum > 0.0
23        && weight_sq_sum.is_finite()
24        && weight_sq_sum > 0.0)
25    {
26        return Err(format!("{context} requires positive finite total weight"));
27    }
28    let effective_n = weight_sum * weight_sum / weight_sq_sum;
29    if !(effective_n.is_finite() && effective_n > 1.0) {
30        return Err(format!(
31            "{context} requires at least two effective observations for latent-score normalization"
32        ));
33    }
34    let mean = z
35        .iter()
36        .zip(weights.iter())
37        .map(|(&zi, &wi)| wi * zi)
38        .sum::<f64>()
39        / weight_sum;
40    let var = z
41        .iter()
42        .zip(weights.iter())
43        .map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
44        .sum::<f64>()
45        / weight_sum;
46    let sd = var.sqrt();
47    if !(sd.is_finite() && sd > BMS_VARIANCE_FLOOR) {
48        return Err(format!(
49            "{context} requires z with positive finite weighted standard deviation"
50        ));
51    }
52    let target_norm = match policy.normalization {
53        LatentZNormalizationMode::None => LatentZNormalization { mean: 0.0, sd: 1.0 },
54        LatentZNormalizationMode::FitWeighted => LatentZNormalization { mean, sd },
55        LatentZNormalizationMode::Frozen {
56            mean: frozen_mean,
57            sd: frozen_sd,
58        } => LatentZNormalization {
59            mean: frozen_mean,
60            sd: frozen_sd,
61        },
62    };
63    let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
64    let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
65    let check_msg = || {
66        format!(
67            "{context} requires z to already be approximately latent N(0,1) before identification normalization; got mean={mean:.6e}, sd={sd:.6e}, effective_n={effective_n:.1}, allowed_mean={mean_tol:.3e}, allowed_sd={sd_tol:.3e}"
68        )
69    };
70    if mean.abs() > mean_tol || (sd - 1.0).abs() > sd_tol {
71        match policy.check_mode {
72            LatentZCheckMode::Strict => return Err(check_msg()),
73            LatentZCheckMode::WarnOnly => log::warn!("{}", check_msg()),
74            LatentZCheckMode::Off => {}
75        }
76    }
77
78    let normalization = target_norm;
79    let z_std = normalization.apply(z, context)?;
80    let skew = z_std
81        .iter()
82        .zip(weights.iter())
83        .map(|(&zi, &wi)| wi * zi.powi(3))
84        .sum::<f64>()
85        / weight_sum;
86    let kurt = z_std
87        .iter()
88        .zip(weights.iter())
89        .map(|(&zi, &wi)| wi * zi.powi(4))
90        .sum::<f64>()
91        / weight_sum
92        - 3.0;
93    if skew.abs() > policy.max_abs_skew || kurt.abs() > policy.max_abs_excess_kurtosis {
94        let msg = format!(
95            "{context} requires z to be approximately Gaussian after identification normalization; got skewness={skew:.3}, excess_kurtosis={kurt:.3}"
96        );
97        match policy.check_mode {
98            LatentZCheckMode::Strict => return Err(msg),
99            LatentZCheckMode::WarnOnly => log::warn!("{}", msg),
100            LatentZCheckMode::Off => {}
101        }
102    }
103    if skew.abs() > 0.75 || kurt.abs() > 2.0 {
104        log::warn!(
105            "{context}: z has skewness={skew:.3} and excess kurtosis={kurt:.3}; latent-measure auto-selection will use empirical calibration unless stricter diagnostics pass"
106        );
107    }
108    Ok((z_std, normalization))
109}
110
111pub fn padded_deviation_seed(seed: &Array1<f64>, min_iqr: f64, pad_fraction: f64) -> Array1<f64> {
112    let mut sorted = seed.to_vec();
113    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114
115    if sorted.len() < 4 {
116        return seed.clone();
117    }
118
119    let n = sorted.len();
120    let q1 = sorted[n / 4];
121    let q3 = sorted[3 * n / 4];
122    let iqr = (q3 - q1).max(min_iqr);
123    let pad = pad_fraction * iqr;
124
125    let mut out = seed.to_vec();
126    out.push(sorted[0] - pad);
127    out.push(sorted[n - 1] + pad);
128    Array1::from_vec(out)
129}
130
131// ── Pooled 2-D probit pilot Newton solver tuning ─────────────────────────────
132//
133// `pooled_probit_baseline` solves a 2-parameter (intercept, slope) penalised
134// probit by damped Newton. The values below are the standard convergence /
135// safeguard knobs; they are deliberately conservative because the pilot is a
136// cheap warm-start for the full fit, not the production estimator.
137
138/// Maximum damped-Newton outer iterations for the pooled probit pilot. A 2-D
139/// strictly-convex probit converges in well under this; the cap only guards a
140/// pathological non-finite data configuration.
141const POOLED_PILOT_MAX_NEWTON_ITERS: usize = 50;
142/// Initial Levenberg ridge added to the 2×2 Hessian diagonal before the solve.
143pub(crate) const POOLED_PILOT_RIDGE_INIT: f64 = 1e-8;
144/// Below this absolute determinant the ridged 2×2 system is treated as
145/// singular and the ridge is escalated.
146pub(crate) const POOLED_PILOT_DET_FLOOR: f64 = 1e-18;
147/// Geometric factor by which the ridge grows when the system is singular.
148pub(crate) const POOLED_PILOT_RIDGE_GROWTH: f64 = 10.0;
149/// Ridge ceiling; exceeding it means the Hessian is unusable and the pilot
150/// fails rather than returning a meaningless step.
151pub(crate) const POOLED_PILOT_RIDGE_MAX: f64 = 1e6;
152/// Maximum backtracking-line-search halvings per Newton step.
153const POOLED_PILOT_MAX_BACKTRACKS: usize = 25;
154/// Backtracking step contraction factor.
155pub(crate) const POOLED_PILOT_BACKTRACK_SHRINK: f64 = 0.5;
156/// Objective-change tolerance below which a stalled (rejected) line search is
157/// accepted as converged instead of erroring.
158pub(crate) const POOLED_PILOT_STALL_TOL: f64 = 1e-10;
159/// Minimum-magnitude signed slope returned by the pilot, so the downstream
160/// `b/√(1+b²)` rigid seed never collapses to an exactly flat (zero-slope) link.
161pub(crate) const POOLED_PILOT_MIN_ABS_SLOPE: f64 = 1e-6;
162
163pub(super) fn pooled_probit_baseline(
164    y: &Array1<f64>,
165    z: &Array1<f64>,
166    weights: &Array1<f64>,
167) -> Result<(f64, f64), String> {
168    if y.len() != z.len() || y.len() != weights.len() {
169        return Err(format!(
170            "pooled bernoulli-marginal-slope pilot length mismatch: y={}, z={}, weights={}",
171            y.len(),
172            z.len(),
173            weights.len()
174        ));
175    }
176    let weight_sum = weights.iter().copied().sum::<f64>();
177    if !weight_sum.is_finite() || weight_sum <= 0.0 {
178        return Err(
179            "pooled bernoulli-marginal-slope pilot requires positive finite total weight"
180                .to_string(),
181        );
182    }
183    let prevalence = y
184        .iter()
185        .zip(weights.iter())
186        .map(|(&yi, &wi)| yi * wi)
187        .sum::<f64>()
188        / weight_sum;
189    let prevalence = prevalence.clamp(1e-6, 1.0 - 1e-6);
190    let z_mean = z
191        .iter()
192        .zip(weights.iter())
193        .map(|(&zi, &wi)| zi * wi)
194        .sum::<f64>()
195        / weight_sum;
196    let z_var = z
197        .iter()
198        .zip(weights.iter())
199        .map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
200        .sum::<f64>()
201        / weight_sum;
202    let yz_cov = y
203        .iter()
204        .zip(z.iter())
205        .zip(weights.iter())
206        .map(|((&yi, &zi), &wi)| wi * (yi - prevalence) * (zi - z_mean))
207        .sum::<f64>()
208        / weight_sum;
209    let mut beta0 = standard_normal_quantile(prevalence).map_err(|e| {
210        format!("failed to initialize pooled bernoulli-marginal-slope pilot intercept: {e}")
211    })?;
212    let mut beta1 = if z_var > BMS_VARIANCE_FLOOR {
213        yz_cov / z_var
214    } else {
215        0.0
216    };
217
218    let objective_grad_hess =
219        |intercept: f64, slope: f64| -> Result<(f64, f64, f64, f64, f64, f64), String> {
220            let mut obj = 0.0;
221            let mut g0 = 0.0;
222            let mut g1 = 0.0;
223            let mut h00 = 0.0;
224            let mut h01 = 0.0;
225            let mut h11 = 0.0;
226            for ((&yi, &zi), &wi) in y.iter().zip(z.iter()).zip(weights.iter()) {
227                if wi == 0.0 {
228                    continue;
229                }
230                let eta = intercept + slope * zi;
231                let s = 2.0 * yi - 1.0;
232                let margin = s * eta;
233                let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(margin);
234                let g_eta = -wi * s * lambda;
235                let h_eta = wi * lambda * (margin + lambda);
236                obj -= wi * logcdf;
237                g0 += g_eta;
238                g1 += g_eta * zi;
239                h00 += h_eta;
240                h01 += h_eta * zi;
241                h11 += h_eta * zi * zi;
242            }
243            Ok((obj, g0, g1, h00, h01, h11))
244        };
245
246    let mut obj_prev = f64::INFINITY;
247    for _ in 0..POOLED_PILOT_MAX_NEWTON_ITERS {
248        let (obj, g0, g1, h00, h01, h11) = objective_grad_hess(beta0, beta1)?;
249        if !obj.is_finite() || !g0.is_finite() || !g1.is_finite() {
250            return Err(
251                "pooled bernoulli-marginal-slope pilot produced non-finite objective or gradient"
252                    .to_string(),
253            );
254        }
255        let grad_max = g0.abs().max(g1.abs());
256        if grad_max < BMS_DERIV_TOL {
257            break;
258        }
259        let mut ridge = POOLED_PILOT_RIDGE_INIT;
260        let (step0, step1) = loop {
261            let h00_r = h00 + ridge;
262            let h11_r = h11 + ridge;
263            let det = h00_r * h11_r - h01 * h01;
264            if det.is_finite() && det.abs() > POOLED_PILOT_DET_FLOOR {
265                let s0 = (h11_r * g0 - h01 * g1) / det;
266                let s1 = (-h01 * g0 + h00_r * g1) / det;
267                if s0.is_finite() && s1.is_finite() {
268                    break (s0, s1);
269                }
270            }
271            ridge *= POOLED_PILOT_RIDGE_GROWTH;
272            if ridge > POOLED_PILOT_RIDGE_MAX {
273                return Err(
274                    "pooled bernoulli-marginal-slope pilot Hessian solve failed".to_string()
275                );
276            }
277        };
278        let mut accepted = false;
279        let mut step_scale = 1.0;
280        for _ in 0..POOLED_PILOT_MAX_BACKTRACKS {
281            let cand0 = beta0 - step_scale * step0;
282            let cand1 = beta1 - step_scale * step1;
283            let (cand_obj, _, _, _, _, _) = objective_grad_hess(cand0, cand1)?;
284            if cand_obj.is_finite() && cand_obj <= obj {
285                beta0 = cand0;
286                beta1 = cand1;
287                obj_prev = cand_obj;
288                accepted = true;
289                break;
290            }
291            step_scale *= POOLED_PILOT_BACKTRACK_SHRINK;
292        }
293        if !accepted {
294            if (obj_prev - obj).abs() < POOLED_PILOT_STALL_TOL {
295                break;
296            }
297            return Err("pooled bernoulli-marginal-slope pilot line search failed".to_string());
298        }
299    }
300    let a = beta0;
301    // Signed slope: preserve direction from pilot probit.
302    let b = if beta1.abs() < POOLED_PILOT_MIN_ABS_SLOPE {
303        if beta1.is_sign_negative() {
304            -POOLED_PILOT_MIN_ABS_SLOPE
305        } else {
306            POOLED_PILOT_MIN_ABS_SLOPE
307        }
308    } else {
309        beta1
310    };
311    Ok((a / (1.0 + b * b).sqrt(), b))
312}
313
314// Compute a non-degenerate pilot η for the link-deviation cross-block
315// identifiability orthogonalisation.
316//
317// The rigid pooled probit pilot from `pooled_probit_baseline` is a scalar
318// pair `(a₀, b₀)`, so the rigid observed-scale linear predictor
319// `η_rigid[i] = a₀·√(1 + (s_f·b₀)²) + s_f·b₀·z[i]` is **exactly affine in z**
320// when the per-row offsets are zero. A degree-3 I-spline of an affine
321// function of `z` spans the same column space at training rows as a
322// degree-3 I-spline of `z` directly, so evaluating the link-deviation basis
323// at `η_rigid` and orthogonalising it against the score-warp basis (built
324// on `z`) produces a structurally singular cross-Gram — the candidate is
325// fully aliased even though at PIRLS time the link-deviation runtime is
326// re-evaluated at the current β-dependent η which carries genuine PC / age
327// structure that the score-warp cannot represent.
328//
329// One probit Gauss-Newton step from the rigid pilot, projected onto the
330// full marginal design at the W-IRLS working response, picks up that PC /
331// age structure cheaply (one `p_marg × p_marg` Cholesky plus a few matvecs
332// — `<<1 s` at large scale because `p_marg` is `O(10²)` whereas the
333// PIRLS dense Hessian build is `O(n·p²)` per cycle). The resulting
334// `η_pilot[i]` has the same row-by-row variation pattern PIRLS will see at
335// any non-degenerate β, so the orthogonalisation transform `T` drops only
336// the directions that are aliased *across all* β, not those that are
337// aliased only at the rigid (rank-1-in-z) pilot.
338/// IRLS Hessian row metric for the probit-style data Hessian at a fixed
339/// linear predictor `eta`: `w[i] = sample_weights[i] · φ(η_i)² / (μ_i·(1−μ_i))`.
340///
341/// This is the canonical row metric that the joint penalised Hessian sees
342/// during PIRLS for a probit GLM (and the dominant term for
343/// BernoulliMarginalSlope's data Hessian). Cross-block orthogonalisation
344/// against parametric anchors must use **this** metric — not a uniform
345/// W=spec.weights — for the joint Hessian to be block-orthogonal between
346/// parametric and flex spans. With a uniform W the orthogonalisation only
347/// kills the Euclidean alias; at PIRLS time `Aᵀ W_pirls C̃ ≠ 0` and the
348/// joint Hessian carries a near-null direction along the W-metric alias,
349/// which REML can drive to arbitrarily small eigenvalue by shrinking the
350/// flex block's smoothing parameter — β then runs away along the alias
351/// (the failure mode that manifests as `rho≈2.0`, constant `step_inf`,
352/// and `beta_inf` growing without bound during PIRLS).
353pub(super) fn pilot_irls_hessian_row_metric_at_eta(
354    eta_pilot: &Array1<f64>,
355    sample_weights: &Array1<f64>,
356) -> Array1<f64> {
357    let n = eta_pilot.len();
358    let mut w = Array1::<f64>::zeros(n);
359    for i in 0..n {
360        let eta = eta_pilot[i];
361        let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
362        let phi = normal_pdf(eta).max(1e-300);
363        let var = (mu * (1.0 - mu)).max(1e-300);
364        w[i] = sample_weights[i] * (phi * phi) / var;
365    }
366    w
367}
368
369/// Per-row rigid pooled-probit pilot η used to seed the IRLS Hessian
370/// metric for score-warp cross-block orthogonalisation. Score-warp's
371/// basis is evaluated at `z` (β-independent) so there is no GN-stepped
372/// pilot to share with the link-deviation path; the rigid pooled-probit
373/// pilot is a sensible β-independent reference at which to evaluate
374/// `W = p(1−p)·spec.weights` for the W-metric orthogonalisation.
375pub(super) fn rigid_pooled_probit_pilot_eta(
376    base_link: &InverseLink,
377    z: &Array1<f64>,
378    marginal_offset: &Array1<f64>,
379    logslope_offset: &Array1<f64>,
380    baseline_marginal: f64,
381    baseline_logslope: f64,
382    probit_scale: f64,
383) -> Result<Array1<f64>, String> {
384    let n = z.len();
385    let mut out = Array1::<f64>::zeros(n);
386    for i in 0..n {
387        let a_pre = baseline_marginal + marginal_offset[i];
388        let b_pre = baseline_logslope + logslope_offset[i];
389        let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
390            .map_err(|e| format!("rigid_pooled_probit_pilot_eta marginal link map: {e}"))?
391            .q;
392        out[i] = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
393    }
394    Ok(out)
395}
396
397/// Tikhonov ridge for the pilot IRLS marginal solve, as a fraction of the mean
398/// Hessian diagonal: `ridge = PILOT_RIDGE_DIAG_FRACTION * max(mean_diag, floor)`.
399/// Scaling by the diagonal keeps the ridge scale-invariant; the fraction is
400/// small enough to be numerically negligible against a well-conditioned design
401/// yet still regularise a near-singular pilot Gram.
402pub(crate) const PILOT_RIDGE_DIAG_FRACTION: f64 = 1e-6;
403/// Positivity floor on the mean Hessian diagonal used to scale the pilot ridge,
404/// so a degenerate (all-zero-diagonal) Gram still receives a tiny ridge.
405pub(crate) const PILOT_RIDGE_DIAG_FLOOR: f64 = 1e-12;
406
407pub(super) fn pilot_eta_for_link_dev_orthogonalisation(
408    base_link: &InverseLink,
409    y: &Array1<f64>,
410    z: &Array1<f64>,
411    weights: &Array1<f64>,
412    marginal_design: &DesignMatrix,
413    marginal_offset: &Array1<f64>,
414    logslope_offset: &Array1<f64>,
415    baseline_marginal: f64,
416    baseline_logslope: f64,
417    probit_scale: f64,
418) -> Result<Array1<f64>, String> {
419    use gam_linalg::faer_ndarray::FaerCholesky;
420
421    let n = y.len();
422    if marginal_design.nrows() != n {
423        return Err(format!(
424            "pilot_eta_for_link_dev_orthogonalisation: marginal design has {} rows, expected {}",
425            marginal_design.nrows(),
426            n,
427        ));
428    }
429    let mut working_eta = Array1::<f64>::zeros(n);
430    let mut w_irls = Array1::<f64>::zeros(n);
431    let mut residual = Array1::<f64>::zeros(n);
432    for i in 0..n {
433        let a_pre = baseline_marginal + marginal_offset[i];
434        let b_pre = baseline_logslope + logslope_offset[i];
435        let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
436            .map_err(|e| {
437                format!("pilot_eta_for_link_dev_orthogonalisation marginal link map: {e}")
438            })?
439            .q;
440        let eta = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
441        working_eta[i] = eta;
442        let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
443        let phi = normal_pdf(eta).max(1e-300);
444        let var = (mu * (1.0 - mu)).max(1e-300);
445        w_irls[i] = weights[i] * (phi * phi) / var;
446        residual[i] = (y[i] - mu) / phi;
447    }
448    let p_marg = marginal_design.ncols();
449    if p_marg == 0 {
450        return Ok(working_eta);
451    }
452    let xtwr = marginal_design.compute_xtwy(&w_irls, &residual)?;
453    let mut xtwx = marginal_design.xt_diag_x_signed_op(SignedWeightsView::from_array(&w_irls))?;
454    let trace_diag: f64 = (0..p_marg).map(|i| xtwx[[i, i]]).sum();
455    let ridge =
456        (trace_diag / p_marg as f64).max(PILOT_RIDGE_DIAG_FLOOR) * PILOT_RIDGE_DIAG_FRACTION;
457    for i in 0..p_marg {
458        xtwx[[i, i]] += ridge;
459    }
460    let factor = xtwx
461        .cholesky(faer::Side::Lower)
462        .map_err(|e| format!("pilot_eta_for_link_dev_orthogonalisation Cholesky failed: {e}"))?;
463    let delta_beta_marg = factor.solvevec(&xtwr);
464    let marg_contrib = marginal_design.dot(&delta_beta_marg);
465    Ok(&working_eta + &marg_contrib)
466}
467
468pub(super) fn joint_setup(
469    data: ArrayView2<'_, f64>,
470    marginalspec: &TermCollectionSpec,
471    logslopespec: &TermCollectionSpec,
472    marginal_penalties: usize,
473    logslope_penalties: usize,
474    extra_rho0: &[f64],
475    kappa_options: &SpatialLengthScaleOptimizationOptions,
476) -> ExactJointHyperSetup {
477    let marginal_terms = spatial_length_scale_term_indices(marginalspec);
478    let logslope_terms = spatial_length_scale_term_indices(logslopespec);
479    let rho_dim = marginal_penalties + logslope_penalties + extra_rho0.len();
480    let mut rho0vec = Array1::<f64>::zeros(rho_dim);
481    for (idx, &value) in extra_rho0.iter().enumerate() {
482        rho0vec[marginal_penalties + logslope_penalties + idx] = value;
483    }
484    let rho_lower = Array1::<f64>::from_elem(rho_dim, -12.0);
485    let rho_upper = Array1::<f64>::from_elem(rho_dim, 12.0);
486    let marginal_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
487        marginalspec,
488        &marginal_terms,
489        kappa_options,
490    )
491    .reseed_from_data(data, marginalspec, &marginal_terms, kappa_options);
492    let logslope_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
493        logslopespec,
494        &logslope_terms,
495        kappa_options,
496    )
497    .reseed_from_data(data, logslopespec, &logslope_terms, kappa_options);
498    let mut values = marginal_kappa.as_array().to_vec();
499    values.extend(logslope_kappa.as_array().iter());
500    let marginal_dims = marginal_kappa.dims_per_term().to_vec();
501    let logslope_dims = logslope_kappa.dims_per_term().to_vec();
502    let mut dims = marginal_dims.clone();
503    dims.extend(logslope_dims.iter().copied());
504    let log_kappa0 = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(values), dims.clone());
505    // Bounds: concatenate per-block data-aware bounds in the same order.
506    let marginal_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
507        data,
508        marginalspec,
509        &marginal_terms,
510        &marginal_dims,
511        kappa_options,
512    );
513    let logslope_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
514        data,
515        logslopespec,
516        &logslope_terms,
517        &logslope_dims,
518        kappa_options,
519    );
520    let mut lower_vals = marginal_lower.as_array().to_vec();
521    lower_vals.extend(logslope_lower.as_array().iter());
522    let log_kappa_lower =
523        SpatialLogKappaCoords::new_with_dims(Array1::from_vec(lower_vals), dims.clone());
524    let marginal_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
525        data,
526        marginalspec,
527        &marginal_terms,
528        &marginal_dims,
529        kappa_options,
530    );
531    let logslope_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
532        data,
533        logslopespec,
534        &logslope_terms,
535        &logslope_dims,
536        kappa_options,
537    );
538    let mut upper_vals = marginal_upper.as_array().to_vec();
539    upper_vals.extend(logslope_upper.as_array().iter());
540    let log_kappa_upper = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(upper_vals), dims);
541    // Project seed onto bounds in case a user-provided spec.length_scale falls
542    // outside the data-derived ψ window; seed was a hint, not a hard constraint.
543    let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
544    ExactJointHyperSetup::new(
545        rho0vec,
546        rho_lower,
547        rho_upper,
548        log_kappa0,
549        log_kappa_lower,
550        log_kappa_upper,
551    )
552}
553
554#[inline]
555pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth_numeric(
556    signed_margin: f64,
557    weight: f64,
558) -> (f64, f64, f64, f64) {
559    if weight == 0.0 || signed_margin == f64::INFINITY {
560        return (0.0, 0.0, 0.0, 0.0);
561    }
562    if signed_margin == f64::NEG_INFINITY {
563        return (f64::NEG_INFINITY, weight, 0.0, 0.0);
564    }
565    if signed_margin.is_nan() {
566        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
567    }
568    let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
569    let k1 = -lambda;
570    let k2 = lambda * (signed_margin + lambda);
571    let k3 = lambda
572        * (1.0
573            - signed_margin * signed_margin
574            - 3.0 * signed_margin * lambda
575            - 2.0 * lambda * lambda);
576    let k4 = lambda
577        * ((signed_margin.powi(3) - 3.0 * signed_margin)
578            + (7.0 * signed_margin * signed_margin - 4.0) * lambda
579            + 12.0 * signed_margin * lambda * lambda
580            + 6.0 * lambda.powi(3));
581    (weight * k1, weight * k2, weight * k3, weight * k4)
582}
583
584/// Exact probit derivative helper used by analytic jet code paths.
585///
586/// `+inf` is the saturated zero tail and is allowed. `-inf` and `NaN` are
587/// rejected instead of being silently collapsed, so exact callers fail fast
588/// rather than erasing curvature or domain errors. Numeric boundary behavior
589/// that needs to preserve `-inf` / `NaN` values lives in
590/// `signed_probit_neglog_derivatives_up_to_fourth_numeric`.
591pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth(
592    signed_margin: f64,
593    weight: f64,
594) -> Result<(f64, f64, f64, f64), String> {
595    if weight == 0.0 || signed_margin == f64::INFINITY {
596        return Ok((0.0, 0.0, 0.0, 0.0));
597    }
598    if !signed_margin.is_finite() {
599        return Err(format!(
600            "non-finite signed margin in exact probit derivative helper: {signed_margin}"
601        ));
602    }
603    Ok(signed_probit_neglog_derivatives_up_to_fourth_numeric(
604        signed_margin,
605        weight,
606    ))
607}
608
609/// Fused exact value+derivative stack for the signed-probit negative-log
610/// kernel: returns `[-w·logΦ(m), w·k1, w·k2, w·k3, w·k4]` in the `[f64; 5]`
611/// shape [`Tower4::compose_unary`] consumes.
612///
613/// This is the single-source replacement for the two-call pattern
614///
615/// ```ignore
616/// let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
617/// let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(m, w)?;
618/// // → [-w*logcdf, k1, k2, k3, k4]
619/// ```
620///
621/// which evaluated `signed_probit_logcdf_and_mills_ratio` TWICE on the same
622/// `m` (once for `logΦ`, once again — discarding `logΦ` — for the Mills ratio
623/// `λ` that drives `k1..k4`). On the rigid standard-normal BMS path that pair
624/// of `erfcx`/`erfc` transcendentals is the dominant per-row arithmetic across
625/// all `n ≈ 356k` rows, so collapsing it to ONE call halves the transcendental
626/// budget of the jet build. The result is bit-identical: `logΦ` and `λ` are the
627/// exact same values the two-call form produced (same branch, same `ex`), and
628/// `k1..k4` are the same polynomials in `(m, λ)`.
629///
630/// Boundary semantics match [`unary_derivatives_neglog_phi`] (the prior
631/// two-call form): `+∞` is the saturated zero tail (all zero); `−∞` returns the
632/// `[+∞, −w, w·0, 0, 0]` limit (value `−w·logΦ(−∞)=+∞`, `k1=−λ→−∞` scaled by the
633/// `w` already folded by the numeric derivative helper); `NaN` propagates.
634#[inline]
635pub(crate) fn signed_probit_neglog_unary_stack(signed_margin: f64, weight: f64) -> [f64; 5] {
636    if weight == 0.0 || signed_margin == f64::INFINITY {
637        return [0.0; 5];
638    }
639    if signed_margin == f64::NEG_INFINITY {
640        // logΦ(−∞) = −∞ ⇒ value −w·(−∞) = +∞; the derivative helper's −∞ limit
641        // is (−∞, w, 0, 0) for (k1, k2, k3, k4) before the weight fold below.
642        return [f64::INFINITY, f64::NEG_INFINITY, weight, 0.0, 0.0];
643    }
644    if signed_margin.is_nan() {
645        return [f64::NAN; 5];
646    }
647    // ONE transcendental evaluation feeds both the value (logΦ) and every
648    // derivative (through the Mills ratio λ).
649    let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
650    let m = signed_margin;
651    let k1 = -lambda;
652    let k2 = lambda * (m + lambda);
653    let k3 = lambda * (1.0 - m * m - 3.0 * m * lambda - 2.0 * lambda * lambda);
654    let k4 = lambda
655        * ((m * m * m - 3.0 * m)
656            + (7.0 * m * m - 4.0) * lambda
657            + 12.0 * m * lambda * lambda
658            + 6.0 * lambda * lambda * lambda);
659    [
660        -weight * logcdf,
661        weight * k1,
662        weight * k2,
663        weight * k3,
664        weight * k4,
665    ]
666}
667
668#[inline]
669pub(super) fn rigid_observed_logslope(logslope: f64, probit_scale: f64) -> f64 {
670    probit_scale * logslope
671}
672
673#[inline]
674pub(super) fn rigid_observed_scale(logslope: f64, probit_scale: f64) -> f64 {
675    let observed_logslope = rigid_observed_logslope(logslope, probit_scale);
676    (1.0 + observed_logslope * observed_logslope).sqrt()
677}
678
679#[inline]
680pub(super) fn rigid_intercept_from_marginal(
681    marginal_eta: f64,
682    logslope: f64,
683    probit_scale: f64,
684) -> f64 {
685    marginal_eta * rigid_observed_scale(logslope, probit_scale)
686}
687
688#[inline]
689pub(super) fn rigid_prescale_intercept_from_marginal(
690    marginal_eta: f64,
691    logslope: f64,
692    probit_scale: f64,
693) -> f64 {
694    rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale) / probit_scale
695}
696
697#[inline]
698pub(super) fn rigid_prescale_intercept_derivative_abs(
699    marginal_eta: f64,
700    logslope: f64,
701    probit_scale: f64,
702) -> f64 {
703    let c = rigid_observed_scale(logslope, probit_scale);
704    probit_scale * normal_pdf(marginal_eta) / c
705}
706
707#[inline]
708pub(super) fn rigid_observed_eta(
709    marginal_eta: f64,
710    logslope: f64,
711    z: f64,
712    probit_scale: f64,
713) -> f64 {
714    marginal_slope_standard_normal_scalar_eta(marginal_eta, logslope, z, probit_scale)
715}
716
717#[inline]
718pub(super) fn marginal_slope_standard_normal_scalar_eta(
719    q: f64,
720    slope: f64,
721    z: f64,
722    probit_scale: f64,
723) -> f64 {
724    let observed_slope = rigid_observed_logslope(slope, probit_scale);
725    q * (1.0 + observed_slope * observed_slope).sqrt() + observed_slope * z
726}
727
728pub(super) fn unary_derivatives_normal_cdf(x: f64) -> [f64; 5] {
729    let pdf = normal_pdf(x);
730    [
731        normal_cdf(x),
732        pdf,
733        -x * pdf,
734        (x * x - 1.0) * pdf,
735        (-x.powi(3) + 3.0 * x) * pdf,
736    ]
737}
738
739pub(super) fn unary_derivatives_normal_pdf(x: f64) -> [f64; 5] {
740    let pdf = normal_pdf(x);
741    [
742        pdf,
743        -x * pdf,
744        (x * x - 1.0) * pdf,
745        (-x.powi(3) + 3.0 * x) * pdf,
746        (x.powi(4) - 6.0 * x * x + 3.0) * pdf,
747    ]
748}
749
750/// Streaming log-sum-exp update: accumulate `exp(log_term)` into a running
751/// `(log_max, sum)` pair representing `Σ exp(log_term_i) = exp(log_max) · sum`.
752///
753/// When `log_term` exceeds the running max, the partial sum is rescaled in
754/// place so the new max becomes the reference point. This keeps everything
755/// inside the dynamic range of f64 with no allocation.
756#[inline]
757pub(super) fn lse_accumulate(log_max: &mut f64, sum: &mut f64, log_term: f64) {
758    if !log_term.is_finite() {
759        return;
760    }
761    if log_term > *log_max {
762        if log_max.is_finite() {
763            *sum = *sum * (*log_max - log_term).exp() + 1.0;
764        } else {
765            *sum = 1.0;
766        }
767        *log_max = log_term;
768    } else {
769        *sum += (log_term - *log_max).exp();
770    }
771}
772
773#[derive(Clone, Copy, Debug, PartialEq, Eq)]
774pub enum MarginalSlopeCovarianceShape {
775    Diagonal,
776    Full,
777    LowRank,
778}
779
780#[derive(Clone, Debug, PartialEq)]
781pub enum MarginalSlopeCovariance {
782    Diagonal(Array1<f64>),
783    Full(Array2<f64>),
784    /// Low-rank factor L with Sigma = L L^T.
785    LowRank(Array2<f64>),
786}
787
788/// Negative-side tolerance on the covariance quadratic form `rᵀΣr`. The form
789/// is mathematically PSD but finite-precision accumulation in the dense / low-
790/// rank sums can produce a tiny negative value at a true zero; results within
791/// this tolerance are clamped to zero, anything more negative is a real error.
792pub(crate) const COVARIANCE_QUADRATIC_FORM_PSD_TOL: f64 = -1e-10;
793
794impl MarginalSlopeCovariance {
795    pub fn shape(&self) -> MarginalSlopeCovarianceShape {
796        match self {
797            Self::Diagonal(_) => MarginalSlopeCovarianceShape::Diagonal,
798            Self::Full(_) => MarginalSlopeCovarianceShape::Full,
799            Self::LowRank(_) => MarginalSlopeCovarianceShape::LowRank,
800        }
801    }
802
803    pub fn dim(&self) -> usize {
804        match self {
805            Self::Diagonal(diag) => diag.len(),
806            Self::Full(cov) => cov.nrows(),
807            Self::LowRank(factor) => factor.nrows(),
808        }
809    }
810
811    pub fn validate(&self, context: &str) -> Result<(), String> {
812        match self {
813            Self::Diagonal(diag) => {
814                if diag.is_empty() {
815                    return Err(format!("{context} diagonal covariance is empty"));
816                }
817                for (idx, &value) in diag.iter().enumerate() {
818                    if !(value.is_finite() && value >= 0.0) {
819                        return Err(format!(
820                            "{context} diagonal covariance entry {idx} must be finite and non-negative, got {value}"
821                        ));
822                    }
823                }
824            }
825            Self::Full(cov) => {
826                if cov.nrows() == 0 || cov.nrows() != cov.ncols() {
827                    return Err(format!(
828                        "{context} full covariance must be non-empty and square, got {}x{}",
829                        cov.nrows(),
830                        cov.ncols()
831                    ));
832                }
833                for i in 0..cov.nrows() {
834                    for j in 0..cov.ncols() {
835                        let value = cov[[i, j]];
836                        if !value.is_finite() {
837                            return Err(format!(
838                                "{context} full covariance entry ({i},{j}) is non-finite"
839                            ));
840                        }
841                        if (value - cov[[j, i]]).abs()
842                            > 1e-10 * (1.0 + value.abs().max(cov[[j, i]].abs()))
843                        {
844                            return Err(format!(
845                                "{context} full covariance must be symmetric at ({i},{j})"
846                            ));
847                        }
848                    }
849                }
850            }
851            Self::LowRank(factor) => {
852                if factor.nrows() == 0 {
853                    return Err(format!(
854                        "{context} low-rank covariance factor has zero rows"
855                    ));
856                }
857                for ((i, j), &value) in factor.indexed_iter() {
858                    if !value.is_finite() {
859                        return Err(format!(
860                            "{context} low-rank covariance factor entry ({i},{j}) is non-finite"
861                        ));
862                    }
863                }
864            }
865        }
866        Ok(())
867    }
868
869    pub fn quadratic_form(&self, vector: &[f64]) -> Result<f64, String> {
870        self.validate("marginal-slope covariance")?;
871        if vector.len() != self.dim() {
872            return Err(format!(
873                "marginal-slope covariance dimension mismatch: vector={}, covariance={}",
874                vector.len(),
875                self.dim()
876            ));
877        }
878        if vector.iter().any(|value| !value.is_finite()) {
879            return Err("marginal-slope covariance vector contains non-finite values".to_string());
880        }
881        let value = match self {
882            Self::Diagonal(diag) => vector
883                .iter()
884                .zip(diag.iter())
885                .map(|(&v, &sigma)| v * v * sigma)
886                .sum::<f64>(),
887            Self::Full(cov) => {
888                let mut total = 0.0;
889                for i in 0..cov.nrows() {
890                    let mut row_dot = 0.0;
891                    for j in 0..cov.ncols() {
892                        row_dot += cov[[i, j]] * vector[j];
893                    }
894                    total += vector[i] * row_dot;
895                }
896                total
897            }
898            Self::LowRank(factor) => {
899                // Sigma = L L'. The Gaussian-probit scale only needs
900                // r' Sigma r = ||L' r||^2. Equivalently,
901                // det(I + L' r r' L) = 1 + ||L' r||^2 by the matrix
902                // determinant lemma, so the low-rank path never builds
903                // the full K x K covariance.
904                let mut total = 0.0;
905                for r in 0..factor.ncols() {
906                    let mut projection = 0.0;
907                    for k in 0..factor.nrows() {
908                        projection += factor[[k, r]] * vector[k];
909                    }
910                    total += projection * projection;
911                }
912                total
913            }
914        };
915        if value.is_finite() && value >= COVARIANCE_QUADRATIC_FORM_PSD_TOL {
916            Ok(value.max(0.0))
917        } else {
918            Err(format!(
919                "marginal-slope covariance quadratic form must be non-negative, got {value}"
920            ))
921        }
922    }
923}
924
925// Marginal-slope probit identity.
926//
927// For a row with latent scores z | a ~ N(0, Sigma(a)) and probit index
928//
929//     eta = c(a) q(t, a) + r(a)' z,
930//
931// the preservation target is
932//
933//     E_z[Phi(-eta) | a] = Phi(-q(t, a)).
934//
935// If X = r' z is N(0, v) with v = r' Sigma r, then for independent
936// E ~ N(0, 1),
937//
938//     E[Phi(-(c q + X))]
939//       = P(E <= -c q - X)
940//       = P(E + X <= -c q)
941//       = Phi(-c q / sqrt(1 + v)).
942//
943// Thus the target holds for every q exactly when
944//
945//     c(a) = sqrt(1 + r(a)' Sigma(a) r(a)).
946//
947// `probit_scale` maps the raw log-slope surface to the observed probit
948// gradient r(a). K=1 with diagonal variance 1 gives the original scalar
949// formula sqrt(1 + r^2); full and low-rank covariances differ only in the
950// shape-specific evaluation of the same quadratic form.
951pub fn marginal_slope_covariance_from_scores(
952    scores: ArrayView2<'_, f64>,
953    weights: &Array1<f64>,
954) -> Result<MarginalSlopeCovariance, String> {
955    let (n, k) = scores.dim();
956    if k == 0 {
957        return Err("marginal-slope score matrix must have at least one column".to_string());
958    }
959    if weights.len() != n {
960        return Err(format!(
961            "marginal-slope covariance weight length mismatch: weights={}, rows={n}",
962            weights.len()
963        ));
964    }
965    let total_weight = weights.iter().copied().sum::<f64>();
966    if !(total_weight.is_finite() && total_weight > 0.0) {
967        return Err("marginal-slope covariance needs positive finite total weight".to_string());
968    }
969    let mut mean = Array1::<f64>::zeros(k);
970    for i in 0..n {
971        let weight = weights[i];
972        if !(weight.is_finite() && weight >= 0.0) {
973            return Err(format!(
974                "marginal-slope covariance weight {i} must be finite and non-negative, got {weight}"
975            ));
976        }
977        for j in 0..k {
978            let score = scores[[i, j]];
979            if !score.is_finite() {
980                return Err(format!(
981                    "marginal-slope covariance score ({i},{j}) is non-finite"
982                ));
983            }
984            mean[j] += weight * score;
985        }
986    }
987    mean.mapv_inplace(|value| value / total_weight);
988
989    let mut cov = Array2::<f64>::zeros((k, k));
990    for i in 0..n {
991        let weight = weights[i];
992        for a in 0..k {
993            let da = scores[[i, a]] - mean[a];
994            for b in 0..=a {
995                let value = weight * da * (scores[[i, b]] - mean[b]) / total_weight;
996                cov[[a, b]] += value;
997                if a != b {
998                    cov[[b, a]] += value;
999                }
1000            }
1001        }
1002    }
1003
1004    // ── Shape classification ──
1005    //
1006    // Pick the cheapest representation that preserves r'Σr for arbitrary r.
1007    //
1008    //   * K = 1: always Diagonal — LowRank/Full distinctions are meaningless.
1009    //
1010    //   * STRICT NUMERICAL DIAGONAL: if every off-diagonal is at machine
1011    //     precision relative to the diagonal scale, return Diagonal.  This
1012    //     catches both structurally-orthogonal inputs (post-orthogonalised
1013    //     production paths) AND degenerate cases like a column of all
1014    //     zeros (rank-deficient but truly diagonal).
1015    //
1016    //   * Otherwise eigendecompose.  positive.len() < K ⇒ the rank
1017    //     deficiency comes from collinear columns (off-diagonals are
1018    //     non-trivial) — Diagonal would drop the coupling and break r'Σr
1019    //     ⇒ LowRank.
1020    //
1021    //   * Full rank: apply a 4σ statistical off-diagonal test.  Under H0
1022    //     (independent population columns) the asymptotic SE of an
1023    //     off-diagonal sample covariance is √(σ_aa σ_bb / N_eff) with
1024    //     N_eff = (Σw)² / Σw² (Kish).  Pass ⇒ Diagonal (sample noise was
1025    //     not real correlation), fail ⇒ Full.  At large-scale N_eff the 4σ
1026    //     statistical floor collapses below the numerical floor, so
1027    //     production behaviour is unchanged.
1028    if k == 1 {
1029        return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
1030    }
1031
1032    let diag: Vec<f64> = (0..k).map(|i| cov[[i, i]]).collect();
1033    let diag_max = diag.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1034    let numerical_floor = 1e-10 * (1.0 + diag_max);
1035
1036    let mut is_strict_diagonal = true;
1037    'strict: for a in 0..k {
1038        for b in (a + 1)..k {
1039            if cov[[a, b]].abs() > numerical_floor {
1040                is_strict_diagonal = false;
1041                break 'strict;
1042            }
1043        }
1044    }
1045    if is_strict_diagonal {
1046        return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
1047    }
1048
1049    use gam_linalg::faer_ndarray::FaerEigh;
1050    let (evals, evecs) = cov
1051        .eigh(faer::Side::Lower)
1052        .map_err(|err| format!("marginal-slope covariance eigendecomposition failed: {err}"))?;
1053    let max_eval = evals
1054        .iter()
1055        .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
1056    let rank_tol = 1e-10 * max_eval.max(1.0);
1057    let positive: Vec<(usize, f64)> = evals
1058        .iter()
1059        .enumerate()
1060        .filter_map(|(idx, &value)| (value > rank_tol).then_some((idx, value)))
1061        .collect();
1062
1063    if positive.len() < k {
1064        // Rank deficiency with non-trivial off-diagonals ⇒ collinear
1065        // columns; Diagonal would lose the coupling.
1066        let mut factor = Array2::<f64>::zeros((k, positive.len()));
1067        for (col, (idx, value)) in positive.iter().enumerate() {
1068            let scale = value.sqrt();
1069            for row in 0..k {
1070                factor[[row, col]] = evecs[[row, *idx]] * scale;
1071            }
1072        }
1073        return Ok(MarginalSlopeCovariance::LowRank(factor));
1074    }
1075
1076    // Full rank.  4σ statistical off-diagonal test.
1077    let sum_w_sq = weights.iter().map(|&w| w * w).sum::<f64>();
1078    let n_eff = if sum_w_sq > 0.0 {
1079        (total_weight * total_weight) / sum_w_sq
1080    } else {
1081        1.0
1082    };
1083    const OFFDIAG_Z_THRESHOLD: f64 = 4.0;
1084    let mut is_stat_diagonal = true;
1085    'stat: for a in 0..k {
1086        for b in (a + 1)..k {
1087            let stat_se = (diag[a].max(0.0) * diag[b].max(0.0) / n_eff)
1088                .max(0.0)
1089                .sqrt();
1090            let threshold = numerical_floor.max(OFFDIAG_Z_THRESHOLD * stat_se);
1091            if cov[[a, b]].abs() > threshold {
1092                is_stat_diagonal = false;
1093                break 'stat;
1094            }
1095        }
1096    }
1097    if is_stat_diagonal {
1098        Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()))
1099    } else {
1100        Ok(MarginalSlopeCovariance::Full(cov))
1101    }
1102}
1103
1104pub fn marginal_slope_preserving_scale(
1105    slopes: &[f64],
1106    covariance: &MarginalSlopeCovariance,
1107    probit_scale: f64,
1108) -> Result<f64, String> {
1109    if !probit_scale.is_finite() {
1110        return Err(format!(
1111            "marginal-slope probit scale must be finite, got {probit_scale}"
1112        ));
1113    }
1114    let observed_slopes = slopes
1115        .iter()
1116        .map(|&slope| probit_scale * slope)
1117        .collect::<Vec<_>>();
1118    let variance = covariance.quadratic_form(&observed_slopes)?;
1119    Ok((1.0 + variance).sqrt())
1120}
1121
1122pub fn marginal_slope_probit_eta(
1123    q: f64,
1124    z: &[f64],
1125    slopes: &[f64],
1126    covariance: &MarginalSlopeCovariance,
1127    probit_scale: f64,
1128) -> Result<f64, String> {
1129    if z.len() != slopes.len() {
1130        return Err(format!(
1131            "marginal-slope score/slope dimension mismatch: z={}, slopes={}",
1132            z.len(),
1133            slopes.len()
1134        ));
1135    }
1136    if slopes.len() != covariance.dim() {
1137        return Err(format!(
1138            "marginal-slope covariance dimension mismatch: slopes={}, covariance={}",
1139            slopes.len(),
1140            covariance.dim()
1141        ));
1142    }
1143    if !q.is_finite() || z.iter().any(|value| !value.is_finite()) {
1144        return Err("marginal-slope probit eta inputs must be finite".to_string());
1145    }
1146    let scale = marginal_slope_preserving_scale(slopes, covariance, probit_scale)?;
1147    let linear = z
1148        .iter()
1149        .zip(slopes.iter())
1150        .map(|(&score, &slope)| probit_scale * slope * score)
1151        .sum::<f64>();
1152    Ok(q * scale + linear)
1153}
1154
1155/// Log-space residual evaluator for the empirical-frailty intercept calibration.
1156///
1157/// Solves, in log-space, the strictly-increasing equation
1158///
1159///   F(a) = log Σᵢ wᵢ Φ(a + b·zᵢ) − log μ★ = 0,
1160///
1161/// where `b = rigid_observed_logslope(slope, probit_scale)` and `(zᵢ, wᵢ)` are
1162/// the supplied quadrature nodes and (positive) weights.
1163///
1164/// Mathematical structure of `F`:
1165///   • `F ∈ C^∞(ℝ)`.
1166///   • `F` is strictly increasing: `F'(a) = (Σ wᵢ φᵢ) / (Σ wᵢ Φᵢ) > 0` everywhere.
1167///   • `F(a) → −∞` as `a → −∞`; `F(a) → log(Σ wᵢ) − log μ★ ≥ 0` as `a → +∞`.
1168///   • Unique root `a★ ∈ ℝ` exists for every `μ★ ∈ (0, 1)`.
1169///
1170/// Why log-space: the linear-space residual `Σ wᵢ Φᵢ − μ★` and its derivative
1171/// `Σ wᵢ φᵢ` are sums of strictly-positive `exp(−η²/2)`-scaled terms. When the
1172/// seed `a` puts every quadrature node `ηᵢ = a + b·zᵢ` into the deep tail
1173/// (|ηᵢ| ≳ 38), every term rounds to 0.0 in IEEE-754 and the derivative
1174/// underflows to exactly zero — destroying Newton's update direction.  The
1175/// log-space formulation evaluates `log φ(η) = −η²/2 − ½ log 2π` (always finite
1176/// for any finite η) and `log Φ(η)` via the `erfcx`-based `normal_logcdf`
1177/// (also always finite for any finite η).  All sums are accumulated by
1178/// streaming log-sum-exp, so `F`, `F'`, and `F''` are finite for every finite
1179/// `a` and the global Newton/Halley iteration converges from any seed.
1180///
1181/// Returns `(F, F', F'')`.  In the deep left tail Newton converges linearly
1182/// (Mills ratio: `F'(a) ≈ |a|`, step ≈ `|a|/2`); near the root convergence is
1183/// quadratic with Newton or cubic with Halley.
1184pub(super) fn empirical_rigid_calibration_eval(
1185    intercept: f64,
1186    log_target_mu: f64,
1187    slope: f64,
1188    probit_scale: f64,
1189    nodes: &[f64],
1190    weights: &[f64],
1191) -> Result<(f64, f64, f64), String> {
1192    if !intercept.is_finite() {
1193        return Err(format!(
1194            "empirical latent calibration: non-finite intercept {intercept}"
1195        ));
1196    }
1197    let observed_slope = rigid_observed_logslope(slope, probit_scale);
1198    const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_8; // 0.5 * ln(2π)
1199
1200    // Streaming LSE accumulators for log Σ wᵢ φᵢ and log Σ wᵢ Φᵢ.
1201    let mut log_max_phi = f64::NEG_INFINITY;
1202    let mut sum_phi = 0.0_f64;
1203    let mut log_max_cdf = f64::NEG_INFINITY;
1204    let mut sum_cdf = 0.0_f64;
1205
1206    // Streaming signed LSE for Σ wᵢ ηᵢ φᵢ, split into positive and negative
1207    // legs so the cancellation `pos − neg` happens once at the end on a
1208    // finite, well-scaled remainder.
1209    let mut log_max_pos = f64::NEG_INFINITY;
1210    let mut sum_pos = 0.0_f64;
1211    let mut log_max_neg = f64::NEG_INFINITY;
1212    let mut sum_neg = 0.0_f64;
1213
1214    for (&node, &weight) in nodes.iter().zip(weights.iter()) {
1215        if !(weight.is_finite() && weight > 0.0) {
1216            continue;
1217        }
1218        let eta = intercept + observed_slope * node;
1219        if !eta.is_finite() {
1220            return Err(format!(
1221                "empirical latent calibration: non-finite η at intercept={intercept}, slope={slope}, node={node}"
1222            ));
1223        }
1224        let log_w = weight.ln();
1225        let log_phi = -0.5 * eta * eta - HALF_LOG_2PI;
1226        let log_term_phi = log_w + log_phi;
1227        let log_term_cdf = log_w + normal_logcdf(eta);
1228
1229        lse_accumulate(&mut log_max_phi, &mut sum_phi, log_term_phi);
1230        lse_accumulate(&mut log_max_cdf, &mut sum_cdf, log_term_cdf);
1231
1232        if eta != 0.0 {
1233            let log_term_eta_phi = log_term_phi + eta.abs().ln();
1234            if eta > 0.0 {
1235                lse_accumulate(&mut log_max_pos, &mut sum_pos, log_term_eta_phi);
1236            } else {
1237                lse_accumulate(&mut log_max_neg, &mut sum_neg, log_term_eta_phi);
1238            }
1239        }
1240    }
1241
1242    if !(sum_phi.is_finite() && sum_cdf.is_finite() && sum_phi > 0.0 && sum_cdf > 0.0) {
1243        return Err(format!(
1244            "empirical latent calibration: log-space accumulation failed (sum_phi={sum_phi}, sum_cdf={sum_cdf}, intercept={intercept})"
1245        ));
1246    }
1247
1248    let log_s_phi = log_max_phi + sum_phi.ln();
1249    let log_s_cdf = log_max_cdf + sum_cdf.ln();
1250
1251    // F = log Σ wᵢ Φᵢ − log μ★
1252    let f = log_s_cdf - log_target_mu;
1253    // F' = exp(log Σ wᵢ φᵢ − log Σ wᵢ Φᵢ).
1254    //
1255    // F' is mathematically strictly positive everywhere — `Σ wᵢ φᵢ` and
1256    // `Σ wᵢ Φᵢ` are both sums of strictly-positive terms with positive weights.
1257    // In the far right tail, Mills ratio gives `φᵢ/Φᵢ → 0` exponentially, so
1258    // `log F' → −∞` and `(log F').exp()` IEEE-underflows to 0.0. Mathematically
1259    // it is a tiny positive number; floor it at `f64::MIN_POSITIVE` so the
1260    // monotone-root solver sees a strictly-positive derivative and routes
1261    // through its bracket-by-doubling phase (which only needs the *sign* of
1262    // `F'`, not its magnitude). Newton would propose `Δa = −F/F' = ±∞`, the
1263    // solver detects that and falls through to bracketing automatically.
1264    let log_f_prime = log_s_phi - log_s_cdf;
1265    let f_prime = if log_f_prime > -740.0 {
1266        log_f_prime.exp()
1267    } else {
1268        f64::MIN_POSITIVE
1269    };
1270
1271    // F'' = (d/da)(S_φ/S_Φ) = (S_φ' S_Φ − S_φ²)/S_Φ²
1272    //     = −(Σ wᵢ ηᵢ φᵢ)/S_Φ − (F')²
1273    // The η-weighted sum is cancellation-prone; combine its positive and
1274    // negative legs against the same `log_s_cdf` reference so the subtraction
1275    // happens on dimensionless quantities of bounded magnitude. When the ratio
1276    // also underflows (deep tail), the result is a clean numerical zero —
1277    // Halley reduces to Newton, which is what the solver does anyway.
1278    let exp_safe = |log_x: f64| -> f64 { if log_x > -740.0 { log_x.exp() } else { 0.0 } };
1279    let pos_over_cdf = if sum_pos > 0.0 {
1280        exp_safe(log_max_pos + sum_pos.ln() - log_s_cdf)
1281    } else {
1282        0.0
1283    };
1284    let neg_over_cdf = if sum_neg > 0.0 {
1285        exp_safe(log_max_neg + sum_neg.ln() - log_s_cdf)
1286    } else {
1287        0.0
1288    };
1289    let s_etaphi_over_s_cdf = pos_over_cdf - neg_over_cdf;
1290    let f_double_prime = -s_etaphi_over_s_cdf - f_prime * f_prime;
1291
1292    if !(f.is_finite() && f_prime.is_finite() && f_prime > 0.0 && f_double_prime.is_finite()) {
1293        return Err(format!(
1294            "empirical latent calibration: non-finite log-space state f={f}, f'={f_prime}, f''={f_double_prime} at intercept={intercept}"
1295        ));
1296    }
1297    Ok((f, f_prime, f_double_prime))
1298}
1299
1300pub(crate) fn empirical_intercept_from_marginal(
1301    target_mu: f64,
1302    target_q: f64,
1303    slope: f64,
1304    probit_scale: f64,
1305    nodes: &[f64],
1306    weights: &[f64],
1307    initial: Option<f64>,
1308) -> Result<f64, String> {
1309    if !(target_mu.is_finite() && target_mu > 0.0 && target_mu < 1.0) {
1310        return Err(format!(
1311            "empirical latent calibration requires target mu in (0,1), got {target_mu}"
1312        ));
1313    }
1314    let log_target_mu = target_mu.ln();
1315    let closed_form_seed = rigid_intercept_from_marginal(target_q, slope, probit_scale);
1316    let seed = initial.unwrap_or(closed_form_seed);
1317    let eval = |a: f64| {
1318        empirical_rigid_calibration_eval(a, log_target_mu, slope, probit_scale, nodes, weights)
1319    };
1320    // Convergence is on the log-space residual |F| = |log Σ wᵢ Φᵢ − log μ★|.
1321    // Near the root this is the relative error in the calibrated probability,
1322    // so 1e-13 in log-space corresponds to absolute residual μ★ · 1e-13 in
1323    // linear space — strictly tighter than the legacy 1e-13 absolute tolerance
1324    // for every μ★ ∈ (0, 1). The 4·ε floor keeps the contract meaningful when
1325    // μ★ approaches 1 (where log Σ Φᵢ approaches 0).
1326    let abs_tol = 1e-13_f64.max(4.0 * f64::EPSILON);
1327    let solve_from = |s: f64| {
1328        crate::monotone_root::solve_monotone_root(
1329            eval,
1330            s,
1331            "empirical latent intercept",
1332            abs_tol,
1333            64,
1334            48,
1335        )
1336        // Enclosing fn emits its own format!() rejection errors as String,
1337        // so the public return type stays Result<_, String>.
1338        .map_err(|e| e.to_string())
1339    };
1340    // A cached warm start can be poisoned across iterations: the per-row
1341    // `intercept_warm_starts` slot is shared by reference across line-search
1342    // trials and across outer-search seed validations, and is written after
1343    // every successful row-solve — including from rejected line-search trials
1344    // whose β/slope was wild. When that stale `a` is paired with the current
1345    // (much smaller) slope, the bracket-by-doubling phase can exhaust its
1346    // budget without crossing zero. Fall back to the deterministic
1347    // closed-form seed, which depends only on the current `(target_q, slope)`
1348    // and is bounded by the analytic rigid-probit geometry, so the cache
1349    // remains a pure speedup that cannot poison correctness.
1350    let (root, _, f_best) = match solve_from(seed) {
1351        Ok(v) => v,
1352        Err(first_err) => {
1353            if seed == closed_form_seed {
1354                return Err(first_err);
1355            }
1356            solve_from(closed_form_seed).map_err(|retry_err| {
1357                format!("{first_err}; closed-form retry from a={closed_form_seed:.6}: {retry_err}")
1358            })?
1359        }
1360    };
1361    if f_best.abs() > abs_tol {
1362        return Err(format!(
1363            "empirical latent intercept solve failed: log-residual={f_best:.3e} at a={root:.6}, target mu={target_mu:.6}"
1364        ));
1365    }
1366    Ok(root)
1367}
1368
1369#[inline]
1370pub(super) fn rigid_standard_normal_neglog_only(
1371    q: f64,
1372    g: f64,
1373    z: f64,
1374    y: f64,
1375    w: f64,
1376    probit_scale: f64,
1377) -> Result<f64, String> {
1378    let s = 2.0 * y - 1.0;
1379    let eta = marginal_slope_standard_normal_scalar_eta(q, g, z, probit_scale);
1380    let m = s * eta;
1381    let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
1382    if !logcdf.is_finite() {
1383        return Err(format!(
1384            "rigid probit neglog_only: non-finite log Φ at q={q}, g={g}, z={z}, y={y}"
1385        ));
1386    }
1387    Ok(-w * logcdf)
1388}
1389
1390/// The rigid standard-normal Bernoulli row negative log-likelihood, written
1391/// ONCE over the generic [`JetScalar`] interface (#932 scalar cutover).
1392///
1393/// Primaries `p = [q_eta = marginal η, g = slope]`. The body is exactly the
1394/// production likelihood — `ℓ = −w·logΦ((2y−1)·η)`, `η = q(η_marg)·√(1+(s·g)²)
1395/// + (s·g)·z` — composed with ONLY [`JetScalar`] ops, so it re-instantiates at
1396/// whatever order / representation a consumer needs:
1397///
1398/// * [`Order2`](super::super::jet_scalar::Order2) → `(v, g, H)`
1399///   ([`rigid_standard_normal_row_kernel`], the inner-Newton path);
1400/// * [`OneSeed`](super::super::jet_scalar::OneSeed) → contracted third
1401///   `Σ_c ℓ_{abc} dir_c` without materialising `t3` (the directional gate);
1402/// * [`TwoSeed`](super::super::jet_scalar::TwoSeed) → contracted fourth
1403///   `Σ_{cd} ℓ_{abcd} u_c v_d` without materialising `t4`;
1404/// * full [`Tower4`] → every uncontracted channel
1405///   ([`rigid_standard_normal_tower`], feeding the `third_full` / `fourth_full`
1406///   caches).
1407///
1408/// Every consumer derives from THIS one expression, so the value channel and
1409/// every derivative channel cannot desync (the #736 / #948 bug genus).
1410///
1411/// The marginal index `q(η_marg)` enters by composing the hand-certified link
1412/// derivative stack `[q, q1, q2, q3, q4]` onto the η primary (slot 0); the
1413/// margin transcendental enters by composing the certified
1414/// [`signed_probit_neglog_unary_stack`] onto the assembled signed margin — the
1415/// stability discipline of #932 (humans own primitive stability, the algebra
1416/// owns combinatorics). The caller MUST guard the signed-margin value against a
1417/// non-finite (non-`+∞`-excluded) NaN before calling; the seeded-evaluation
1418/// wrappers below do that.
1419#[inline]
1420pub(crate) fn rigid_standard_normal_row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
1421    p: &[S; 2],
1422    marginal: BernoulliMarginalLinkMap,
1423    z: f64,
1424    y: f64,
1425    w: f64,
1426    probit_scale: f64,
1427) -> Result<S, String> {
1428    // The order-≤4 signed observed margin `m = (2y−1)·η`, written ONCE in
1429    // `rigid_standard_normal_signed_margin` over `S: JetScalar<2>` and shared
1430    // verbatim with the batched builder's Pass-A jet (#932 single source).
1431    let signed = rigid_standard_normal_signed_margin(p, marginal, z, y, probit_scale);
1432    // Preserve the production fail-fast: a NaN (non-`+∞`) signed margin is an
1433    // upstream domain failure, not a tail saturation.
1434    let m = signed.value();
1435    if !(m.is_finite() || m == f64::INFINITY) {
1436        return Err(format!(
1437            "non-finite signed margin in rigid probit row NLL: {m}"
1438        ));
1439    }
1440    // NLL = −w·logΦ(m) via the fused single-Mills-ratio probit neglog stack.
1441    Ok(signed.compose_unary(signed_probit_neglog_unary_stack(m, w)))
1442}
1443
1444/// The order-≤4 signed observed margin `m = (2y−1)·η` of one rigid
1445/// standard-normal Bernoulli row, written ONCE over `S: JetScalar<2>`:
1446/// `q(η_marg)` composed onto the η primary, observed slope `b = s·g`, scale
1447/// `c = √(1 + b²)`, `η = q·c + b·z`. This is the polynomial part shared by
1448/// every channel consumer — the per-row / contracted / full-tower generic NLL
1449/// ([`rigid_standard_normal_row_nll_generic`]) composes the probit-neglog
1450/// transcendental onto it, and the batched builder's Pass-A jet
1451/// ([`rigid_standard_normal_signed_jet`]) evaluates it at `Tower4<2>` — so the
1452/// signed margin has a single source (#932), with no second hand-packed jet.
1453#[inline]
1454pub(crate) fn rigid_standard_normal_signed_margin<S: gam_math::jet_scalar::JetScalar<2>>(
1455    p: &[S; 2],
1456    marginal: BernoulliMarginalLinkMap,
1457    z: f64,
1458    y: f64,
1459    probit_scale: f64,
1460) -> S {
1461    // q(η_marg): compose the link's q-as-function-of-η stack onto the η primary.
1462    let q = p[0].compose_unary([
1463        marginal.q,
1464        marginal.q1,
1465        marginal.q2,
1466        marginal.q3,
1467        marginal.q4,
1468    ]);
1469    let slope = p[1];
1470    // observed slope b = s·g, scale c = √(1 + b²).
1471    let observed_slope = slope.scale(probit_scale);
1472    let b2 = observed_slope.mul(&observed_slope);
1473    let c = b2.add(&S::constant(1.0)).sqrt();
1474    // η = q·c + (s·g)·z, signed margin m = (2y−1)·η.
1475    let eta = q.mul(&c).add(&observed_slope.scale(z));
1476    eta.scale(2.0 * y - 1.0)
1477}
1478
1479/// One row of rigid standard-normal Bernoulli data as a generic
1480/// [`RowNllProgramGeneric<2>`] (#932 production wiring).
1481///
1482/// This is the genuine production consumer of the generic program seam: the row
1483/// NLL is written ONCE in [`rigid_standard_normal_row_nll_generic`] over
1484/// `S: JetScalar<2>`, and this single-row program routes it through the
1485/// [`gam_math::jet_tower`] `generic_*` evaluators
1486/// ([`generic_full_tower`](gam_math::jet_tower::generic_full_tower) for
1487/// the uncontracted tensors, and the cheap order-2 / contracted scalars for the
1488/// value/grad/Hessian and directional channels). Primaries are
1489/// `[marginal η, slope g]`; the marginal link map and per-row data
1490/// `(z, y, w, probit_scale)` enter as constants on the body.
1491pub(crate) struct RigidStandardNormalRow {
1492    pub(crate) marginal: BernoulliMarginalLinkMap,
1493    pub(crate) g: f64,
1494    pub(crate) z: f64,
1495    pub(crate) y: f64,
1496    pub(crate) w: f64,
1497    pub(crate) probit_scale: f64,
1498}
1499
1500impl gam_math::jet_tower::RowNllProgramGeneric<2> for RigidStandardNormalRow {
1501    fn n_rows(&self) -> usize {
1502        1
1503    }
1504
1505    fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
1506        if row != 0 {
1507            return Err(format!("RigidStandardNormalRow: row {row} out of range"));
1508        }
1509        Ok([self.marginal.eta_value(), self.g])
1510    }
1511
1512    fn row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
1513        &self,
1514        row: usize,
1515        p: &[S; 2],
1516    ) -> Result<S, String> {
1517        if row != 0 {
1518            return Err(format!("RigidStandardNormalRow: row {row} out of range"));
1519        }
1520        rigid_standard_normal_row_nll_generic(
1521            p,
1522            self.marginal,
1523            self.z,
1524            self.y,
1525            self.w,
1526            self.probit_scale,
1527        )
1528    }
1529}
1530
1531#[inline]
1532pub(crate) fn rigid_standard_normal_tower(
1533    marginal: BernoulliMarginalLinkMap,
1534    g: f64,
1535    z: f64,
1536    y: f64,
1537    w: f64,
1538    probit_scale: f64,
1539) -> Result<Tower4<2>, String> {
1540    // #932 cutover: the full uncontracted tower comes from the SAME single
1541    // generic row-NLL expression every other channel consumer derives from,
1542    // routed through the generic program seam evaluated at the all-channels
1543    // `Tower4` scalar. `generic_full_tower` seeds `[marginal η, g]` exactly as
1544    // the previous inline `Tower4::variable` form did, so this is bit-identical
1545    // while giving `RowNllProgramGeneric` a genuine production consumer.
1546    let program = RigidStandardNormalRow {
1547        marginal,
1548        g,
1549        z,
1550        y,
1551        w,
1552        probit_scale,
1553    };
1554    gam_math::jet_tower::generic_full_tower(&program, 0)
1555}
1556
1557/// Branch-free `signed`-margin jet for the rigid standard-normal row kernel.
1558///
1559/// This is the order-≤4 polynomial part of [`rigid_standard_normal_tower`]
1560/// *before* the single transcendental compose: it builds the `Tower4<2>` of the
1561/// signed observed index `signed = (2y−1)·η`, `η = q·c(g) + g·(s·z)`,
1562/// `c(g) = √(1 + (s·g)²)`, with no `erfc`/`exp`/`ln` call. Splitting this off
1563/// lets the batched builder run all the cheap, branch-free jet products in one
1564/// SIMD-friendly pass and isolate the (branchy, transcendental) Mills-ratio
1565/// composition into its own tight pass. The returned jet is the SAME expression
1566/// (`rigid_standard_normal_signed_margin`) the per-row `rigid_standard_normal_tower`
1567/// signed margin evaluates, here at `Tower4<2>` — bit-identical by construction,
1568/// not a parallel hand-packed jet (#932 single source).
1569#[inline]
1570fn rigid_standard_normal_signed_jet(
1571    marginal: BernoulliMarginalLinkMap,
1572    g: f64,
1573    z: f64,
1574    y: f64,
1575    probit_scale: f64,
1576) -> Tower4<2> {
1577    // Seed `[marginal η, g]` exactly as the generic program's `primaries()`, then
1578    // evaluate the one shared signed-margin expression at the all-channels scalar.
1579    let p = [
1580        Tower4::<2>::variable(marginal.eta_value(), 0),
1581        Tower4::<2>::variable(g, 1),
1582    ];
1583    rigid_standard_normal_signed_margin(&p, marginal, z, y, probit_scale)
1584}
1585
1586/// Batched, two-pass builder of the rigid standard-normal row `Tower4<2>` jets
1587/// for a contiguous chunk of rows, written for the auto-vectorizer.
1588///
1589/// Per row the production path ([`rigid_standard_normal_tower`]) interleaves
1590/// (1) cheap branch-free jet products to form the `signed` margin jet, (2) ONE
1591/// branchy transcendental (`erfcx`/`exp`/`ln` via
1592/// [`signed_probit_neglog_unary_stack`]) that dominates the per-row scalar-ALU
1593/// budget across all `n ≈ 356k` rows, and (3) the branch-free Faà-di-Bruno
1594/// `compose_unary` tensor assembly. The interleaving keeps the compiler from
1595/// vectorizing the loop body because the transcendental's internal branches sit
1596/// between the two pure-FMA blocks.
1597///
1598/// This builder runs the same work as three *separate* loops over the chunk:
1599///
1600/// * Pass A — build every `signed` jet (branch-free, [`rigid_standard_normal_signed_jet`]),
1601///   spilling `signed.v` into a contiguous `margins` scratch buffer.
1602/// * Pass B — fill the per-row unary derivative stack `[d0..d4]` from
1603///   `margins`/`weights` (the transcendental, now back-to-back over a flat
1604///   `&[f64]` so branch prediction and the polynomial `k1..k4` portion stream).
1605/// * Pass C — `compose_unary` each `signed` jet against its stack (branch-free,
1606///   pure FMA over the dense tensors → the vectorizable hot block).
1607///
1608/// Every scalar operation, and its order, is identical to the per-row path, so
1609/// the produced jets are bit-for-bit equal; the win is making the n-row build
1610/// memory-bandwidth-bound rather than scalar-ALU/branch-bound. The `fill`
1611/// callback writes the consumer's per-row payload (e.g. `.t3` or `.t4`) from the
1612/// finished jet, so neither tensor is materialized into an intermediate `Vec`.
1613#[inline]
1614pub(super) fn rigid_standard_normal_towers_batch<T>(
1615    marginals: &[BernoulliMarginalLinkMap],
1616    slopes: &[f64],
1617    zs: &[f64],
1618    ys: &[f64],
1619    weights: &[f64],
1620    probit_scale: f64,
1621    out: &mut [T],
1622    mut fill: impl FnMut(&Tower4<2>) -> Result<T, String>,
1623) -> Result<(), String> {
1624    let chunk = marginals.len();
1625    if slopes.len() != chunk
1626        || zs.len() != chunk
1627        || ys.len() != chunk
1628        || weights.len() != chunk
1629        || out.len() != chunk
1630    {
1631        return Err(format!(
1632            "rigid_standard_normal_towers_batch length mismatch: marginals={chunk}, \
1633             slopes={}, zs={}, ys={}, weights={}, out={}",
1634            slopes.len(),
1635            zs.len(),
1636            ys.len(),
1637            weights.len(),
1638            out.len()
1639        ));
1640    }
1641
1642    // Pass A: branch-free signed-margin jets + flat margin scratch.
1643    let mut signed: Vec<Tower4<2>> = Vec::with_capacity(chunk);
1644    let mut margins: Vec<f64> = Vec::with_capacity(chunk);
1645    for i in 0..chunk {
1646        let jet =
1647            rigid_standard_normal_signed_jet(marginals[i], slopes[i], zs[i], ys[i], probit_scale);
1648        margins.push(jet.v);
1649        signed.push(jet);
1650    }
1651
1652    // Pass B: the transcendental, isolated over a flat margin slice. Each entry
1653    // is the exact `[d0..d4]` `compose_unary` consumes; the production path's
1654    // fail-fast on a non-finite (non-`+∞`) margin is preserved here.
1655    let mut stacks: Vec<[f64; 5]> = Vec::with_capacity(chunk);
1656    for i in 0..chunk {
1657        let m = margins[i];
1658        if !(m.is_finite() || m == f64::INFINITY) {
1659            return Err(format!(
1660                "non-finite signed margin in rigid probit tower batch: {m}"
1661            ));
1662        }
1663        stacks.push(signed_probit_neglog_unary_stack(m, weights[i]));
1664    }
1665
1666    // Pass C: branch-free dense compose + consumer fill.
1667    for i in 0..chunk {
1668        let tower = signed[i].compose_unary(stacks[i]);
1669        out[i] = fill(&tower)?;
1670    }
1671    Ok(())
1672}
1673
1674#[inline]
1675pub(super) fn rigid_standard_normal_row_kernel(
1676    marginal: BernoulliMarginalLinkMap,
1677    g: f64,
1678    z: f64,
1679    y: f64,
1680    w: f64,
1681    probit_scale: f64,
1682) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
1683    // #932 cutover: value/gradient/Hessian derive from the SAME single generic
1684    // row-NLL expression (`rigid_standard_normal_row_nll_generic`) every other
1685    // channel consumer uses, routed through the `RowNllProgramGeneric` seam at the
1686    // packed `Order2<2>` scalar — there is no longer a hand-assembled `Tower2<2>`
1687    // here. Seeds `[marginal η, g]` exactly as the deleted inline form did, so it
1688    // is bit-identical (the `rigid_bernoulli_*_agrees_with_jet_tower_program_all_channels`
1689    // oracle pins v/g/H ≤ 1e-12), while sharing one definition with the third/
1690    // fourth/full-tower channels.
1691    let program = RigidStandardNormalRow {
1692        marginal,
1693        g,
1694        z,
1695        y,
1696        w,
1697        probit_scale,
1698    };
1699    gam_math::jet_tower::generic_row_kernel(&program, 0)
1700}
1701
1702/// Mixed `(primary, z)` second derivative of the rigid standard-normal row
1703/// LOG-LIKELIHOOD score: the per-row 2-vector
1704/// `[∂²(log L)/∂q∂z, ∂²(log L)/∂g∂z]` in the primary coordinates `(q = marginal η,
1705/// g = slope)`, evaluated at this row's converged `(q, g)` and calibrated
1706/// latent score `z = ζ`.
1707///
1708/// SIGN CONVENTION (#1131). This returns the mixed partial of the
1709/// LOG-LIKELIHOOD score `score_β,i = ∂(log L_i)/∂β`, NOT of the negative
1710/// log-likelihood `ℓ = −log L`. Concretely the row jet evaluates the NLL
1711/// `ℓ = −w·log Φ(sign·η)` and we NEGATE its mixed `(primary, z)` Hessian entries,
1712/// so the returned 2-vector is `+∂²(log L_i)/∂(q,g)∂ζ_i = −∂²ℓ_i/∂(q,g)∂ζ_i`.
1713/// This is the convention under which the Murphy–Topel chain
1714/// `G = Σ_i s_i·(∂ζ_i/∂θ₁)` with `s_i = ∂score_β,i/∂ζ_i` and `Vb = H_β⁻¹`
1715/// (the NLL-Hessian inverse) gives the SIGNED sensitivity with the right sign:
1716/// the implicit-function theorem on the stationarity `∂(log L)/∂β = 0` yields
1717/// `∂β̂/∂θ₁ = −(∂²log L/∂β²)⁻¹·∂²(log L)/∂β∂θ₁ = +H_β⁻¹·G = +Vb·G`. (Had we
1718/// returned the NLL mixed partial instead, `Vb·G` would equal `−∂β̂/∂θ₁` — a
1719/// benign sign flip for the PSD quadratic SE `(Vb·G)V₁(Vb·G)ᵀ`, but wrong for
1720/// any signed consumer of the sensitivity.)
1721///
1722/// This is the #1028 Murphy–Topel generated-regressor channel: `score_β,i =
1723/// ∂(log L_i)/∂β = J_iᵀ·(∂(log L_i)/∂(q,g))`, so the per-row slope-score
1724/// sensitivity to the calibrated score is
1725/// `s_i = ∂score_β,i/∂ζ_i = J_iᵀ·(∂²(log L_i)/∂(q,g)∂ζ_i)`, and the primary
1726/// 2-vector returned here is exactly `∂²(log L_i)/∂(q,g)∂ζ_i`. The block-level
1727/// contraction `J_iᵀ` (marginal+logslope design rows) is applied by the caller.
1728///
1729/// It is computed by seeding `z` as a THIRD jet variable (index 2) in the SAME
1730/// order-≤2 jet algebra the value/gradient/Hessian path uses, carried by the
1731/// packed `Order2<3>`/`Tower2<3>` scalar rather than a dense `Tower4<3>`
1732/// (#932 row-jet machinery, packed-scalar perf cutover): the
1733/// rigid standard-normal observed index is `η = q·c(g) + g·(s·z)` with
1734/// `c(g) = √(1 + (s·g)²)`, `s = probit_scale`, and `ℓ = −w·log Φ(sign·η)`. The
1735/// converged-frame mixed partials of the NLL are the off-diagonal Hessian
1736/// entries `tower.h[q][z]` and `tower.h[g][z]`, read off in one composition and
1737/// NEGATED to the log-likelihood-score convention — the only extra cost over the
1738/// production `Tower4<2>` evaluation is the third jet axis.
1739#[inline]
1740pub(super) fn rigid_standard_normal_mixed_z_sensitivity(
1741    marginal: BernoulliMarginalLinkMap,
1742    g: f64,
1743    z: f64,
1744    y: f64,
1745    w: f64,
1746    probit_scale: f64,
1747) -> Result<[f64; 2], String> {
1748    // Three jet axes: q = marginal η (0), g = slope (1), z = latent score (2).
1749    //
1750    // #932 perf: this consumer reads ONLY the two mixed Hessian channels
1751    // `h[0][2]`/`h[1][2]`, so it needs only the value/gradient/Hessian stack —
1752    // the packed `Order2<3>` scalar (operating on its inner `Tower2<3>`), NOT a
1753    // dense `Tower4<3>` that would materialise the unused `K³`/`K⁴` `t3`/`t4`
1754    // tensors. The order-≤2 channels are bit-identical to the dense tower
1755    // (`Tower2::mul`/`compose_unary` match `Tower4` term-for-term), so the read
1756    // entries are unchanged; the `q3`/`q4` marginal-link channels are dropped
1757    // because no order-≤2 channel of the composed jet reads them.
1758    use gam_math::jet_tower::Tower2;
1759    let mut q = Tower2::<3>::constant(marginal.q);
1760    q.g[0] = marginal.q1;
1761    q.h[0][0] = marginal.q2;
1762    let slope = Tower2::<3>::variable(g, 1);
1763    let z_var = Tower2::<3>::variable(z, 2);
1764    let observed_logslope = slope * probit_scale;
1765    let c = (observed_logslope * observed_logslope + 1.0).sqrt();
1766    // η = q·c + g·(s·z): z enters linearly through the slope×z product, so the
1767    // mixed (q,z)/(g,z) curvature is carried entirely by the unary NLL chain and
1768    // the η-bilinear, exactly as in the Tower4<2> production path.
1769    let eta = q * c + slope * (z_var * probit_scale);
1770    let signed = eta * (2.0 * y - 1.0);
1771    // ONE transcendental per row (see `rigid_standard_normal_tower`).
1772    if !(signed.v.is_finite() || signed.v == f64::INFINITY) {
1773        return Err(format!(
1774            "rigid probit mixed-z sensitivity: non-finite signed margin {} at q={}, g={g}, z={z}, y={y}",
1775            signed.v, marginal.q
1776        ));
1777    }
1778    let stack = signed_probit_neglog_unary_stack(signed.v, w);
1779    if !stack[0].is_finite() {
1780        return Err(format!(
1781            "rigid probit mixed-z sensitivity: non-finite log Φ at q={}, g={g}, z={z}, y={y}",
1782            marginal.q
1783        ));
1784    }
1785    // Order-≤2 composition consumes only the leading `[f, f', f'']` of the
1786    // certified `[f64; 5]` derivative stack.
1787    let tower = signed.compose_unary([stack[0], stack[1], stack[2]]);
1788    // #1131: `tower` is the NLL `ℓ = −w·log Φ`, so `tower.h[·][z]` is the mixed
1789    // partial of the NLL. Negate to the LOG-LIKELIHOOD-score convention
1790    // `s = ∂²(log L)/∂(primary)∂z = −∂²ℓ/∂(primary)∂z`, under which the
1791    // downstream Murphy–Topel chain `Vb·G = +∂β̂/∂θ₁` carries the correct sign
1792    // (see the function doc). The SE is the PSD quadratic `(Vb·G)V₁(Vb·G)ᵀ` and
1793    // is invariant to this sign, so the reported standard errors are unchanged.
1794    let s_q = -tower.h[0][2];
1795    let s_g = -tower.h[1][2];
1796    if !(s_q.is_finite() && s_g.is_finite()) {
1797        return Err(format!(
1798            "rigid probit mixed-z sensitivity: non-finite ∂²(log L)/∂(q,g)∂z = [{s_q}, {s_g}] at q={}, g={g}, z={z}",
1799            marginal.q
1800        ));
1801    }
1802    Ok([s_q, s_g])
1803}
1804
1805/// Assemble the #1028 Murphy–Topel slope-score sensitivity matrix
1806/// `score_zeta_sensitivity` (`n × p_β`, row `i` = `s_i = ∂score_β,i/∂ζ_i`) for
1807/// the rigid standard-normal BMS kernel — the kernel the conditional
1808/// location-scale gate ALWAYS selects (`LatentMeasureKind::StandardNormal`).
1809///
1810/// where `s_i = ∂score_β,i/∂ζ_i` is the LOG-LIKELIHOOD-score sensitivity (see
1811/// the sign convention in [`rigid_standard_normal_mixed_z_sensitivity`], #1131).
1812/// For each row `i` the primary 2-vector `∂²(log L_i)/∂(q,g)∂ζ_i` is read off the
1813/// z-augmented row jet ([`rigid_standard_normal_mixed_z_sensitivity`]) at the
1814/// converged marginal index `q_i` (`marginal_eta[i]`) and slope `g_i`
1815/// (`slope_eta[i]`) and calibrated score `ζ_i` (`z[i]`), then contracted through
1816/// the block Jacobian `J_iᵀ` (the same marginal+logslope design-row scatter the
1817/// row kernel exposes via `jacobian_transpose_action`):
1818///
1819/// ```text
1820///   s_i[marginal_range]  = (∂²(log L_i)/∂q∂ζ_i) · marginal_design.row(i)
1821///   s_i[logslope_range]  = (∂²(log L_i)/∂g∂ζ_i) · logslope_design.row(i)
1822/// ```
1823///
1824/// `logslope_design` MUST be the reduced-basis design `G·T` actually fitted
1825/// (so `p_β = p_marginal + r` matches the reduced-frame `covariance_conditional`
1826/// the correction inflates). The aux deviation blocks (score_warp / link_dev),
1827/// when present, occupy the trailing columns of `p_beta` and are left zero here:
1828/// the rigid standard-normal kernel carries no deviation z-dependence, and the
1829/// conditional gate's canonical (non-flex) kernel has no such blocks — the
1830/// caller wires the correction only when `p_beta == p_marginal + p_logslope`.
1831pub(super) fn rigid_standard_normal_score_zeta_sensitivity(
1832    base_link: &InverseLink,
1833    marginal_eta: &Array1<f64>,
1834    slope_eta: &Array1<f64>,
1835    z: &Array1<f64>,
1836    y: &Array1<f64>,
1837    weights: &Array1<f64>,
1838    probit_scale: f64,
1839    marginal_design: ArrayView2<'_, f64>,
1840    logslope_design: ArrayView2<'_, f64>,
1841    p_beta: usize,
1842) -> Result<Array2<f64>, String> {
1843    let n = marginal_eta.len();
1844    let p_m = marginal_design.ncols();
1845    let r = logslope_design.ncols();
1846    if slope_eta.len() != n
1847        || z.len() != n
1848        || y.len() != n
1849        || weights.len() != n
1850        || marginal_design.nrows() != n
1851        || logslope_design.nrows() != n
1852    {
1853        return Err(format!(
1854            "score_zeta_sensitivity row mismatch: marginal_eta={n}, slope_eta={}, z={}, y={}, \
1855             weights={}, marginal_design rows={}, logslope_design rows={}",
1856            slope_eta.len(),
1857            z.len(),
1858            y.len(),
1859            weights.len(),
1860            marginal_design.nrows(),
1861            logslope_design.nrows()
1862        ));
1863    }
1864    if p_m + r > p_beta {
1865        return Err(format!(
1866            "score_zeta_sensitivity width overflow: marginal({p_m}) + logslope({r}) > p_beta({p_beta})"
1867        ));
1868    }
1869    let mut s = Array2::<f64>::zeros((n, p_beta));
1870    for i in 0..n {
1871        let marginal = bernoulli_marginal_link_map(base_link, marginal_eta[i])?;
1872        let [s_q, s_g] = rigid_standard_normal_mixed_z_sensitivity(
1873            marginal,
1874            slope_eta[i],
1875            z[i],
1876            y[i],
1877            weights[i],
1878            probit_scale,
1879        )?;
1880        // J_iᵀ scatter into the reduced-frame coordinates: marginal block first,
1881        // then the reduced logslope block.
1882        if s_q != 0.0 {
1883            let m_row = marginal_design.row(i);
1884            for (j, &mij) in m_row.iter().enumerate() {
1885                s[[i, j]] = s_q * mij;
1886            }
1887        }
1888        if s_g != 0.0 {
1889            let g_row = logslope_design.row(i);
1890            for (j, &gij) in g_row.iter().enumerate() {
1891                s[[i, p_m + j]] = s_g * gij;
1892            }
1893        }
1894    }
1895    Ok(s)
1896}
1897
1898#[inline]
1899pub(super) fn rigid_standard_normal_third_full(
1900    marginal: BernoulliMarginalLinkMap,
1901    g: f64,
1902    z: f64,
1903    y: f64,
1904    w: f64,
1905    probit_scale: f64,
1906) -> Result<[[[f64; 2]; 2]; 2], String> {
1907    Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t3)
1908}
1909
1910/// Contract a symmetric 3-tensor on its third index with a primary-space
1911/// direction `d = (d_eta, d_g)`, producing the symmetric 2×2 contracted
1912/// matrix the outer-derivative pipeline consumes:
1913///   `M[a][b] = Σ_c T[a][b][c] · d[c]`.
1914#[inline]
1915pub(super) fn contract_third_full(t: &[[[f64; 2]; 2]; 2], d_eta: f64, d_g: f64) -> [[f64; 2]; 2] {
1916    [
1917        [
1918            t[0][0][0] * d_eta + t[0][0][1] * d_g,
1919            t[0][1][0] * d_eta + t[0][1][1] * d_g,
1920        ],
1921        [
1922            t[1][0][0] * d_eta + t[1][0][1] * d_g,
1923            t[1][1][0] * d_eta + t[1][1][1] * d_g,
1924        ],
1925    ]
1926}
1927
1928#[inline]
1929pub(super) fn rigid_standard_normal_fourth_full(
1930    marginal: BernoulliMarginalLinkMap,
1931    g: f64,
1932    z: f64,
1933    y: f64,
1934    w: f64,
1935    probit_scale: f64,
1936) -> Result<[[[[f64; 2]; 2]; 2]; 2], String> {
1937    // #932 single-sourcing: the full uncontracted fourth-order primary tensor is
1938    // the `.t4` channel of the SAME `Tower4<2>` row jet the value/gradient/Hessian
1939    // and the third-order tensor (`rigid_standard_normal_third_full` → `.t3`) are
1940    // read from. The marginal latent-coordinate chain `q(η)` is already seeded
1941    // into axis 0 of the tower (`q.g[0]=q1, q.h[0][0]=q2, q.t3[0][0][0]=q3,
1942    // q.t4[0][0][0][0]=q4` in `rigid_standard_normal_signed_jet`), so `.t4` is
1943    // delivered directly in the production `(η, g)` primary space — no separate
1944    // Faà-di-Bruno q-chain reassembly. This replaces the former hand-written
1945    // fourth-derivative chain rule with the mechanically-derived tower output,
1946    // exactly mirroring how `.t3` is consumed;
1947    // it is cross-checked term-for-term against the independent
1948    // `HandRigidProbitKernel` witness in
1949    // `rigid_standard_normal_tower_path_matches_hand_chain_witness`.
1950    Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t4)
1951}
1952
1953/// Combined uncontracted THIRD **and** FOURTH primary tensors for one rigid
1954/// standard-normal row, read off a SINGLE shared `Tower4<2>` jet.
1955///
1956/// `rigid_standard_normal_third_full` (→ `.t3`) and
1957/// `rigid_standard_normal_fourth_full` (→ `.t4`) each build a full
1958/// `rigid_standard_normal_tower` and discard the OTHER tensor — so a consumer
1959/// that needs both for the same `(row, β)` point (the outer Jeffreys/REML
1960/// derivative path warms both the `rigid_third_full` and `rigid_fourth_full`
1961/// caches in the same fit; see the paired `rigid_{third,fourth}_full_cached`
1962/// warm-up) pays the per-row Mills-ratio transcendental
1963/// (`signed_probit_neglog_unary_stack`, ~88% of the per-row scalar cost) TWICE
1964/// where ONCE suffices. The two tensors are the `.t3` / `.t4` channels of the
1965/// same tower, so this builder evaluates that tower ONCE and returns both.
1966///
1967/// Contract a symmetric 4-tensor on its last two indices with two
1968/// primary-space directions `u = (u_eta, u_g)` and `v = (v_eta, v_g)`,
1969/// producing the symmetric 2×2 matrix the outer-Hessian pipeline expects:
1970///   `M[a][b] = Σ_{c,d} T[a][b][c][d] · u[c] · v[d]`.
1971#[inline]
1972pub(super) fn contract_fourth_full(
1973    t: &[[[[f64; 2]; 2]; 2]; 2],
1974    u_eta: f64,
1975    u_g: f64,
1976    v_eta: f64,
1977    v_g: f64,
1978) -> [[f64; 2]; 2] {
1979    let mut out = [[0.0; 2]; 2];
1980    for a in 0..2 {
1981        for b in 0..2 {
1982            let mut sum = 0.0;
1983            sum += t[a][b][0][0] * u_eta * v_eta;
1984            sum += t[a][b][0][1] * u_eta * v_g;
1985            sum += t[a][b][1][0] * u_g * v_eta;
1986            sum += t[a][b][1][1] * u_g * v_g;
1987            out[a][b] = sum;
1988        }
1989    }
1990    out
1991}
1992
1993pub(super) fn ensure_finite_third_full_cache_row(
1994    t: &[[[f64; 2]; 2]; 2],
1995    context: &str,
1996) -> Result<(), String> {
1997    if t.iter().flatten().flatten().all(|value| value.is_finite()) {
1998        Ok(())
1999    } else {
2000        Err(format!(
2001            "{context}: warmed third-derivative cache row contains a non-finite value"
2002        ))
2003    }
2004}
2005
2006pub(super) fn ensure_finite_fourth_full_cache_row(
2007    t: &[[[[f64; 2]; 2]; 2]; 2],
2008    context: &str,
2009) -> Result<(), String> {
2010    if t.iter()
2011        .flatten()
2012        .flatten()
2013        .flatten()
2014        .all(|value| value.is_finite())
2015    {
2016        Ok(())
2017    } else {
2018        Err(format!(
2019            "{context}: warmed fourth-derivative cache row contains a non-finite value"
2020        ))
2021    }
2022}
2023
2024pub(crate) fn unary_derivatives_sqrt(x: f64) -> [f64; 5] {
2025    let s = x.max(1e-300).sqrt();
2026    let x1 = x.max(1e-300);
2027    let x2 = x1 * x1;
2028    let x3 = x2 * x1;
2029    [
2030        s,
2031        0.5 / s,
2032        -0.25 / (x1 * s),
2033        3.0 / (8.0 * x2 * s),
2034        -15.0 / (16.0 * x3 * s),
2035    ]
2036}
2037pub(crate) fn unary_derivatives_neglog_phi(x: f64, weight: f64) -> [f64; 5] {
2038    // Single source of truth for the signed-probit value+derivative stack:
2039    // one Mills-ratio transcendental feeds both logΦ and k1..k4 (the prior
2040    // body evaluated `signed_probit_logcdf_and_mills_ratio` twice). The
2041    // ±∞/NaN/zero-weight boundary limits are handled identically inside.
2042    signed_probit_neglog_unary_stack(x, weight)
2043}
2044
2045/// Derivatives of `log(x)` through 4th order.
2046///
2047/// # Contract
2048///
2049/// `x` must be strictly positive. `log` and its derivatives are undefined at
2050/// and below the boundary, so this function does NOT clamp: a previous version
2051/// silently replaced `x` by `x.max(1e-300)`, which fabricated enormous finite
2052/// derivatives (`1/1e-300` etc.) that are the derivatives of neither `log(x)`
2053/// nor `log(max(x, floor))`. Such a non-positive argument signals an upstream
2054/// domain failure (e.g. a monotonicity violation) that must surface, not be
2055/// masked. Every caller guarantees `x > 0` before invoking this:
2056/// the survival marginal-slope kernels evaluate `log` of the transformed time
2057/// derivative `q'(t)·√(1+b²)` only after passing `survival_derivative_guard`
2058/// (`q'(t) >= derivative_guard > 0`, `√(1+b²) > 0`). A non-positive `x`
2059/// therefore never reaches here on any supported path; were one to, the
2060/// function returns the honest IEEE result (`-inf`/`NaN`) — identical in debug
2061/// and release — rather than a finite fabrication.
2062pub(crate) fn unary_derivatives_log(x: f64) -> [f64; 5] {
2063    let x2 = x * x;
2064    let x3 = x2 * x;
2065    let x4 = x3 * x;
2066    [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
2067}
2068
2069/// Derivatives of log φ(x) = -½x² - ½ln(2π) through 4th order.
2070pub(crate) fn unary_derivatives_log_normal_pdf(x: f64) -> [f64; 5] {
2071    let c = 0.5 * (2.0 * std::f64::consts::PI).ln();
2072    [-0.5 * x * x - c, -x, -1.0, 0.0, 0.0]
2073}
2074
2075#[cfg(test)]
2076mod jet_tower_oracle_tests {
2077    //! #932 deployment step 2 for the BMS rigid Bernoulli `RowKernel<2>`.
2078    //!
2079    //! The production rigid standard-normal row kernel
2080    //! ([`rigid_standard_normal_row_kernel`] / `_third_full` / `_fourth_full`)
2081    //! reads value/grad/Hessian/third/fourth straight off ONE
2082    //! [`rigid_standard_normal_tower`] `Tower4<2>` — the strongest #932 form,
2083    //! where the production kernel literally *is* the single-expression jet.
2084    //! What was missing (unlike the two survival `RowKernel` families, which
2085    //! already carry `verify_kernel_channels` oracles) is an INDEPENDENT
2086    //! cross-check that this production tower is correct. This module adds it:
2087    //!
2088    //! * an independent [`RowNllProgram<2>`] that writes the row NLL
2089    //!   `ℓ = −w·logΦ((2y−1)·η)`, `η = q·√(1+(s·g)²) + s·g·z` ONCE over generic
2090    //!   `Tower4` arithmetic (a different composition order than the fused
2091    //!   production `signed` jet → exercises the Leibniz/Faà-di-Bruno layer
2092    //!   where the #736 cross-block sign-flip bug genus lives), and
2093    //! * a special-function-independent central-FD witness of the value channel
2094    //!   that re-derives `logΦ` from `libm::erfc`, pinning the probit derivative
2095    //!   stack itself (so the oracle does not merely re-use the production
2096    //!   transcendental).
2097
2098    use super::*;
2099
2100    /// #932 combined third+fourth primary tensors read off ONE shared
2101    /// `rigid_standard_normal_tower` jet (the redundancy-free form of the
2102    /// separate `_third_full` / `_fourth_full` builds, bit-identical to them).
2103    /// Lives in this `#[cfg(test)]` module — its only consumers are the
2104    /// bit-identity checks below — so it is not a production `src` item with no
2105    /// production caller (production reads the separate builders) and is not dead
2106    /// code in the non-test lib build.
2107    fn rigid_standard_normal_third_and_fourth_full(
2108        marginal: BernoulliMarginalLinkMap,
2109        g: f64,
2110        z: f64,
2111        y: f64,
2112        w: f64,
2113        probit_scale: f64,
2114    ) -> Result<([[[f64; 2]; 2]; 2], [[[[f64; 2]; 2]; 2]; 2]), String> {
2115        let tower = rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?;
2116        Ok((tower.t3, tower.t4))
2117    }
2118    use gam_math::jet_tower::{
2119        KernelChannels, RowNllProgram, evaluate_program, verify_kernel_channels,
2120    };
2121
2122    /// Independent single-expression row NLL for the rigid standard-normal
2123    /// Bernoulli kernel, primaries `(q_eta = marginal η, g = slope)`.
2124    struct BernoulliRigidStandardNormalNllProgram {
2125        /// `(marginal η, slope g)` per row.
2126        primaries: Vec<[f64; 2]>,
2127        /// Per-row `(z latent score, y in {0,1}, w weight)`.
2128        z: Vec<f64>,
2129        y: Vec<f64>,
2130        w: Vec<f64>,
2131        probit_scale: f64,
2132    }
2133
2134    impl RowNllProgram<2> for BernoulliRigidStandardNormalNllProgram {
2135        fn n_rows(&self) -> usize {
2136            self.primaries.len()
2137        }
2138
2139        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
2140            self.primaries
2141                .get(row)
2142                .copied()
2143                .ok_or_else(|| format!("bernoulli rigid nll program: row {row} out of range"))
2144        }
2145
2146        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
2147            let z = self.z[row];
2148            let y = self.y[row];
2149            let w = self.w[row];
2150            let s = self.probit_scale;
2151            // q(η) via the family's own marginal link-map derivative stack,
2152            // composed through generic Leibniz on the η primary (independent of
2153            // the production signed-jet, which seeds the q tensor slots directly).
2154            let eta_marginal = p[0];
2155            let link = bernoulli_marginal_link_map(
2156                &InverseLink::Standard(gam_problem::StandardLink::Probit),
2157                eta_marginal.v,
2158            )?;
2159            let q = eta_marginal.compose_unary([link.q, link.q1, link.q2, link.q3, link.q4]);
2160            let g = p[1];
2161            // observed slope b = s·g, scale c = √(1 + b²).
2162            let observed_slope = g * s;
2163            let c = (observed_slope * observed_slope + 1.0).compose_unary(unary_derivatives_sqrt(
2164                observed_slope.v * observed_slope.v + 1.0,
2165            ));
2166            // η = q·c + b·z, signed margin m = (2y−1)·η.
2167            let eta = q * c + observed_slope * z;
2168            let signed = eta * (2.0 * y - 1.0);
2169            // NLL = −w·logΦ(m) via the documented probit neglog stack.
2170            Ok(signed.compose_unary(unary_derivatives_neglog_phi(signed.v, w)))
2171        }
2172    }
2173
2174    /// Special-function-independent scalar row NLL `ℓ(q_eta, g)` using
2175    /// `libm::erfc`, for the central-FD value-channel witness.
2176    fn scalar_nll(eta_marginal: f64, g: f64, z: f64, y: f64, w: f64, s: f64) -> f64 {
2177        let link = bernoulli_marginal_link_map(
2178            &InverseLink::Standard(gam_problem::StandardLink::Probit),
2179            eta_marginal,
2180        )
2181        .unwrap();
2182        let observed_slope = g * s;
2183        let c = (observed_slope * observed_slope + 1.0).sqrt();
2184        let eta = link.q * c + observed_slope * z;
2185        let signed = (2.0 * y - 1.0) * eta;
2186        let cdf = 0.5 * libm::erfc(-signed / std::f64::consts::SQRT_2);
2187        -w * cdf.max(1e-300).ln()
2188    }
2189
2190    #[test]
2191    fn rigid_bernoulli_row_kernel_agrees_with_jet_tower_program_all_channels() {
2192        // Mixed responses, weights, latent scores, and slope regimes; the last
2193        // rows push the marginal index toward the normal tails while staying
2194        // finite. Probit marginal link, standard-normal latent measure.
2195        let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2196        let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2197        let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2198        let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2199        let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2200        let n = eta.len();
2201
2202        // Deterministic direction vectors (no RNG dependency).
2203        let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
2204
2205        for &probit_scale in &[1.0_f64, 0.8] {
2206            let program = BernoulliRigidStandardNormalNllProgram {
2207                primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
2208                z: z.to_vec(),
2209                y: y.to_vec(),
2210                w: w.to_vec(),
2211                probit_scale,
2212            };
2213
2214            for row in 0..n {
2215                let tower = evaluate_program(&program, row).expect("tower evaluation");
2216
2217                // Production scalar kernel channels (the hand path under audit).
2218                let marginal = bernoulli_marginal_link_map(
2219                    &InverseLink::Standard(gam_problem::StandardLink::Probit),
2220                    eta[row],
2221                )
2222                .expect("link map");
2223                let (value, gradient, hessian) = rigid_standard_normal_row_kernel(
2224                    marginal,
2225                    g[row],
2226                    z[row],
2227                    y[row],
2228                    w[row],
2229                    probit_scale,
2230                )
2231                .expect("production row kernel");
2232
2233                // One shared tower for BOTH the third and fourth tensors (the
2234                // #932 transcendental-de-dup builder): this is the redundancy-
2235                // free form of the former two separate
2236                // `rigid_standard_normal_{third,fourth}_full` calls, and it is
2237                // pinned bit-identically against them in
2238                // `rigid_third_and_fourth_full_shares_one_tower_bit_identical`.
2239                let (third_full, fourth_full) = rigid_standard_normal_third_and_fourth_full(
2240                    marginal,
2241                    g[row],
2242                    z[row],
2243                    y[row],
2244                    w[row],
2245                    probit_scale,
2246                )
2247                .expect("production third+fourth");
2248                let third: Vec<([f64; 2], [[f64; 2]; 2])> = dirs
2249                    .iter()
2250                    .map(|d| (*d, contract_third_full(&third_full, d[0], d[1])))
2251                    .collect();
2252
2253                let fourth: Vec<([f64; 2], [f64; 2], [[f64; 2]; 2])> = dirs
2254                    .iter()
2255                    .enumerate()
2256                    .map(|(i, u)| {
2257                        let v = dirs[(i + 1) % dirs.len()];
2258                        (
2259                            *u,
2260                            v,
2261                            contract_fourth_full(&fourth_full, u[0], u[1], v[0], v[1]),
2262                        )
2263                    })
2264                    .collect();
2265
2266                let claims = KernelChannels {
2267                    value,
2268                    gradient,
2269                    hessian,
2270                    third,
2271                    fourth,
2272                };
2273
2274                verify_kernel_channels(&tower, &claims, 1e-9).unwrap_or_else(|e| {
2275                    panic!(
2276                        "probit_scale {probit_scale} row {row}: production rigid Bernoulli \
2277                         RowKernel disagrees with #932 jet-tower truth: {e}"
2278                    )
2279                });
2280
2281                // Special-function-independent FD witness of the value channel:
2282                // re-derives logΦ from `libm::erfc`, pinning the probit derivative
2283                // stack rather than re-using the production one.
2284                let h = 1e-3;
2285                let f = |de: f64, dg: f64| {
2286                    scalar_nll(
2287                        eta[row] + de,
2288                        g[row] + dg,
2289                        z[row],
2290                        y[row],
2291                        w[row],
2292                        probit_scale,
2293                    )
2294                };
2295                let f0 = f(0.0, 0.0);
2296                assert!(
2297                    (f0 - tower.v).abs() <= 1e-9 * f0.abs().max(1.0),
2298                    "row {row}: independent scalar NLL {f0:+.12e} != tower value {:+.12e}",
2299                    tower.v
2300                );
2301                // 5-point first-derivative stencils.
2302                let g_eta = (f(-2.0 * h, 0.0) - 8.0 * f(-h, 0.0) + 8.0 * f(h, 0.0)
2303                    - f(2.0 * h, 0.0))
2304                    / (12.0 * h);
2305                let g_g = (f(0.0, -2.0 * h) - 8.0 * f(0.0, -h) + 8.0 * f(0.0, h) - f(0.0, 2.0 * h))
2306                    / (12.0 * h);
2307                for (label, fd, ad) in [("∂η", g_eta, tower.g[0]), ("∂g", g_g, tower.g[1])] {
2308                    assert!(
2309                        (fd - ad).abs() <= 1e-5 * ad.abs().max(1.0),
2310                        "row {row} {label}: FD witness {fd:+.6e} != tower grad {ad:+.6e}"
2311                    );
2312                }
2313            }
2314        }
2315    }
2316
2317    /// #932 transcendental de-duplication: the combined
2318    /// [`rigid_standard_normal_third_and_fourth_full`] builder reads BOTH the
2319    /// third and fourth uncontracted tensors off ONE shared
2320    /// `rigid_standard_normal_tower` (one Mills-ratio transcendental per row),
2321    /// and must be BIT-IDENTICAL to the two separate single-tensor builders
2322    /// (`rigid_standard_normal_third_full` + `rigid_standard_normal_fourth_full`,
2323    /// two transcendentals). This pins the exactness of the redundancy
2324    /// elimination: `==`, max diff exactly 0.0 — same tower, no accuracy or
2325    /// generality change, only the redundant second transcendental removed.
2326    #[test]
2327    fn rigid_third_and_fourth_full_shares_one_tower_bit_identical() {
2328        let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2329        let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2330        let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2331        let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2332        let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2333        for &probit_scale in &[1.0_f64, 0.8] {
2334            for r in 0..eta.len() {
2335                let marginal = bernoulli_marginal_link_map(
2336                    &InverseLink::Standard(gam_problem::StandardLink::Probit),
2337                    eta[r],
2338                )
2339                .expect("link map");
2340                let t3_sep = rigid_standard_normal_third_full(
2341                    marginal,
2342                    g[r],
2343                    z[r],
2344                    y[r],
2345                    w[r],
2346                    probit_scale,
2347                )
2348                .expect("separate third");
2349                let t4_sep = rigid_standard_normal_fourth_full(
2350                    marginal,
2351                    g[r],
2352                    z[r],
2353                    y[r],
2354                    w[r],
2355                    probit_scale,
2356                )
2357                .expect("separate fourth");
2358                let (t3_comb, t4_comb) = rigid_standard_normal_third_and_fourth_full(
2359                    marginal,
2360                    g[r],
2361                    z[r],
2362                    y[r],
2363                    w[r],
2364                    probit_scale,
2365                )
2366                .expect("combined third+fourth");
2367                // Exact bitwise equality (same tower) — no tolerance.
2368                for a in 0..2 {
2369                    for b in 0..2 {
2370                        for c in 0..2 {
2371                            assert_eq!(
2372                                t3_comb[a][b][c], t3_sep[a][b][c],
2373                                "t3[{a}][{b}][{c}] row {r} scale {probit_scale} not bit-identical"
2374                            );
2375                            for d in 0..2 {
2376                                assert_eq!(
2377                                    t4_comb[a][b][c][d], t4_sep[a][b][c][d],
2378                                    "t4[{a}][{b}][{c}][{d}] row {r} scale {probit_scale} not bit-identical"
2379                                );
2380                            }
2381                        }
2382                    }
2383                }
2384            }
2385        }
2386    }
2387
2388    /// #932 production wiring: the rigid Bernoulli row, routed through the
2389    /// generic [`RowNllProgramGeneric<2>`] program seam and its cheap
2390    /// order-2 / contracted scalar evaluators (`generic_row_kernel`,
2391    /// `generic_third_contracted`, `generic_fourth_contracted`,
2392    /// `generic_full_tower`), must agree BIT-FOR-BIT with the dense
2393    /// `Tower4`-only [`RowNllProgram`] path (`evaluate_program`). Both write the
2394    /// same single-expression NLL — the contracted scalars fold the direction
2395    /// into the differentiation, so this pins that the packed channels equal the
2396    /// corresponding contractions of the dense tower truth, exercising every
2397    /// `generic_*` evaluator end-to-end through a real production consumer.
2398    #[test]
2399    fn rigid_bernoulli_generic_program_matches_tower4_program_all_channels() {
2400        use gam_math::jet_tower::{
2401            generic_fourth_contracted, generic_full_tower, generic_row_kernel,
2402            generic_third_contracted,
2403        };
2404
2405        let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2406        let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2407        let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2408        let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2409        let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2410        let n = eta.len();
2411        let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
2412
2413        let close = |a: f64, b: f64, label: &str| {
2414            let band = 1e-12 + 1e-12 * a.abs().max(b.abs());
2415            assert!(
2416                (a - b).abs() <= band,
2417                "{label}: generic {a:+.15e} vs Tower4-program {b:+.15e} (band {band:.3e})"
2418            );
2419        };
2420
2421        for &probit_scale in &[1.0_f64, 0.8] {
2422            // The dense Tower4-only program over all rows (independent path).
2423            let tower_program = BernoulliRigidStandardNormalNllProgram {
2424                primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
2425                z: z.to_vec(),
2426                y: y.to_vec(),
2427                w: w.to_vec(),
2428                probit_scale,
2429            };
2430
2431            for row in 0..n {
2432                let truth = evaluate_program(&tower_program, row).expect("Tower4 program tower");
2433
2434                let marginal = bernoulli_marginal_link_map(
2435                    &InverseLink::Standard(gam_problem::StandardLink::Probit),
2436                    eta[row],
2437                )
2438                .expect("link map");
2439                let program = RigidStandardNormalRow {
2440                    marginal,
2441                    g: g[row],
2442                    z: z[row],
2443                    y: y[row],
2444                    w: w[row],
2445                    probit_scale,
2446                };
2447
2448                // generic_full_tower must reproduce the dense tower in EVERY
2449                // channel (v, g, H, t3, t4).
2450                let full = generic_full_tower(&program, 0).expect("generic full tower");
2451                close(full.v, truth.v, "full value");
2452                for a in 0..2 {
2453                    close(full.g[a], truth.g[a], "full grad");
2454                    for b in 0..2 {
2455                        close(full.h[a][b], truth.h[a][b], "full hess");
2456                        for c in 0..2 {
2457                            close(full.t3[a][b][c], truth.t3[a][b][c], "full t3");
2458                            for d in 0..2 {
2459                                close(full.t4[a][b][c][d], truth.t4[a][b][c][d], "full t4");
2460                            }
2461                        }
2462                    }
2463                }
2464
2465                // generic_row_kernel (Order2) must equal the tower's (v, g, H).
2466                let (val, grad, hess) =
2467                    generic_row_kernel(&program, 0).expect("generic row kernel");
2468                close(val, truth.v, "order2 value");
2469                for a in 0..2 {
2470                    close(grad[a], truth.g[a], "order2 grad");
2471                    for b in 0..2 {
2472                        close(hess[a][b], truth.h[a][b], "order2 hess");
2473                    }
2474                }
2475
2476                // generic_third_contracted (OneSeed) must equal the dense
2477                // tower's third contraction for each direction.
2478                for dir in &dirs {
2479                    let third = generic_third_contracted(&program, 0, dir)
2480                        .expect("generic third contracted");
2481                    let truth3 = truth.third_contracted(dir);
2482                    for a in 0..2 {
2483                        for b in 0..2 {
2484                            close(third[a][b], truth3[a][b], "third contracted");
2485                        }
2486                    }
2487                }
2488
2489                // generic_fourth_contracted (TwoSeed) must equal the dense
2490                // tower's fourth contraction for each direction pair.
2491                for (i, u) in dirs.iter().enumerate() {
2492                    let v = dirs[(i + 1) % dirs.len()];
2493                    let fourth = generic_fourth_contracted(&program, 0, u, &v)
2494                        .expect("generic fourth contracted");
2495                    let truth4 = truth.fourth_contracted(u, &v);
2496                    for a in 0..2 {
2497                        for b in 0..2 {
2498                            close(fourth[a][b], truth4[a][b], "fourth contracted");
2499                        }
2500                    }
2501                }
2502            }
2503        }
2504    }
2505}
2506
2507#[cfg(test)]
2508mod flex_primary_hessian_oracle_tests {
2509    //! #932 correctness gate for the BMS-FLEX per-row primary Hessian assembled
2510    //! by hand product-rule in
2511    //! [`super::super::row_primary_hessian::BernoulliMarginalSlopeFamily::compute_row_analytic_flex_from_parts_into`]
2512    //! (`f_aa += w·φ·(η_aa − η·η_a·η_a)`, the `f_au`/`f_uv`/`a_uv` chain, and the
2513    //! final `d2_m·η_u·η_v + d1_m·s_y·η_uv` contraction).
2514    //!
2515    //! A prior audit found this hand Hessian had NO INDEPENDENT oracle: the only
2516    //! covering test (`families_bms_joint_hessian_hvp_correction_tests.rs`)
2517    //! asserts batched-vs-nonbatched self-consistency using the SAME hand code on
2518    //! both sides, so a dropped product-rule term would pass undetected. This
2519    //! module closes that gap with a finite-difference witness that NEVER runs the
2520    //! Hessian-assembly branch: it central-differences the flex GRADIENT — which
2521    //! is produced by an entirely separate code path (the `need_hessian = false`
2522    //! value/`eta_u`-scaling lines, none of which read the `f_aa`/`f_au`/`f_uv`
2523    //! product-rule accumulators) — and pins the analytic Hessian against it.
2524    //!
2525    //! The gradient itself is FD-validated transitively: it is the analytic
2526    //! gradient of the same per-row NLL, evaluated at the converged intercept,
2527    //! and the FD perturbation re-solves the intercept root per perturbed point
2528    //! (rebuilding the row context), so the difference quotient is the true
2529    //! mixed/second partial of the row negative log-likelihood — the independent
2530    //! truth the hand Hessian must reproduce.
2531
2532    use super::*;
2533    // `BernoulliMarginalSlopeFamily` (and the flex block-config helpers) live in
2534    // the sibling `super::family` module and are `pub(super)`; this oracle test
2535    // module's `use super::*` does not re-export them, so import the family
2536    // namespace explicitly. Mirrors `cell_moment_assembly.rs`'s
2537    // `use super::family::*`. Without this the flex oracle fixture fails to
2538    // resolve the family type (E0422/E0425/E0433) and blocks the whole lib build.
2539    use super::family::*;
2540    use gam_linalg::matrix::DenseDesignMatrix;
2541    use ndarray::Array1;
2542    use ndarray::Array2;
2543    use std::sync::Arc;
2544    use std::sync::Mutex;
2545
2546    /// Port of the integration-test flex fixture
2547    /// (`make_flex_hvp_cache_test_family`), kept in-crate so the oracle can run
2548    /// without the test crate (the family struct is `pub(super)`). Builds a small
2549    /// flex BMS family with both a score-warp and a link-deviation block so the
2550    /// flex Hessian assembly exercises every primary block (q, logslope, h, w).
2551    fn make_flex_oracle_family(
2552        n: usize,
2553    ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
2554        let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
2555        let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
2556        let cfg = DeviationBlockConfig {
2557            num_internal_knots: 3,
2558            ..DeviationBlockConfig::default()
2559        };
2560        let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
2561            .expect("build score warp block");
2562        let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2563            &link_seed, &link_seed, &cfg,
2564        )
2565        .expect("build link deviation block");
2566
2567        let y: Array1<f64> =
2568            Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
2569        let weights: Array1<f64> =
2570            Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
2571        let z: Array1<f64> =
2572            Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
2573        let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
2574            if j == 0 {
2575                1.0
2576            } else {
2577                -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
2578            }
2579        });
2580        let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
2581            if j == 0 {
2582                1.0
2583            } else {
2584                0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
2585            }
2586        });
2587
2588        let family = BernoulliMarginalSlopeFamily {
2589            y: Arc::new(y),
2590            weights: Arc::new(weights),
2591            z: Arc::new(z.clone()),
2592            latent_measure: LatentMeasureKind::StandardNormal,
2593            gaussian_frailty_sd: Some(0.15),
2594            base_link: InverseLink::Standard(gam_problem::StandardLink::Probit),
2595            marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
2596            logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
2597            score_warp: Some(score_prepared.runtime.clone()),
2598            link_dev: Some(link_prepared.runtime.clone()),
2599            policy: gam_runtime::resource::ResourcePolicy::default_library(),
2600            cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
2601            cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
2602            intercept_warm_starts: None,
2603            auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2604            auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2605        };
2606
2607        let beta_m = Array1::from_vec(vec![0.12, -0.04]);
2608        let beta_g = Array1::from_vec(vec![0.35, 0.03]);
2609        let beta_h = Array1::from_iter(
2610            (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
2611        );
2612        let beta_w = Array1::from_iter(
2613            (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
2614        );
2615        let states = vec![
2616            ParameterBlockState {
2617                eta: marginal_x.dot(&beta_m),
2618                beta: beta_m,
2619            },
2620            ParameterBlockState {
2621                eta: logslope_x.dot(&beta_g),
2622                beta: beta_g,
2623            },
2624            ParameterBlockState {
2625                beta: beta_h,
2626                eta: Array1::zeros(z.len()),
2627            },
2628            ParameterBlockState {
2629                beta: beta_w,
2630                eta: Array1::zeros(z.len()),
2631            },
2632        ];
2633        (family, states)
2634    }
2635
2636    /// The flex primary gradient at a perturbed primary point. Perturbs primary
2637    /// coordinate `u` by `delta` (mutating the relevant block state — the
2638    /// marginal/logslope row η or a deviation β plus its design contribution
2639    /// where applicable), rebuilds the row context FRESH (re-solving the
2640    /// calibration intercept root at the perturbed point), and returns the
2641    /// analytic gradient. The Hessian-assembly branch is never run, so this is a
2642    /// genuinely independent witness for that branch.
2643    fn flex_gradient_at_perturbed(
2644        family: &BernoulliMarginalSlopeFamily,
2645        states: &[ParameterBlockState],
2646        primary: &super::super::hessian_paths::PrimarySlices,
2647        row: usize,
2648        u: usize,
2649        delta: f64,
2650    ) -> Array1<f64> {
2651        let mut states = states.to_vec();
2652        // Map the primary coordinate `u` onto the parameter that controls it.
2653        // q / logslope live in the per-row η of blocks 0 / 1; the deviation
2654        // bases live in the β of blocks 2 (score-warp) / 3 (link-wiggle), which
2655        // the row context reads via `score_beta` / `link_beta` (their η rows are
2656        // unused on the flex per-row path, so only β need move).
2657        if u == primary.q {
2658            states[0].eta[row] += delta;
2659        } else if u == primary.logslope {
2660            states[1].eta[row] += delta;
2661        } else if let Some(h_range) = primary.h.as_ref()
2662            && h_range.contains(&u)
2663        {
2664            states[2].beta[u - h_range.start] += delta;
2665        } else if let Some(w_range) = primary.w.as_ref()
2666            && w_range.contains(&u)
2667        {
2668            states[3].beta[u - w_range.start] += delta;
2669        } else {
2670            panic!("primary coordinate {u} out of range for flex oracle");
2671        }
2672        let row_ctx = family
2673            .build_row_exact_context_with_stats_and_cell_cache(row, &states, None, false)
2674            .expect("perturbed row context");
2675        let (_neglog, grad, _hess) = family
2676            .compute_row_primary_gradient_hessian(row, &states, primary, &row_ctx)
2677            .expect("perturbed gradient");
2678        grad
2679    }
2680
2681    /// The hand-assembled BMS-FLEX per-row primary Hessian must equal the
2682    /// central finite difference of the flex gradient at every fixture row.
2683    #[test]
2684    fn flex_primary_hessian_matches_central_fd_of_gradient() {
2685        let n = 12usize;
2686        let (family, states) = make_flex_oracle_family(n);
2687        let cache = family
2688            .build_exact_eval_cache(&states)
2689            .expect("flex exact eval cache");
2690        let primary = &cache.primary;
2691        let r = primary.total;
2692        assert!(
2693            r >= 4,
2694            "flex fixture must carry q + logslope + deviation blocks"
2695        );
2696
2697        // Central-difference step. The flex gradient is smooth in every primary
2698        // coordinate; 1e-4 balances truncation (O(h^2)) against the cancellation
2699        // floor of the per-perturbation intercept re-solve (~1e-12).
2700        let h = 1e-4;
2701        let mut max_rel = 0.0_f64;
2702
2703        // A handful of interior rows (avoid the strongest-tail endpoints where
2704        // the FD floor is loosest). Every primary coordinate is differenced.
2705        for &row in &[2usize, 5, 8] {
2706            let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
2707            let (_neglog, _grad, analytic_hess) = family
2708                .compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
2709                .expect("analytic flex gradient + hessian");
2710
2711            for u in 0..r {
2712                let grad_plus = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
2713                let grad_minus = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
2714                for v in 0..r {
2715                    let fd = (grad_plus[v] - grad_minus[v]) / (2.0 * h);
2716                    let analytic = analytic_hess[[v, u]];
2717                    let denom = 1.0 + analytic.abs().max(fd.abs());
2718                    let rel = (analytic - fd).abs() / denom;
2719                    max_rel = max_rel.max(rel);
2720                    assert!(
2721                        rel <= 1e-6,
2722                        "flex hand Hessian H[{v}][{u}] = {analytic:.6e} disagrees with central \
2723                         FD of the gradient {fd:.6e} at row {row} (rel {rel:.3e}); a product-rule \
2724                         term is dropped or mis-signed"
2725                    );
2726                }
2727            }
2728        }
2729        // Surface the achieved tightness for the record.
2730        assert!(
2731            max_rel <= 1e-6,
2732            "flex Hessian FD oracle max rel {max_rel:.3e}"
2733        );
2734    }
2735
2736    /// ARBITER (diagnostic): is the H[0][0] flex-Hessian vs FD-of-gradient gap a
2737    /// REAL hand-derivation bug or just FD-truncation / intercept-re-solve noise
2738    /// in the witness? Sweep the central-difference step `h` on the worst entry
2739    /// (row 2, [q][q]); if the gap scales ~h^2 it is FD truncation (the analytic
2740    /// Hessian is right, the witness bound is just too tight); if it stays flat
2741    /// as h shrinks it is a genuine dropped/mis-signed term. Richardson-cancel
2742    /// the O(h^2) term and report the residual. Panics with the table so the
2743    /// harness surfaces the numbers (stdout is otherwise suppressed).
2744    #[test]
2745    fn arbiter_flex_hessian_h00_fd_step_scaling() {
2746        let n = 12usize;
2747        let (family, states) = make_flex_oracle_family(n);
2748        let cache = family
2749            .build_exact_eval_cache(&states)
2750            .expect("flex exact eval cache");
2751        let primary = &cache.primary;
2752        let row = 2usize;
2753        let u = primary.q; // intercept / q axis => H[0][0]
2754        let v = primary.q;
2755
2756        let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
2757        let (_neglog, _grad, analytic_hess) = family
2758            .compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
2759            .expect("analytic flex gradient + hessian");
2760        let analytic = analytic_hess[[v, u]];
2761
2762        let fd_at = |h: f64| -> f64 {
2763            let gp = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
2764            let gm = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
2765            (gp[v] - gm[v]) / (2.0 * h)
2766        };
2767
2768        // Coarse and fine central-difference steps. If the analytic Hessian is
2769        // CORRECT and the witness gap is pure O(h^2) FD truncation, halving h
2770        // quarters the gap; the Richardson combination cancels that O(h^2) term
2771        // and lands on the analytic value to the intercept-re-solve floor
2772        // (~1e-9). If instead a hand product-rule term is dropped, the gap is
2773        // h-INDEPENDENT and the Richardson residual stays at the bug magnitude.
2774        let h = 1e-3_f64;
2775        let fd_h = fd_at(h);
2776        let fd_half = fd_at(h * 0.5);
2777        let fd_quarter = fd_at(h * 0.25);
2778        let gap_h = (analytic - fd_h).abs();
2779        let gap_half = (analytic - fd_half).abs();
2780        let gap_quarter = (analytic - fd_quarter).abs();
2781        let rich = (4.0 * fd_half - fd_h) / 3.0;
2782        let rich_gap = (analytic - rich).abs();
2783        let denom = analytic.abs().max(1.0);
2784
2785        // DIAGNOSTIC RECORD (shown on failure; this is the dispositive table):
2786        let record = format!(
2787            "FLEX H[0][0] ARBITER row 2: analytic={analytic:+.12e} \
2788             fd(h)={fd_h:+.12e} fd(h/2)={fd_half:+.12e} fd(h/4)={fd_quarter:+.12e} \
2789             gap(h)={gap_h:.3e} gap(h/2)={gap_half:.3e} gap(h/4)={gap_quarter:.3e} \
2790             ratio_h_over_half={:.3} ratio_half_over_quarter={:.3} \
2791             richardson={rich:+.12e} richardson_gap={rich_gap:.3e} (rich_rel={:.3e})",
2792            gap_h / gap_half.max(f64::MIN_POSITIVE),
2793            gap_half / gap_quarter.max(f64::MIN_POSITIVE),
2794            rich_gap / denom,
2795        );
2796
2797        // VERDICT: the analytic Hessian is correct iff the FD gap is O(h^2) — i.e.
2798        // the Richardson-extrapolated second derivative (truncation-cancelled)
2799        // matches it to the intercept-solve floor. A genuine dropped term leaves
2800        // a Richardson residual at the bug scale (~1e-5), failing this with the
2801        // record above so the harness surfaces the numbers.
2802        assert!(
2803            rich_gap / denom <= 1e-7,
2804            "{record}\nVERDICT: Richardson residual exceeds the FD-truncation floor — \
2805             the hand H[0][0] genuinely diverges (real dropped/mis-signed term), NOT FD noise"
2806        );
2807    }
2808}