Skip to main content

gam_models/bms/
mod.rs

1use crate::custom_family::{
2    BatchedOuterGradientTerms, BlockEffectiveJacobian, BlockWorkingSet, BlockwiseFitOptions,
3    CustomFamily, CustomFamilyWarmStart, ExactNewtonJointGradientEvaluation,
4    ExactNewtonJointHessianWorkspace, ExactNewtonJointPsiSecondOrderTerms,
5    ExactNewtonJointPsiTerms, ExactNewtonJointPsiWorkspace, FamilyEvaluation,
6    FamilyLinearizationState, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
7    custom_family_outer_derivatives, evaluate_custom_family_joint_hyper_efs_shared,
8    evaluate_custom_family_joint_hyper_shared, fit_custom_family,
9    joint_hyper_options_for_outer_tolerance,
10};
11use gam_solve::estimate::reml::reml_outer_engine::{DenseSpectralOperator, HessianOperator};
12use crate::cubic_cell_kernel as exact_kernel;
13use crate::marginal_slope_shared::{
14    CoeffSupport, DirectionalScaleJets, ObservedDenestedCellPartials, SparsePrimaryCoeffJetView,
15    add_optional_matrix, add_optional_vector, add_two_surface_psi_outer,
16    build_denested_partition_cells as shared_denested_partition_cells, chunked_row_reduction,
17    directional_obj_grad_hess, eval_coeff4_at, is_sigma_aux_index as shared_is_sigma_aux_index,
18    observed_denested_cell_partials as shared_observed_denested_cell_partials, outer_row_indices,
19    outer_weighted_rows, parameter_block_specs_match_rows, probit_frailty_scale,
20    probit_frailty_scale_multi_dir_jet, psi_derivative_location, scale_coeff4,
21};
22use crate::parameter_block::ParameterBlockInput;
23use crate::row_kernel::{
24    RowKernel, RowKernelHessianWorkspace, build_row_kernel_cache, row_kernel_gradient,
25    row_kernel_hessian_dense, row_kernel_log_likelihood,
26};
27use crate::spatial_psi_bridge::build_block_spatial_psi_derivatives;
28use crate::survival::lognormal_kernel::FrailtySpec;
29use crate::wiggle::initializewiggle_knots_from_seed;
30use gam_linalg::matrix::{DesignMatrix, SymmetricMatrix};
31use crate::model_types::UnifiedFitResult;
32use crate::outer_subsample::WeightedOuterRow;
33use gam_solve::pirls::LinearInequalityConstraints;
34use crate::probability::{
35    normal_cdf, normal_logcdf, normal_pdf, signed_probit_logcdf_and_mills_ratio,
36    standard_normal_quantile,
37};
38use gam_terms::smooth::{
39    SpatialLengthScaleOptimizationOptions, SpatialLogKappaCoords, TermCollectionDesign,
40    TermCollectionSpec,
41};
42use crate::fit_orchestration::drivers::{
43    ExactJointHyperSetup, apply_spatial_anisotropy_pilot_initializer,
44    build_term_collection_designs_and_freeze_joint, optimize_spatial_length_scale_exact_joint,
45    spatial_length_scale_term_indices,
46};
47use gam_problem::{InverseLink, StandardLink, WigglePenaltyConfig};
48use gam_math::jet_partitions::MultiDirJet;
49use gam_problem::HyperOperator;
50use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, s};
51use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
52use serde::{Deserialize, Serialize};
53use std::cell::RefCell;
54use std::collections::HashMap;
55use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
56use std::sync::{Arc, Mutex, OnceLock};
57
58pub mod deviation_runtime;
59pub mod gpu;
60pub use deviation_runtime::DeviationRuntime;
61pub use deviation_runtime::ParametricAnchorBlock;
62
63/// Above this size, FLEX spatial length-scale optimization uses the pilot
64/// geometry initializer and skips the iterative joint κ/ψ outer loop. This is
65/// a spatial-optimizer policy only; it must not gate exact outer Hessian
66/// capability or row-cell moment materialization.
67pub(crate) const BMS_FLEX_SPATIAL_OUTER_PILOT_ROW_THRESHOLD: usize = 50_000;
68
69#[derive(Clone, Debug)]
70pub struct DeviationBlockConfig {
71    pub degree: usize,
72    pub num_internal_knots: usize,
73    pub penalty_order: usize,
74    pub penalty_orders: Vec<usize>,
75    pub double_penalty: bool,
76    pub monotonicity_eps: f64,
77}
78
79impl Default for DeviationBlockConfig {
80    fn default() -> Self {
81        WigglePenaltyConfig::cubic_triple_operator_default().into()
82    }
83}
84
85impl DeviationBlockConfig {
86    pub fn triple_penalty_default() -> Self {
87        Self::default()
88    }
89}
90
91impl From<WigglePenaltyConfig> for DeviationBlockConfig {
92    fn from(cfg: WigglePenaltyConfig) -> Self {
93        let penalty_order = *cfg.penalty_orders.iter().max().unwrap_or(&2);
94        Self {
95            degree: cfg.degree,
96            num_internal_knots: cfg.num_internal_knots,
97            penalty_order,
98            penalty_orders: cfg.penalty_orders,
99            double_penalty: cfg.double_penalty,
100            monotonicity_eps: cfg.monotonicity_eps,
101        }
102    }
103}
104
105#[derive(Clone)]
106pub(crate) struct DeviationPrepared {
107    pub(crate) block: ParameterBlockInput,
108    pub(crate) runtime: DeviationRuntime,
109}
110
111impl std::fmt::Debug for DeviationPrepared {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("DeviationPrepared").finish_non_exhaustive()
114    }
115}
116
117#[derive(Clone)]
118pub struct BernoulliMarginalSlopeTermSpec {
119    pub y: Array1<f64>,
120    pub weights: Array1<f64>,
121    pub z: Array1<f64>,
122    pub base_link: InverseLink,
123    pub marginalspec: TermCollectionSpec,
124    pub logslopespec: TermCollectionSpec,
125    pub marginal_offset: Array1<f64>,
126    pub logslope_offset: Array1<f64>,
127    /// GaussianShift frailty on the final probit index: U ~ N(0, σ²) added
128    /// to the scalar argument of Φ.  This is exact because the sextic
129    /// microcell kernel is preserved — the Gaussian-decoupling identity
130    /// E[Φ(η + U)] = Φ(η / √(1+σ²)) rescales the index by 1/τ where
131    /// τ = √(1+σ²), and every derivative chain rule factor is polynomial
132    /// in τ, so all six kernel derivatives remain closed-form.
133    ///
134    /// **HazardMultiplier frailty is NOT supported in this family.**
135    /// HazardMultiplier frailty + score_warp/linkwiggle cubic marginal-slope
136    /// is not finite-state exact.  For hazard-multiplier frailty, use the
137    /// standalone LatentCloglogBinomial / LatentSurvival families instead.
138    pub frailty: FrailtySpec,
139    pub score_warp: Option<DeviationBlockConfig>,
140    pub link_dev: Option<DeviationBlockConfig>,
141    pub latent_z_policy: LatentZPolicy,
142    /// Out-of-fold Stage-1 score-influence Jacobian `J = ∂z/∂θ₁` (n × p₁)
143    /// from cross-fitting a CTN transformation-normal Stage-1 model (#461).
144    /// When `Some`, the realized leakage directions `Z_infl = diag(s_f·β̂₀)·J`
145    /// are absorbed as a null-penalized block so the joint solve makes the
146    /// β estimating equation orthogonal to `span(Z_infl)` — the x-dependent
147    /// realization of `ψ − Π_η[ψ]`. `None` ⇒ raw `--z-column` with no CTN
148    /// Stage-1, in which case the free 1-D `score_warp` spline is the
149    /// fallback basis (it spans only the x-free leakage column).
150    pub score_influence_jacobian: Option<Array2<f64>>,
151}
152
153pub struct BernoulliMarginalSlopeFitResult {
154    pub fit: UnifiedFitResult,
155    pub marginalspec_resolved: TermCollectionSpec,
156    pub logslopespec_resolved: TermCollectionSpec,
157    pub marginal_design: TermCollectionDesign,
158    pub logslope_design: TermCollectionDesign,
159    pub baseline_marginal: f64,
160    pub baseline_logslope: f64,
161    pub z_normalization: LatentZNormalization,
162    pub latent_measure: LatentMeasureKind,
163    pub score_warp_runtime: Option<DeviationRuntime>,
164    pub link_dev_runtime: Option<DeviationRuntime>,
165    /// Learned or fixed Gaussian-shift frailty SD.  `None` = no frailty.
166    pub gaussian_frailty_sd: Option<f64>,
167    /// Structured warnings emitted during fit-time setup when a flex
168    /// block was fully aliased by its anchor union and got dropped. The
169    /// fit proceeds without the dropped block (its contribution to the
170    /// joint design was numerically reproducible by the anchor span, so
171    /// keeping it would leave the joint Hessian rank-deficient). Empty
172    /// for fits where every flex block carried independent directions.
173    pub cross_block_warnings: Vec<CrossBlockIdentifiabilityWarning>,
174    /// Optional weighted rank inverse-normal (Blom rankit) calibration
175    /// installed at fit time when the auto latent-z normality check
176    /// failed. `Some(_)` ⇒ the training z was transformed in place via
177    /// [`LatentZRankIntCalibration::apply_to_training`] before any
178    /// downstream consumer (pooled probit baseline, term-collection
179    /// designs, family PIRLS loops) saw it, and the rigid kernel
180    /// routes through the standard-normal closed-form path on the
181    /// calibrated scale. `None` ⇒ no calibration was applied (training
182    /// z already passed the standard-normal diagnostics, or the caller
183    /// explicitly selected a non-Auto `LatentMeasureSpec`).
184    ///
185    /// Persisted to disk so prediction applies the same monotone map
186    /// via [`LatentZRankIntCalibration::apply_at_predict`] to incoming
187    /// z before the standard-normal kernel runs. The public field name
188    /// is `latent_z_rank_int_calibration` — Agent D's persistence
189    /// pipeline reads it under that exact identifier.
190    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
191    /// Optional conditional location-scale calibration of the latent score
192    /// (#905). `Some(_)` ⇒ the Auto path's conditional `E[z|C]`/`Var(z|C)` Rao
193    /// gate detected PC/grouping-dependence that the pooled-marginal gate
194    /// cannot see, so the training z was replaced in place by
195    /// `ζ = (z − m(C))/√v(C)` (via [`LatentZConditionalCalibration::apply`])
196    /// before any downstream consumer saw it. Mutually exclusive with
197    /// `latent_z_rank_int_calibration`: rank-INT fixes a pooled-marginal
198    /// defect, the conditional correction fixes a conditional-shift defect that
199    /// rank-INT provably cannot. Persisted so prediction rebuilds `a(C)` from
200    /// the (reproducible) marginal design and applies the identical map.
201    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
202}
203
204#[derive(Clone, Debug)]
205pub enum LatentZCheckMode {
206    Strict,
207    WarnOnly,
208    Off,
209}
210
211#[derive(Clone, Debug)]
212pub enum LatentZNormalizationMode {
213    None,
214    FitWeighted,
215    Frozen { mean: f64, sd: f64 },
216}
217
218pub const DEFAULT_EMPIRICAL_LATENT_GRID_SIZE: usize = 65;
219pub(crate) const AUTO_Z_NORMAL_SKEW_TOL: f64 = 0.10;
220pub(crate) const AUTO_Z_NORMAL_KURT_TOL: f64 = 0.25;
221pub(crate) const AUTO_Z_NORMAL_KS_TOL: f64 = 0.025;
222pub(crate) const AUTO_Z_NORMAL_MAX_ABS: f64 = 8.0;
223/// Inner σ level at which the empirical tail mass of latent z is compared
224/// against the standard normal's theoretical two-sided tail in the auto
225/// normality gate. Chosen well inside `AUTO_Z_NORMAL_MAX_ABS` so a fat inner
226/// tail is caught before any single observation trips the hard `max |z|` bound.
227pub(crate) const AUTO_Z_NORMAL_TAIL_SIGMA_INNER: f64 = 4.0;
228/// Outer σ level for the same tail-mass comparison; catches heavier far-tail
229/// excess that the inner level can miss.
230pub(crate) const AUTO_Z_NORMAL_TAIL_SIGMA_OUTER: f64 = 6.0;
231/// Multiplier applied to the normal's theoretical tail mass before comparison:
232/// the empirical tail may be up to this many times the Gaussian tail at the
233/// same σ before the gate fails, allowing for finite-sample sampling noise.
234pub(crate) const AUTO_Z_NORMAL_TAIL_MASS_SLACK: f64 = 2.0;
235/// Absolute additive floor on the inner-σ tail comparison, so the gate does
236/// not fail on round-off when the Gaussian tail itself is already tiny.
237pub(crate) const AUTO_Z_NORMAL_TAIL_FLOOR_INNER: f64 = 1e-5;
238/// Absolute additive floor on the outer-σ tail comparison; smaller than the
239/// inner floor because the 6σ Gaussian tail is many orders smaller than 4σ.
240pub(crate) const AUTO_Z_NORMAL_TAIL_FLOOR_OUTER: f64 = 1e-8;
241/// Significance level for the conditional `E[z|C]` / `Var(z|C)` Rao gate in the
242/// core Auto path (#905). When the latent score's conditional mean or variance
243/// on the marginal-index span `a(C)` is significant at this level, the Auto
244/// path escalates from the pooled-marginal rank-INT to a conditional
245/// location-scale correction. Chosen small (0.1%) so the escalation fires only
246/// on clear conditional structure, not finite-sample noise — the gate runs once
247/// over the whole training sample, so a per-test α this tight still has ample
248/// power against the grouping mean-shift the issue names.
249pub(crate) const AUTO_Z_CONDITIONAL_RAO_ALPHA: f64 = 1.0e-3;
250/// Relative ridge added to the weighted normal equations when regressing the
251/// latent score on the marginal-index span for the conditional correction.
252/// Stabilizes the solve when `a(C)` is rank-deficient or collinear (penalized
253/// spline marginal indices routinely are) without materially biasing the
254/// conditional mean/variance fit.
255pub(crate) const AUTO_Z_CONDITIONAL_RIDGE_REL: f64 = 1.0e-8;
256/// Floor on the fitted conditional variance `v(C)`, as a fraction of the global
257/// weighted variance of the latent score. Keeps `ζ = (z−m)/√v` finite and
258/// well-scaled where the linear variance model would otherwise fit a
259/// non-positive or vanishing conditional variance.
260pub(crate) const AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC: f64 = 1.0e-3;
261
262#[derive(Clone, Copy, Debug, PartialEq, Eq)]
263pub enum LatentMeasureSpec {
264    Auto { grid_size: usize },
265    StandardNormal,
266    GlobalEmpirical { grid_size: usize },
267}
268
269impl LatentMeasureSpec {
270    pub fn auto_default() -> Self {
271        Self::Auto {
272            grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
273        }
274    }
275}
276
277impl Default for LatentMeasureSpec {
278    fn default() -> Self {
279        Self::auto_default()
280    }
281}
282
283#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
284pub struct EmpiricalZGrid {
285    pub nodes: Vec<f64>,
286    pub weights: Vec<f64>,
287}
288
289impl EmpiricalZGrid {
290    /// Construct a grid whose node/weight invariants (equal length ≥ 2, finite
291    /// nodes, finite positive weights, weights summing to 1 within 1e-8) are
292    /// enforced up-front. Prefer this over building the struct literally;
293    /// every code path that goes through `new` is guaranteed to satisfy the
294    /// same contract that `validate_empirical_z_grid` checks on read.
295    pub fn new(nodes: Vec<f64>, weights: Vec<f64>, context: &str) -> Result<Self, String> {
296        validate_empirical_z_grid(&nodes, &weights, context)?;
297        Ok(Self { nodes, weights })
298    }
299
300    /// Iterate over co-indexed `(node, weight)` pairs. Use this instead of
301    /// reading `.nodes`/`.weights` separately whenever a loop wants both
302    /// arrays in lockstep — eliminates the chance of mismatched indexing.
303    #[inline]
304    pub fn pairs(&self) -> impl Iterator<Item = (f64, f64)> + '_ {
305        self.nodes.iter().copied().zip(self.weights.iter().copied())
306    }
307}
308
309#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
310#[serde(tag = "kind", rename_all = "kebab-case")]
311#[derive(Default)]
312pub enum LatentMeasureKind {
313    #[default]
314    StandardNormal,
315    GlobalEmpirical {
316        grid: EmpiricalZGrid,
317    },
318    LocalEmpirical {
319        feature_cols: Vec<usize>,
320        #[serde(default)]
321        input_scales: Option<Vec<f64>>,
322        centers: Vec<Vec<f64>>,
323        grids: Vec<EmpiricalZGrid>,
324        top_k: usize,
325        bandwidth: f64,
326        #[serde(skip)]
327        train_row_mixtures: Arc<Vec<Vec<(usize, f64)>>>,
328    },
329}
330
331impl LatentMeasureKind {
332    pub fn validate(&self, context: &str) -> Result<(), String> {
333        match self {
334            Self::StandardNormal => Ok(()),
335            Self::GlobalEmpirical { grid } => {
336                validate_empirical_z_grid(&grid.nodes, &grid.weights, context)
337            }
338            Self::LocalEmpirical {
339                feature_cols,
340                input_scales,
341                centers,
342                grids,
343                top_k,
344                bandwidth,
345                ..
346            } => {
347                if feature_cols.is_empty() {
348                    return Err(format!(
349                        "{context} local empirical latent measure needs feature columns"
350                    ));
351                }
352                if centers.is_empty() {
353                    return Err(format!(
354                        "{context} local empirical latent measure needs centers"
355                    ));
356                }
357                if centers.len() != grids.len() {
358                    return Err(format!(
359                        "{context} local empirical latent measure center/grid length mismatch: centers={}, grids={}",
360                        centers.len(),
361                        grids.len()
362                    ));
363                }
364                if *top_k == 0 || *top_k > centers.len() {
365                    return Err(format!(
366                        "{context} local empirical latent measure top_k must be in 1..={}, got {top_k}",
367                        centers.len()
368                    ));
369                }
370                if !(*bandwidth).is_finite() || *bandwidth <= 0.0 {
371                    return Err(format!(
372                        "{context} local empirical latent measure bandwidth must be finite and positive, got {bandwidth}"
373                    ));
374                }
375                if let Some(scales) = input_scales.as_ref() {
376                    if scales.len() != feature_cols.len() {
377                        return Err(format!(
378                            "{context} local empirical latent measure input scale dimension mismatch: scales={}, features={}",
379                            scales.len(),
380                            feature_cols.len()
381                        ));
382                    }
383                    for (scale_idx, scale) in scales.iter().enumerate() {
384                        if !(scale.is_finite() && *scale > 0.0) {
385                            return Err(format!(
386                                "{context} local empirical latent measure input scale {scale_idx} must be finite and positive, got {scale}"
387                            ));
388                        }
389                    }
390                }
391                for (center_idx, center) in centers.iter().enumerate() {
392                    if center.len() != feature_cols.len() {
393                        return Err(format!(
394                            "{context} local empirical latent center {center_idx} dimension mismatch: got {}, expected {}",
395                            center.len(),
396                            feature_cols.len()
397                        ));
398                    }
399                    if center.iter().any(|value| !value.is_finite()) {
400                        return Err(format!(
401                            "{context} local empirical latent center {center_idx} has non-finite coordinates"
402                        ));
403                    }
404                }
405                for (grid_idx, grid) in grids.iter().enumerate() {
406                    validate_empirical_z_grid(
407                        &grid.nodes,
408                        &grid.weights,
409                        &format!("{context} local empirical grid {grid_idx}"),
410                    )?;
411                }
412                Ok(())
413            }
414        }
415    }
416
417    pub(crate) fn is_empirical(&self) -> bool {
418        matches!(
419            self,
420            Self::GlobalEmpirical { .. } | Self::LocalEmpirical { .. }
421        )
422    }
423
424    /// Per-row empirical latent grid, borrowed where possible. This sits in
425    /// the innermost per-row loops of every criterion/gradient/Hessian
426    /// evaluation, so the global grid MUST come back as a borrow — the old
427    /// `grid.clone()` here allocated two `grid_size`-length vectors per row
428    /// per evaluation across the whole fit. Only the local-mixture path,
429    /// which genuinely synthesizes a new grid per row, returns an owned
430    /// value.
431    pub(crate) fn empirical_grid_for_training_row(
432        &self,
433        row: usize,
434    ) -> Result<Option<std::borrow::Cow<'_, EmpiricalZGrid>>, String> {
435        match self {
436            Self::StandardNormal => Ok(None),
437            Self::GlobalEmpirical { grid } => Ok(Some(std::borrow::Cow::Borrowed(grid))),
438            Self::LocalEmpirical {
439                grids,
440                train_row_mixtures,
441                ..
442            } => {
443                let mixture = train_row_mixtures.get(row).ok_or_else(|| {
444                    format!(
445                        "local empirical latent measure is missing training mixture for row {row}"
446                    )
447                })?;
448                Ok(Some(std::borrow::Cow::Owned(combine_empirical_grids(
449                    grids, mixture,
450                )?)))
451            }
452        }
453    }
454}
455
456pub(crate) fn validate_empirical_z_grid(
457    nodes: &[f64],
458    weights: &[f64],
459    context: &str,
460) -> Result<(), String> {
461    if nodes.len() != weights.len() {
462        return Err(format!(
463            "{context} empirical latent measure node/weight length mismatch: nodes={}, weights={}",
464            nodes.len(),
465            weights.len()
466        ));
467    }
468    if nodes.len() < 2 {
469        return Err(format!(
470            "{context} empirical latent measure requires at least two nodes"
471        ));
472    }
473    let mut total = 0.0;
474    for (idx, (&node, &weight)) in nodes.iter().zip(weights.iter()).enumerate() {
475        if !node.is_finite() {
476            return Err(format!(
477                "{context} empirical latent measure node {idx} is non-finite ({node})"
478            ));
479        }
480        if !(weight.is_finite() && weight > 0.0) {
481            return Err(format!(
482                "{context} empirical latent measure weight {idx} must be finite and positive, got {weight}"
483            ));
484        }
485        total += weight;
486    }
487    if !(total.is_finite() && (total - 1.0).abs() <= 1e-8) {
488        return Err(format!(
489            "{context} empirical latent measure weights must sum to 1, got {total}"
490        ));
491    }
492    Ok(())
493}
494
495pub(crate) fn combine_empirical_grids(
496    grids: &[EmpiricalZGrid],
497    mixture: &[(usize, f64)],
498) -> Result<EmpiricalZGrid, String> {
499    if mixture.is_empty() {
500        return Err("local empirical latent measure row mixture is empty".to_string());
501    }
502    let mut nodes = Vec::new();
503    let mut weights = Vec::new();
504    for &(grid_idx, grid_weight) in mixture {
505        if !grid_weight.is_finite() || grid_weight <= 0.0 {
506            return Err(format!(
507                "local empirical latent mixture weight must be finite and positive, got {grid_weight}"
508            ));
509        }
510        let grid = grids.get(grid_idx).ok_or_else(|| {
511            format!("local empirical latent mixture references missing grid {grid_idx}")
512        })?;
513        for (node, weight) in grid.pairs() {
514            nodes.push(node);
515            weights.push(grid_weight * weight);
516        }
517    }
518    let total = weights.iter().copied().sum::<f64>();
519    if !(total.is_finite() && total > 0.0) {
520        return Err(
521            "local empirical latent combined grid has non-positive total weight".to_string(),
522        );
523    }
524    for weight in &mut weights {
525        *weight /= total;
526    }
527    EmpiricalZGrid::new(nodes, weights, "local empirical latent combined grid")
528}
529
530#[derive(Clone, Debug)]
531pub struct LatentZPolicy {
532    pub check_mode: LatentZCheckMode,
533    pub normalization: LatentZNormalizationMode,
534    pub latent_measure: LatentMeasureSpec,
535    pub mean_tol_multiplier: f64,
536    pub sd_tol_multiplier: f64,
537    pub max_abs_skew: f64,
538    pub max_abs_excess_kurtosis: f64,
539}
540
541impl LatentZPolicy {
542    pub fn frozen_transformation_normal() -> Self {
543        // Defaults relaxed to `WarnOnly` with the same thresholds the
544        // exploratory-weighted preset uses (skew ≤ 4.0, |excess kurt| ≤ 20.0).
545        // Rationale: the upstream conditional transformation-normal
546        // preprocessor may be fit isotropically (no per-axis κ). At large-scale
547        // dimensionality (16 PCs, 15 ancestries) an isotropic fit can leave
548        // the global latent-z distribution mildly heavy-tailed (skew ≈ 4,
549        // excess kurt ≈ 30–40 in synthetic studies) without violating per-
550        // grouping mean/variance calibration. The downstream marginal-slope
551        // model still uses the latent-Gaussian probit/score-warp link; the
552        // emitted warning makes the deviation visible without aborting the
553        // fit. Callers that need strict enforcement can construct a custom
554        // `LatentZPolicy` with `check_mode: LatentZCheckMode::Strict`.
555        Self {
556            check_mode: LatentZCheckMode::WarnOnly,
557            normalization: LatentZNormalizationMode::Frozen { mean: 0.0, sd: 1.0 },
558            latent_measure: LatentMeasureSpec::auto_default(),
559            mean_tol_multiplier: 4.0,
560            sd_tol_multiplier: 4.0,
561            max_abs_skew: 4.0,
562            max_abs_excess_kurtosis: 20.0,
563        }
564    }
565
566    pub fn exploratory_fit_weighted() -> Self {
567        Self {
568            check_mode: LatentZCheckMode::WarnOnly,
569            normalization: LatentZNormalizationMode::FitWeighted,
570            latent_measure: LatentMeasureSpec::auto_default(),
571            mean_tol_multiplier: 8.0,
572            sd_tol_multiplier: 8.0,
573            max_abs_skew: 4.0,
574            max_abs_excess_kurtosis: 20.0,
575        }
576    }
577}
578
579impl Default for LatentZPolicy {
580    fn default() -> Self {
581        Self::frozen_transformation_normal()
582    }
583}
584
585#[derive(Clone, Copy, Debug, PartialEq)]
586pub struct LatentZNormalization {
587    pub mean: f64,
588    pub sd: f64,
589}
590
591impl LatentZNormalization {
592    pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, String> {
593        if !(self.mean.is_finite() && self.sd.is_finite() && self.sd > BMS_VARIANCE_FLOOR) {
594            return Err(format!(
595                "{context} requires finite latent z normalization with sd > {BMS_VARIANCE_FLOOR:e}; got mean={} sd={}",
596                self.mean, self.sd
597            ));
598        }
599        if z.iter().any(|value| !value.is_finite()) {
600            return Err(format!("{context} requires finite z values"));
601        }
602        Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
603    }
604}
605
606/// Blom-rankit weighted rank inverse-normal transform for the latent
607/// score.
608///
609/// When the latent z fails the standard-normal auto-detection
610/// ([`latent_z_is_standard_normal_enough`]), the BMS family applied to
611/// pretend the score is N(0,1) anyway would distort the closed-form
612/// probit log-CDF kernel. The historical fallback (local- or
613/// global-empirical latent measure) is *mathematically correct* but
614/// triggers the per-row intercept Newton solve in the empirical-grid
615/// closed-form kernels (`empirical_rigid_primary_grad_hess_closed_form`
616/// and its higher-order siblings); at large scale that is the dominant
617/// cost.
618///
619/// **Rank-INT is exact under monotone re-parameterisation.** The Blom rankit assigns
620/// each sorted training z the rank-probability
621/// `(W_i − 0.375) / (W_total + 0.25)`, then maps that probability
622/// through `Φ⁻¹`. The transform is **strictly monotone** on the
623/// observed support, so the BMS likelihood is invariant up to a
624/// re-parameterisation (the model is a transformation-equivariant
625/// family on the latent axis). The transformed sample is *exactly*
626/// N(0,1) by construction, so the standard-normal closed-form kernel
627/// is **exact** on the calibrated scale. The kept work is the same
628/// closed-form `signed_probit_logcdf_and_mills_ratio` evaluation as
629/// the no-calibration path; the dropped work is the empirical-grid
630/// jet machinery. Persisted to disk so prediction applies the same
631/// monotone map to incoming z and re-routes through the closed-form
632/// kernel.
633#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
634pub struct LatentZRankIntCalibration {
635    /// Sorted unique z values seen during training, ascending. Knot table
636    /// for `apply_to_training` / `apply_at_predict`.
637    pub sorted_z: Vec<f64>,
638    /// Weighted cumulative-distribution-function values at each
639    /// `sorted_z` knot, in `[eps, 1 - eps]` with
640    /// `eps = 0.5 / W_total`. Strictly increasing.
641    pub weighted_cdf: Vec<f64>,
642    /// Weighted mean of the calibrated training sample. Used as a
643    /// sanity-check value on `fit`; should be very close to zero.
644    pub post_mean: f64,
645    /// Weighted SD of the calibrated training sample. Used as a
646    /// sanity-check value on `fit`; should be very close to one.
647    pub post_sd: f64,
648}
649
650impl LatentZRankIntCalibration {
651    /// Fit the weighted rank-INT calibration from training z and weights.
652    ///
653    /// Algorithm:
654    /// 1. Sort rows by ascending z.
655    /// 2. Compute cumulative weight `W_i` at each sorted row.
656    /// 3. Blom-rankit cumulative probability:
657    ///    `p_i = (W_i − 0.375) / (W_total + 0.25)`.
658    /// 4. Clip to `[eps, 1 − eps]` with `eps = 0.5 / W_total`.
659    /// 5. Store `(sorted_z, weighted_cdf = p_i)`.
660    ///
661    /// Returns the calibration plus the post-transform sample's weighted
662    /// mean / SD for sanity-check logging.
663    pub fn fit(z: &Array1<f64>, weights: &Array1<f64>) -> Result<Self, String> {
664        if z.len() != weights.len() {
665            return Err(format!(
666                "rank-INT calibration: z length {} != weights length {}",
667                z.len(),
668                weights.len()
669            ));
670        }
671        if z.is_empty() {
672            return Err("rank-INT calibration requires at least one observation".to_string());
673        }
674        let w_total = weights.iter().copied().sum::<f64>();
675        if !(w_total.is_finite() && w_total > 0.0) {
676            return Err(format!(
677                "rank-INT calibration requires positive finite total weight, got {w_total}"
678            ));
679        }
680        for (idx, value) in z.iter().enumerate() {
681            if !value.is_finite() {
682                return Err(format!(
683                    "rank-INT calibration: z[{idx}] = {value} not finite"
684                ));
685            }
686        }
687        for (idx, weight) in weights.iter().enumerate() {
688            if !(weight.is_finite() && *weight >= 0.0) {
689                return Err(format!(
690                    "rank-INT calibration: weight[{idx}] = {weight} not finite/non-negative"
691                ));
692            }
693        }
694        let mut order: Vec<usize> = (0..z.len()).collect();
695        order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
696
697        let mut sorted_z: Vec<f64> = Vec::with_capacity(z.len());
698        let mut weighted_cdf: Vec<f64> = Vec::with_capacity(z.len());
699        let denom = w_total + 0.25;
700        let eps = 0.5 / w_total.max(1.0);
701        let mut cum_w = 0.0_f64;
702        let mut last_z: Option<f64> = None;
703        for &idx in &order {
704            cum_w += weights[idx];
705            let zi = z[idx];
706            // Collapse ties: store one knot per unique z (last cumulative).
707            if let Some(prev) = last_z
708                && zi == prev
709            {
710                if let Some(slot) = weighted_cdf.last_mut() {
711                    let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
712                    *slot = p;
713                }
714                continue;
715            }
716            let p = ((cum_w - 0.375) / denom).clamp(eps, 1.0 - eps);
717            sorted_z.push(zi);
718            weighted_cdf.push(p);
719            last_z = Some(zi);
720        }
721
722        // Compute sanity-check post-mean and post-sd on the transformed
723        // sample, weighted by the original weights.
724        let mut sum_wz = 0.0_f64;
725        let mut sum_w = 0.0_f64;
726        for &idx in &order {
727            let zi = z[idx];
728            let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
729            sum_wz += weights[idx] * calibrated;
730            sum_w += weights[idx];
731        }
732        let post_mean = if sum_w > 0.0 { sum_wz / sum_w } else { 0.0 };
733        let mut sum_w_dev = 0.0_f64;
734        for &idx in &order {
735            let zi = z[idx];
736            let calibrated = Self::apply_with_knots(zi, &sorted_z, &weighted_cdf);
737            let d = calibrated - post_mean;
738            sum_w_dev += weights[idx] * d * d;
739        }
740        let post_sd = if sum_w > 0.0 {
741            (sum_w_dev / sum_w).sqrt()
742        } else {
743            1.0
744        };
745
746        Ok(Self {
747            sorted_z,
748            weighted_cdf,
749            post_mean,
750            post_sd,
751        })
752    }
753
754    /// Apply the calibration to the full training z vector, returning the
755    /// calibrated sample. Equivalent to mapping each row's z through
756    /// [`Self::apply_at_predict`], but vectorised.
757    pub fn apply_to_training(&self, z: &Array1<f64>) -> Result<Array1<f64>, String> {
758        if self.sorted_z.is_empty() {
759            return Err("rank-INT calibration has no knots".to_string());
760        }
761        let mut out = Array1::<f64>::zeros(z.len());
762        for (idx, &zi) in z.iter().enumerate() {
763            if !zi.is_finite() {
764                return Err(format!(
765                    "rank-INT calibration apply: z[{idx}] = {zi} not finite"
766                ));
767            }
768            out[idx] = self.apply_at_predict(zi);
769        }
770        Ok(out)
771    }
772
773    /// Apply the calibration to a single z at predict time.
774    ///
775    /// Linear interpolation on `(sorted_z, weighted_cdf)` to obtain
776    /// `p ∈ [eps, 1 − eps]`, then `Φ⁻¹(p)` via
777    /// [`standard_normal_quantile`]. Out-of-range z's clip to the
778    /// boundary CDF before the quantile, so the calibration extrapolates
779    /// monotonically beyond the training support.
780    pub fn apply_at_predict(&self, z: f64) -> f64 {
781        Self::apply_with_knots(z, &self.sorted_z, &self.weighted_cdf)
782    }
783
784    pub(crate) fn apply_with_knots(z: f64, sorted_z: &[f64], weighted_cdf: &[f64]) -> f64 {
785        assert_eq!(sorted_z.len(), weighted_cdf.len());
786        assert!(!sorted_z.is_empty());
787        let n = sorted_z.len();
788        let p = if z <= sorted_z[0] {
789            weighted_cdf[0]
790        } else if z >= sorted_z[n - 1] {
791            weighted_cdf[n - 1]
792        } else {
793            // Binary search for the right knot.
794            let mut lo = 0usize;
795            let mut hi = n - 1;
796            while hi - lo > 1 {
797                let mid = (lo + hi) / 2;
798                if sorted_z[mid] <= z {
799                    lo = mid;
800                } else {
801                    hi = mid;
802                }
803            }
804            let z_lo = sorted_z[lo];
805            let z_hi = sorted_z[hi];
806            let p_lo = weighted_cdf[lo];
807            let p_hi = weighted_cdf[hi];
808            if z_hi == z_lo {
809                p_hi
810            } else {
811                let t = (z - z_lo) / (z_hi - z_lo);
812                p_lo + t * (p_hi - p_lo)
813            }
814        };
815        // Φ⁻¹(p); clip away from {0, 1} to keep the quantile finite.
816        standard_normal_quantile(p).unwrap_or_else(|_| if p < 0.5 { -8.0 } else { 8.0 })
817    }
818}
819
820/// Optional calibration applied to the latent score before the BMS
821/// kernel runs. When `RankInverseNormal`, both the training and predict
822/// paths route the input z through [`LatentZRankIntCalibration::apply_*`]
823/// before the standard-normal closed-form kernel is invoked.
824#[derive(Clone, Debug)]
825pub enum LatentMeasureCalibration {
826    None,
827    RankInverseNormal(LatentZRankIntCalibration),
828    ConditionalLocationScale(LatentZConditionalCalibration),
829}
830
831/// Conditional location-scale calibration of the latent score (#905).
832///
833/// The marginal-slope Auto trigger's pooled-z gate (KS / skewness / kurtosis +
834/// the rank inverse-normal transform) only inspects the **marginal** law of
835/// `z`. A conditional shift `E[z | C] = m(C) ≠ 0` — the allele-frequency-driven
836/// grouping mean shift — passes the marginal gate while leaving `z | C`
837/// off-center, so the slope contribution `b(C)·m(C)` leaks into the influence
838/// channel `q`. Rank-INT provably cannot fix this: no transform `T` depending
839/// only on the marginal `F_Z` can enforce `E[T(Z) | C] ≡ const` for all joint
840/// laws.
841///
842/// The unique Fisher-orthogonal location-scale correction (for the Gaussian
843/// working metric the closed-form probit kernel assumes) is
844/// `ζ = (z − m(C)) / √v(C)`, where `m(C) = E[z|C]` and `v(C) = Var(z|C)` are
845/// estimated by weighted ridge regression of `z` (and its squared residual) on
846/// the marginal-index span `a(C) = [1 | X_marginal]`. The corrected `ζ` is
847/// conditionally centered (and homoskedastic when the variance block is
848/// active) by construction, so the `b(C)·m(C)` leakage vanishes and the
849/// standard-normal closed-form kernel is exact on `ζ`. Persisted so prediction
850/// rebuilds `a(C)` from the (reproducible) marginal design and applies the
851/// identical map to incoming z.
852#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct LatentZConditionalCalibration {
854    /// Coefficients for the conditional mean `m(C) = β_m·[1 | a(C)]` over the
855    /// basis `[1 | marginal-design row]`. Length `1 + basis_ncols` (leading
856    /// entry is the intercept).
857    pub mean_coeffs: Vec<f64>,
858    /// Coefficients for the conditional variance
859    /// `v(C) = max(β_v·[1 | a(C)], var_floor)`. Length `1 + basis_ncols`, or
860    /// empty when the conditional-variance block of the Rao gate was not
861    /// significant (mean-only correction); then `v(C) ≡ global_var`.
862    pub var_coeffs: Vec<f64>,
863    /// Number of marginal-design columns in the basis (excludes the leading
864    /// intercept). The predict-time marginal design must present exactly this
865    /// many columns.
866    pub basis_ncols: usize,
867    /// Floor on the fitted conditional variance, in the (normalized)
868    /// latent-score scale (= `AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC · global_var`).
869    pub var_floor: f64,
870    /// Global weighted variance of the (normalized) training latent score. Used
871    /// as `v(C)` when `var_coeffs` is empty.
872    pub global_var: f64,
873    /// Weighted mean of the calibrated training sample (sanity-check, ≈ 0).
874    pub post_mean: f64,
875    /// Weighted SD of the calibrated training sample (sanity-check, ≈ 1).
876    pub post_sd: f64,
877    /// First-stage (generated-regressor) sandwich covariance of `mean_coeffs`,
878    /// `V₁ᵐ = M⁻¹ (Σ_i w_i² û_i² A_i A_iᵀ) M⁻¹` with `A = [1 | a(C)]`,
879    /// `M = AᵀWA + λR` (the same weighted-ridge normal matrix that produced
880    /// `mean_coeffs`), `û_i = z_i − m̂(C_i)` the HC0 mean residual, and
881    /// `W = diag(w_i)`. Shape `(1+basis_ncols) × (1+basis_ncols)`. This is the
882    /// closed-form estimation uncertainty of `m(C)` that the second stage
883    /// (Murphy–Topel) needs; see [`Self::generated_regressor_term`].
884    pub mean_cov: Array2<f64>,
885    /// First-stage sandwich covariance of `var_coeffs`, computed identically on
886    /// the squared-mean-residual response. Empty (`0 × 0`) exactly when
887    /// `var_coeffs` is empty (mean-only correction; `v(C) ≡ global_var` is a
888    /// constant carrying no first-stage slope uncertainty).
889    pub var_cov: Array2<f64>,
890}
891
892impl LatentZConditionalCalibration {
893    #[inline]
894    pub(crate) fn affine(coeffs: &[f64], a_row: ArrayView1<'_, f64>) -> f64 {
895        let mut acc = coeffs[0];
896        for (c, &x) in coeffs[1..].iter().zip(a_row.iter()) {
897            acc += c * x;
898        }
899        acc
900    }
901
902    pub(crate) fn conditional_mean(&self, a_row: ArrayView1<'_, f64>) -> f64 {
903        Self::affine(&self.mean_coeffs, a_row)
904    }
905
906    pub(crate) fn conditional_var(&self, a_row: ArrayView1<'_, f64>) -> f64 {
907        if self.var_coeffs.is_empty() {
908            self.global_var.max(self.var_floor)
909        } else {
910            Self::affine(&self.var_coeffs, a_row).max(self.var_floor)
911        }
912    }
913
914    /// Apply `ζ = (z − m(C))/√v(C)` to a batch. `a_block` is the marginal
915    /// design (`n × basis_ncols`); `z` is the (normalized) latent score. Used
916    /// at both training and predict time, so the map is identical.
917    pub fn apply(
918        &self,
919        z: ArrayView1<'_, f64>,
920        a_block: ArrayView2<'_, f64>,
921    ) -> Result<Array1<f64>, String> {
922        if a_block.ncols() != self.basis_ncols {
923            return Err(format!(
924                "conditional latent calibration expects {} basis columns, got {}",
925                self.basis_ncols,
926                a_block.ncols()
927            ));
928        }
929        if a_block.nrows() != z.len() {
930            return Err(format!(
931                "conditional latent calibration row mismatch: z={}, basis rows={}",
932                z.len(),
933                a_block.nrows()
934            ));
935        }
936        if self.mean_coeffs.len() != self.basis_ncols + 1 {
937            return Err(format!(
938                "conditional latent calibration mean coefficient length {} != basis_ncols+1 ({})",
939                self.mean_coeffs.len(),
940                self.basis_ncols + 1
941            ));
942        }
943        let mut out = Array1::<f64>::zeros(z.len());
944        for i in 0..z.len() {
945            let a_row = a_block.row(i);
946            if !z[i].is_finite() {
947                return Err(format!(
948                    "conditional latent calibration: z[{i}] = {} not finite",
949                    z[i]
950                ));
951            }
952            let m = self.conditional_mean(a_row);
953            let v = self.conditional_var(a_row);
954            if !(v.is_finite() && v > 0.0) {
955                return Err(format!(
956                    "conditional latent calibration produced non-positive variance {v} at row {i}"
957                ));
958            }
959            let zeta = (z[i] - m) / v.sqrt();
960            if !zeta.is_finite() {
961                return Err(format!(
962                    "conditional latent calibration produced non-finite zeta at row {i}"
963                ));
964            }
965            out[i] = zeta;
966        }
967        Ok(out)
968    }
969
970    /// Dimension of the first-stage parameter vector `θ₁ = (mean_coeffs,
971    /// var_coeffs)` whose estimation uncertainty the generated-regressor
972    /// correction propagates. Equals `len(mean_coeffs)` when the variance block
973    /// is inactive, otherwise `len(mean_coeffs) + len(var_coeffs)`.
974    pub fn theta1_dim(&self) -> usize {
975        self.mean_coeffs.len() + self.var_coeffs.len()
976    }
977
978    /// Per-row sensitivity `∂ζ_i/∂θ₁` of the calibrated score to the first-stage
979    /// calibration coefficients, stacked as `[∂ζ/∂mean_coeffs | ∂ζ/∂var_coeffs]`
980    /// (length [`Self::theta1_dim`]). With `ζ = (z − m(C))/√v(C)`,
981    /// `A_i = [1 | a(C_i)]`, `m = A_iᵀ·mean_coeffs`, `v = A_iᵀ·var_coeffs`:
982    ///
983    ///   `∂ζ/∂m = −1/√v`,  `∂ζ/∂v = −(z − m)/(2 v^{3/2}) = −ζ/(2v)`,
984    ///
985    /// and by the chain rule through the affine basis
986    /// `∂ζ/∂mean_coeffs = (∂ζ/∂m)·A_i`, `∂ζ/∂var_coeffs = (∂ζ/∂v)·A_i`. The
987    /// variance block contributes only when `var_coeffs` is active AND the
988    /// fitted `v(C_i)` is above the floor (a floored row has `∂v/∂var_coeffs = 0`
989    /// in the applied map). `z` is the (normalized) raw latent score at this row.
990    pub fn zeta_theta1_jacobian_row(&self, z: f64, a_row: ArrayView1<'_, f64>) -> Vec<f64> {
991        let m = self.conditional_mean(a_row);
992        let v = self.conditional_var(a_row);
993        let inv_sqrt_v = 1.0 / v.sqrt();
994        // Intercept-augmented basis row A_i = [1 | a(C_i)].
995        let mut out = Vec::with_capacity(self.theta1_dim());
996        let dzeta_dm = -inv_sqrt_v;
997        out.push(dzeta_dm); // intercept column of A
998        for &x in a_row.iter() {
999            out.push(dzeta_dm * x);
1000        }
1001        if !self.var_coeffs.is_empty() {
1002            // ∂ζ/∂v active only off the floor; on the floor the applied v(C) is
1003            // constant in var_coeffs, so the variance sensitivity is exactly 0.
1004            let raw_v = Self::affine(&self.var_coeffs, a_row);
1005            let dzeta_dv = if raw_v > self.var_floor {
1006                let zeta = (z - m) * inv_sqrt_v;
1007                -zeta / (2.0 * v)
1008            } else {
1009                0.0
1010            };
1011            out.push(dzeta_dv);
1012            for &x in a_row.iter() {
1013                out.push(dzeta_dv * x);
1014            }
1015        }
1016        out
1017    }
1018
1019    /// Block-diagonal first-stage covariance `V₁ = blkdiag(mean_cov, var_cov)`
1020    /// of `θ₁`, ordered to match [`Self::zeta_theta1_jacobian_row`]. The two
1021    /// stages are fit on (asymptotically) uncorrelated estimating equations
1022    /// (the mean score `Σ w û A` and the Breusch–Pagan variance score
1023    /// `Σ w (û² − v) A` are orthogonal under the Gaussian working model), so the
1024    /// joint first-stage covariance is block-diagonal to first order — the same
1025    /// approximation the Rao gate above uses.
1026    pub fn theta1_covariance(&self) -> Array2<f64> {
1027        let dm = self.mean_coeffs.len();
1028        let dv = self.var_coeffs.len();
1029        let mut v1 = Array2::<f64>::zeros((dm + dv, dm + dv));
1030        v1.slice_mut(s![..dm, ..dm]).assign(&self.mean_cov);
1031        if dv > 0 {
1032            v1.slice_mut(s![dm.., dm..]).assign(&self.var_cov);
1033        }
1034        v1
1035    }
1036
1037    /// Murphy–Topel generated-regressor correction term for the second-stage
1038    /// slope covariance. Given the second-stage information `H_β` (the penalized
1039    /// joint Hessian of the slope fit, whose inverse is the naive `V_β`) and the
1040    /// cross-derivative `G = ∂(score_β)/∂θ₁` (`p_β × dim θ₁`), the corrected
1041    /// covariance is
1042    ///
1043    ///   `V_β = V_β^naive + (H_β⁻¹ G) V₁ (H_β⁻¹ G)ᵀ`.
1044    ///
1045    /// This returns the additive rank-`dim θ₁` term `(H_β⁻¹ G) V₁ (H_β⁻¹ G)ᵀ`
1046    /// given the already-formed `hbeta_inv_g = H_β⁻¹ G` (`p_β × dim θ₁`). The
1047    /// caller forms `G` by accumulating the per-row slope-score sensitivity to
1048    /// `ζ_i` times [`Self::zeta_theta1_jacobian_row`] (chain rule
1049    /// `∂score_β/∂θ₁ = Σ_i (∂score_β/∂ζ_i) (∂ζ_i/∂θ₁)`).
1050    pub fn generated_regressor_term(&self, hbeta_inv_g: ArrayView2<'_, f64>) -> Array2<f64> {
1051        let v1 = self.theta1_covariance();
1052        hbeta_inv_g.dot(&v1).dot(&hbeta_inv_g.t())
1053    }
1054
1055    /// Assemble the full Murphy–Topel generated-regressor correction
1056    /// `(Vb·G)·V₁·(Vb·G)ᵀ` for the second-stage slope covariance, given the ONE
1057    /// engine-side quantity it cannot reconstruct post-fit: the per-row
1058    /// reduced-frame slope-score sensitivity to the calibrated score,
1059    /// `s_i = ∂score_β,i/∂ζ_i` (a `p_β`-vector in the joint flat-β reduced frame
1060    /// `solved_fit.beta_covariance()` lives in). With `score_β,i = ∂ℓ_i/∂β`,
1061    /// `s_i = ∂²ℓ_i/∂β∂ζ_i = J_iᵀ·(∂²ℓ_i/∂η_i∂ζ_i)` is the mixed `(β, ζ)`
1062    /// second derivative of the warped row kernel contracted through the slope
1063    /// design Jacobian `J_i` — exactly the #932 RowNllProgram/Tower4 z-jet
1064    /// channel (`z` is already a row-program input; one extra mixed `(β, z)` jet
1065    /// channel reads off `∂²ℓ/∂β∂z`). It must be evaluated at the converged `β̂`
1066    /// in the SAME reduced frame as `vb`.
1067    ///
1068    /// Everything else is built here from the stored first-stage quantities and
1069    /// the second-stage fit, dissolving the post-fit-reconstruction blocker:
1070    ///   - `G = Σ_i s_i · (∂ζ_i/∂θ₁)ᵀ` (`p_β × dim θ₁`), the chain-rule outer
1071    ///     product accumulated row-by-row with `∂ζ_i/∂θ₁ =
1072    ///     `[`Self::zeta_theta1_jacobian_row`]`(z_i, a_row_i)` (exact-zero on
1073    ///     floored rows, so floored rows contribute nothing — `G`'s support is
1074    ///     the gate-fired rows);
1075    ///   - `Vb·G = vb·G` since the naive second-stage covariance `vb` IS
1076    ///     `H_β⁻¹` (the coordinator's `H_β⁻¹ G = Vb.dot(G)`);
1077    ///   - the term `(Vb·G)·V₁·(Vb·G)ᵀ` via [`Self::generated_regressor_term`].
1078    ///
1079    /// `score_zeta_sensitivity` is `n × p_β` (row `i` = `s_i`); `z` is the
1080    /// per-row normalized latent score (`n`); `a_block` is the marginal design
1081    /// `n × basis_ncols` whose rows feed `zeta_theta1_jacobian_row`; `vb` is the
1082    /// naive reduced-frame slope covariance `n_β × n_β`. The returned term is
1083    /// PSD (a congruence of the PSD `V₁`), so adding it to `vb` makes the
1084    /// corrected slope SE strictly ≥ the naive SE whenever the gate fires
1085    /// (`G ≠ 0`) and exactly equal when every row is floored (`G = 0`).
1086    pub fn generated_regressor_correction(
1087        &self,
1088        score_zeta_sensitivity: ArrayView2<'_, f64>,
1089        z: ArrayView1<'_, f64>,
1090        a_block: ArrayView2<'_, f64>,
1091        vb: ArrayView2<'_, f64>,
1092    ) -> Result<Array2<f64>, String> {
1093        let n = score_zeta_sensitivity.nrows();
1094        let p_beta = score_zeta_sensitivity.ncols();
1095        if z.len() != n || a_block.nrows() != n {
1096            return Err(format!(
1097                "generated_regressor_correction row mismatch: score_zeta_sensitivity rows={n}, \
1098                 z={}, a_block rows={}",
1099                z.len(),
1100                a_block.nrows()
1101            ));
1102        }
1103        if a_block.ncols() != self.basis_ncols {
1104            return Err(format!(
1105                "generated_regressor_correction expects {} basis columns, got {}",
1106                self.basis_ncols,
1107                a_block.ncols()
1108            ));
1109        }
1110        if vb.nrows() != p_beta || vb.ncols() != p_beta {
1111            return Err(format!(
1112                "generated_regressor_correction: vb must be {p_beta}×{p_beta}, got {}×{}",
1113                vb.nrows(),
1114                vb.ncols()
1115            ));
1116        }
1117        // G = Σ_i s_i ⊗ (∂ζ_i/∂θ₁)  (p_β × dim θ₁). Each row contributes the
1118        // rank-1 outer product `s_i ⊗ J_zeta_i`, so summed over the n rows this
1119        // is exactly the cross product `G = Sᵀ·J` of the score-sensitivity
1120        // matrix `S` (`n × p_β`, supplied) and the per-row ζ-Jacobian matrix
1121        // `J` (`n × dim θ₁`). Forming `J` row-by-row is O(n·dim θ₁); the cross
1122        // product is then a single BLAS-3 GEMM rather than the O(n·p_β·dim θ₁)
1123        // scalar triple loop (≈1.5e9 FMA at biobank scale, n≈194k, the dominant
1124        // ~13s/disease cost of the SE correction). Floored rows yield an exact
1125        // all-zero `J` row, so they contribute zero to the GEMM — bit-identical
1126        // to skipping them, no approximation.
1127        let j_mat = self.build_zeta_theta1_jacobian(z, a_block);
1128        let vb_g = self.beta_theta1_sensitivity(score_zeta_sensitivity, j_mat.view(), vb)?;
1129        Ok(self.generated_regressor_term(vb_g.view()))
1130    }
1131
1132    /// Per-row ζ-Jacobian matrix `J` (`n × dim θ₁`, row `i` = `∂ζ_i/∂θ₁`) built
1133    /// row-by-row from [`Self::zeta_theta1_jacobian_row`]. Floored rows yield an
1134    /// exact all-zero row, so they contribute nothing to the `G = Sᵀ·J` cross
1135    /// product (bit-identical to skipping them).
1136    fn build_zeta_theta1_jacobian(
1137        &self,
1138        z: ArrayView1<'_, f64>,
1139        a_block: ArrayView2<'_, f64>,
1140    ) -> Array2<f64> {
1141        let n = a_block.nrows();
1142        let dim_theta1 = self.theta1_dim();
1143        let mut j_mat = Array2::<f64>::zeros((n, dim_theta1));
1144        for i in 0..n {
1145            let j_zeta_row = self.zeta_theta1_jacobian_row(z[i], a_block.row(i));
1146            assert_eq!(
1147                j_zeta_row.len(),
1148                dim_theta1,
1149                "J_zeta row width must match the first-stage hyperparameter dimension"
1150            );
1151            let mut dst = j_mat.row_mut(i);
1152            for (slot, jz) in dst.iter_mut().zip(j_zeta_row.into_iter()) {
1153                *slot = jz;
1154            }
1155        }
1156        j_mat
1157    }
1158
1159    /// Signed first-order sensitivity `∂β̂/∂θ₁ = Vb·G` (`p_β × dim θ₁`) of the
1160    /// converged second-stage slope to the first-stage calibration parameters,
1161    /// the SIGNED quantity the Murphy–Topel correction is built from.
1162    ///
1163    /// `G = Sᵀ·J = Σ_i s_i ⊗ (∂ζ_i/∂θ₁)` with `s_i = ∂score_β,i/∂ζ_i` the
1164    /// LOG-LIKELIHOOD-score sensitivity (the sign convention #1131 fixes at the
1165    /// source in [`gradient_paths::rigid_standard_normal_mixed_z_sensitivity`]),
1166    /// and `Vb = H_β⁻¹` the NLL-Hessian inverse. Under this convention the
1167    /// implicit-function theorem on `∂(log L)/∂β = 0` gives
1168    /// `∂β̂/∂θ₁ = +H_β⁻¹·G = +Vb·G`, so the returned matrix matches the finite
1169    /// difference of the refit slope in θ₁ in BOTH sign and magnitude — unlike
1170    /// the PSD correction term [`Self::generated_regressor_correction`], which is
1171    /// invariant to this sign. `j_zeta` is the per-row ζ-Jacobian matrix
1172    /// (`n × dim θ₁`, row `i` = `∂ζ_i/∂θ₁`).
1173    fn beta_theta1_sensitivity(
1174        &self,
1175        score_zeta_sensitivity: ArrayView2<'_, f64>,
1176        j_zeta: ArrayView2<'_, f64>,
1177        vb: ArrayView2<'_, f64>,
1178    ) -> Result<Array2<f64>, String> {
1179        // G = Sᵀ·J (p_β × dim θ₁) via the SIMD/GPU-routed cross product.
1180        let g = gam_linalg::faer_ndarray::fast_atb(&score_zeta_sensitivity, &j_zeta);
1181        // Vb·G = H_β⁻¹·G (vb is the naive reduced-frame covariance the fit
1182        // already produced — reused, never recomputed).
1183        Ok(vb.dot(&g))
1184    }
1185}
1186
1187/// First-stage robust (HC0) sandwich covariance of a weighted-ridge coefficient
1188/// vector: `V₁ = M⁺ (Σ_i w_i² û_i² A_i A_iᵀ) M⁺` with `M = AᵀWA + λR` the
1189/// ridge normal matrix that produced the coefficients, `W = diag(weights)`,
1190/// `û_i` the per-row residual, and `A` the regression basis (here `[1 | a(C)]`).
1191/// `M⁺` is the Moore–Penrose pseudo-inverse via eigendecomposition with a
1192/// relative tolerance: identifiable directions get the usual `(λ_eff)⁻¹` weight,
1193/// and rank-deficient directions (where some θ₁ components are not identified
1194/// by `A`) are zeroed — they carry no asymptotic distribution, so V₁ in those
1195/// directions is zero, and the Murphy–Topel propagation through identifiable
1196/// functionals of β remains finite and consistent. Using the ordinary inverse
1197/// here let the unregularized direction's `1/ε` blow `M⁻¹·meat·M⁻¹` through
1198/// the f64 range whenever the wide marginal-index span had a near-null
1199/// direction (the bug behind "conditional latent calibration sandwich
1200/// covariance is non-finite" on wide rank-deficient duchon/spline conditioning).
1201/// The meat is formed as `BᵀB` with `B_i = w_i û_i A_iᵀ` (signed) so the
1202/// fused-multiply GEMM is the same SIMD path used everywhere else in the
1203/// codebase, instead of a hand-rolled triple loop whose partial sums could
1204/// overflow on a single pathological row of the basis.
1205pub(crate) fn weighted_ridge_sandwich_cov(
1206    basis: ArrayView2<'_, f64>,
1207    residuals: &[f64],
1208    weights: ArrayView1<'_, f64>,
1209    normal_matrix: &Array2<f64>,
1210) -> Result<Array2<f64>, String> {
1211    let n = basis.nrows();
1212    let p = basis.ncols();
1213    if residuals.len() != n || weights.len() != n {
1214        return Err(format!(
1215            "weighted ridge sandwich length mismatch: rows={n}, residuals={}, weights={}",
1216            residuals.len(),
1217            weights.len()
1218        ));
1219    }
1220    if normal_matrix.nrows() != p || normal_matrix.ncols() != p {
1221        return Err(format!(
1222            "weighted ridge sandwich normal-matrix shape mismatch: basis cols={p}, normal {}x{}",
1223            normal_matrix.nrows(),
1224            normal_matrix.ncols()
1225        ));
1226    }
1227    // Robust HC0 meat as a Gram: build `B` with `B_i = (w_i û_i) A_iᵀ` (rows of
1228    // basis scaled by `w_i û_i`, sign carried), so `meat = BᵀB = Σ_i w_i² û_i²
1229    // A_i A_iᵀ` from one BLAS Gramian. Identical math to the per-row outer-
1230    // product accumulation, but the GEMM path keeps partial sums vectorized
1231    // and is less sensitive to a single pathological row producing an
1232    // intermediate that overflows f64 before the column-wise reduction cancels.
1233    let mut b = basis.to_owned();
1234    for i in 0..n {
1235        let wi = weights[i];
1236        let ri = residuals[i];
1237        let scale = wi * ri;
1238        if scale == 0.0 {
1239            b.row_mut(i).fill(0.0);
1240            continue;
1241        }
1242        b.row_mut(i).iter_mut().for_each(|value| *value *= scale);
1243    }
1244    let meat = gam_linalg::faer_ndarray::fast_ata(&b);
1245    // SPD pseudo-inverse of `M = AᵀWA + λR` via eigendecomposition with a
1246    // relative tolerance; symmetrize first to absorb floating-point asymmetry
1247    // accumulated in the AᵀWA assembly.
1248    let mut m_sym = normal_matrix.clone();
1249    gam_linalg::matrix::symmetrize_in_place(&mut m_sym);
1250    // Jacobi (symmetric diagonal) preconditioning. When the conditioning basis
1251    // spans many orders of magnitude — a power-9 Duchon RBF over 16 standardized
1252    // PCs produces columns differing by ~30 decades — `M` and `meat` live on
1253    // wildly different per-column scales, and the eigendecomposition behind
1254    // `M⁺ meat M⁺` loses all accuracy: the relative truncation tolerance is set
1255    // by `λ_max(M)` (dominated by the largest-scale column), so a genuinely
1256    // identified small-scale direction can be dropped while a near-null one is
1257    // kept, and the surviving `1/λ` then multiplies the huge `meat` straight
1258    // through the f64 range. Precondition by `D = diag(√M_jj)`. Because the ridge
1259    // penalty diagonal is built as the weighted Gram diagonal itself
1260    // (`penalty_jj = Σ_i w_i a_ij² = (AᵀWA)_jj`), `M_jj = (1+ρ)(AᵀWA)_jj`, so
1261    // `M̃ = D⁻¹ M D⁻¹` has EXACT unit diagonal and `M̃ = C + (ρ/(1+ρ))·I` with
1262    // `C` the basis correlation matrix (PSD). Hence `λ_min(M̃) ≥ ρ/(1+ρ) ≈ 1e-8`
1263    // even for a fully collinear basis, which clears the pseudo-inverse's
1264    // relative tolerance `≈ 1e-10·λ_max(M̃)` for the conditioning widths that
1265    // occur here: no direction is spuriously dropped, so `M̃⁺ = M̃⁻¹ = D M⁻¹ D`
1266    // and `cov = D⁻¹ (M̃⁻¹ meat̃ M̃⁻¹) D⁻¹ = M⁻¹ meat M⁻¹` EXACTLY — the scaling
1267    // cancels, this is the same sandwich, only computed on a well-conditioned
1268    // matrix. (Should a pure-ridge direction ever fall under tolerance at very
1269    // large width, dropping it is the correct scale-invariant identifiability
1270    // call.) `meat̃ = D⁻¹ meat D⁻¹`; `M_jj > 0` (Gram diagonal floored positive)
1271    // so `D` is always finite and invertible.
1272    let scale: Vec<f64> = (0..p)
1273        .map(|j| 1.0 / m_sym[[j, j]].max(f64::MIN_POSITIVE).sqrt())
1274        .collect();
1275    let mut m_scaled = m_sym;
1276    let mut meat_scaled = meat;
1277    for i in 0..p {
1278        for j in 0..p {
1279            let s = scale[i] * scale[j];
1280            m_scaled[[i, j]] *= s;
1281            meat_scaled[[i, j]] *= s;
1282        }
1283    }
1284    let (_rank, m_pinv) =
1285        gam_linalg::utils::block_penalty_rank_and_pinv(&m_scaled).map_err(|e| {
1286            format!("conditional latent calibration sandwich pseudo-inverse failed: {e}")
1287        })?;
1288    let mut cov = m_pinv.dot(&meat_scaled).dot(&m_pinv);
1289    // Undo the symmetric scaling: cov_raw = D⁻¹ cov_scaled D⁻¹.
1290    for i in 0..p {
1291        for j in 0..p {
1292            cov[[i, j]] *= scale[i] * scale[j];
1293        }
1294    }
1295    if cov.iter().any(|v| !v.is_finite()) {
1296        return Err("conditional latent calibration sandwich covariance is non-finite".to_string());
1297    }
1298    Ok(cov)
1299}
1300
1301/// Weighted mean of a slice of values.
1302pub(crate) fn weighted_mean(
1303    values: &[f64],
1304    weights: ArrayView1<'_, f64>,
1305    total_weight: f64,
1306) -> f64 {
1307    values
1308        .iter()
1309        .zip(weights.iter())
1310        .map(|(&v, &w)| w * v)
1311        .sum::<f64>()
1312        / total_weight
1313}
1314
1315/// Robust (heteroskedasticity-consistent) Rao/LM score-test p-value for the
1316/// null that the centered basis columns `ã(C)` carry no information about the
1317/// centered response `u`. This is the LAN locally-optimal statistic the issue
1318/// names: `s = Σ_i w_i u_i ã(C_i)`, `Ω̂ = Σ_i w_i² u_i² ã(C_i)ã(C_i)ᵀ`,
1319/// `D = sᵀ Ω̂⁺ s ⟶ χ²_{rank Ω̂}`. Both the conditional-mean test
1320/// (`u_i = z_i − z̄`) and the conditional-variance / Breusch-Pagan test
1321/// (`u_i = (z_i − z̄)² − σ̂²`) are this statistic with the same centered basis.
1322///
1323/// Returns `None` when the test is degenerate (no usable basis directions),
1324/// otherwise the asymptotic p-value.
1325pub(crate) fn robust_conditional_score_pvalue(
1326    a_centered: ArrayView2<'_, f64>,
1327    u: &[f64],
1328    weights: ArrayView1<'_, f64>,
1329) -> Result<Option<f64>, String> {
1330    let n = a_centered.nrows();
1331    let r = a_centered.ncols();
1332    if r == 0 || n == 0 {
1333        return Ok(None);
1334    }
1335    if u.len() != n || weights.len() != n {
1336        return Err(format!(
1337            "conditional score test length mismatch: rows={n}, u={}, weights={}",
1338            u.len(),
1339            weights.len()
1340        ));
1341    }
1342    // Build the per-row scaled basis `B` with `B_i = (w_i u_i) ã_i` once, then
1343    // recover both the score and the HC0 robust meat from it with two BLAS-3
1344    // GEMMs over chunked row-blocks instead of an `O(n · r²)` per-row scatter:
1345    //   • score  `s   = ãᵀ (w ∘ u) = Bᵀ 1`     (column sums of `B`),
1346    //   • meat   `Ω̂  = Σ_i w_i² u_i² ã_i ã_iᵀ = BᵀB` since `(w_i u_i)² = w_i² u_i²`.
1347    // A non-positive weight zeroes that row of `B` (its score and meat
1348    // contributions both vanish), reproducing the `wi <= 0.0` skip EXACTLY.
1349    // `fast_ata` is the same parallel Gramian the second-stage sandwich uses, so
1350    // the statistic is numerically identical to the row-accumulated form up to
1351    // the deterministic GEMM reduction order.
1352    let mut b = a_centered.to_owned();
1353    for i in 0..n {
1354        let wi = weights[i];
1355        let scale = if wi > 0.0 { wi * u[i] } else { 0.0 };
1356        if scale == 0.0 {
1357            b.row_mut(i).fill(0.0);
1358            continue;
1359        }
1360        b.row_mut(i).iter_mut().for_each(|value| *value *= scale);
1361    }
1362    let s = b.sum_axis(ndarray::Axis(0));
1363    let omega = gam_linalg::faer_ndarray::fast_ata(&b);
1364    if !s.iter().all(|v| v.is_finite()) || !omega.iter().all(|v| v.is_finite()) {
1365        return Ok(None);
1366    }
1367    let (rank, omega_pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(&omega)
1368        .map_err(|e| format!("conditional score test pseudo-inverse failed: {e}"))?;
1369    if rank == 0 {
1370        return Ok(None);
1371    }
1372    let d_stat = s.dot(&omega_pinv.dot(&s));
1373    if !(d_stat.is_finite() && d_stat >= 0.0) {
1374        return Ok(None);
1375    }
1376    // p = 1 − CDF_{χ²_rank}(D) = 1 − P(rank/2, D/2) (regularized lower gamma).
1377    let p_lower = statrs::function::gamma::gamma_lr(rank as f64 / 2.0, d_stat / 2.0);
1378    let p_value = (1.0 - p_lower).clamp(0.0, 1.0);
1379    Ok(Some(p_value))
1380}
1381
1382/// Fit the conditional location-scale calibration (#905) if the conditional
1383/// `E[z|C]`/`Var(z|C)` Rao gate fires on the marginal-index basis `a_block`.
1384///
1385/// Returns `None` when there is no conditional structure to correct (the gate
1386/// does not fire, or the basis is degenerate) — in that case the caller falls
1387/// back to the existing pooled-marginal gate (rank-INT or no calibration).
1388pub(crate) fn fit_conditional_latent_calibration_if_needed(
1389    z: &Array1<f64>,
1390    weights: &Array1<f64>,
1391    a_block: ArrayView2<'_, f64>,
1392) -> Result<Option<LatentZConditionalCalibration>, String> {
1393    let n = z.len();
1394    let p = a_block.ncols();
1395    if n != weights.len() {
1396        return Err(format!(
1397            "conditional latent gate length mismatch: z={n}, weights={}",
1398            weights.len()
1399        ));
1400    }
1401    if a_block.nrows() != n {
1402        return Err(format!(
1403            "conditional latent gate row mismatch: z={n}, basis rows={}",
1404            a_block.nrows()
1405        ));
1406    }
1407    if p == 0 {
1408        return Ok(None);
1409    }
1410    let total_weight = weights.iter().copied().sum::<f64>();
1411    if !(total_weight.is_finite() && total_weight > 0.0) {
1412        return Ok(None);
1413    }
1414    if z.iter().any(|v| !v.is_finite()) || a_block.iter().any(|v| !v.is_finite()) {
1415        return Ok(None);
1416    }
1417
1418    let z_mean = z
1419        .iter()
1420        .zip(weights.iter())
1421        .map(|(&zi, &wi)| wi * zi)
1422        .sum::<f64>()
1423        / total_weight;
1424    let global_var = z
1425        .iter()
1426        .zip(weights.iter())
1427        .map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
1428        .sum::<f64>()
1429        / total_weight;
1430    if !(global_var.is_finite() && global_var > 0.0) {
1431        return Ok(None);
1432    }
1433
1434    // Center each basis column by its weighted mean so the score test is about
1435    // conditional structure *beyond* the global level (the intercept nuisance).
1436    // A constant marginal-design column collapses to ~0 and is dropped by the
1437    // pseudo-inverse rank, so an intercept already present in a(C) is harmless.
1438    let mut a_centered = a_block.to_owned();
1439    for j in 0..p {
1440        let col = a_block.column(j);
1441        let col_mean = col
1442            .iter()
1443            .zip(weights.iter())
1444            .map(|(&v, &w)| w * v)
1445            .sum::<f64>()
1446            / total_weight;
1447        a_centered.column_mut(j).mapv_inplace(|v| v - col_mean);
1448    }
1449
1450    // Conditional-mean Rao test: u = z − z̄.
1451    let u_mean: Vec<f64> = z.iter().map(|&zi| zi - z_mean).collect();
1452    let p_mean = robust_conditional_score_pvalue(a_centered.view(), &u_mean, weights.view())?;
1453    // Conditional-variance (Breusch-Pagan) Rao test: u = (z − z̄)² − σ̂².
1454    let u_var: Vec<f64> = u_mean.iter().map(|&e| e * e - global_var).collect();
1455    let p_var = robust_conditional_score_pvalue(a_centered.view(), &u_var, weights.view())?;
1456
1457    let mean_fires = p_mean.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
1458    let var_fires = p_var.is_some_and(|p| p < AUTO_Z_CONDITIONAL_RAO_ALPHA);
1459    if !mean_fires && !var_fires {
1460        return Ok(None);
1461    }
1462
1463    // Escalation fires. Fit the conditional mean over the full basis
1464    // [1 | a(C)] via a weighted ridge (the ridge stabilizes a rank-deficient
1465    // marginal-index span; it does not meaningfully shrink the few directions
1466    // that triggered the gate). The conditional-mean correction is applied
1467    // whenever the gate fires (a pure-variance trigger leaves the C-slopes of
1468    // m(C) ≈ 0, so it reduces to harmless global centering).
1469    let basis = build_intercept_basis(a_block);
1470    // Per-column Tikhonov penalty scaled by the weighted Gram diagonal, so the
1471    // ridge is *relative* to each column's scale (a 1e-8 absolute ridge would
1472    // be negligible against an O(n) Gram and would not stabilize a
1473    // rank-deficient penalized-spline marginal index). `diag_jj = Σ_i w_i a_ij²`;
1474    // floored positive so the all-zero (already-dropped) directions still
1475    // receive a finite ridge and the factorization cannot fail.
1476    let mut penalty = Array2::<f64>::zeros((basis.ncols(), basis.ncols()));
1477    for j in 0..basis.ncols() {
1478        let diag_jj = basis
1479            .column(j)
1480            .iter()
1481            .zip(weights.iter())
1482            .map(|(&x, &w)| w * x * x)
1483            .sum::<f64>()
1484            .max(f64::MIN_POSITIVE);
1485        penalty[[j, j]] = diag_jj;
1486    }
1487    let z_col = z.view().insert_axis(ndarray::Axis(1));
1488    let (mean_coeffs_mat, mean_fitted) = gam_linalg::utils::gaussian_weighted_ridge(
1489        basis.view(),
1490        z_col,
1491        penalty.view(),
1492        weights.view(),
1493        AUTO_Z_CONDITIONAL_RIDGE_REL,
1494    )?;
1495    let mean_coeffs: Vec<f64> = mean_coeffs_mat.column(0).to_vec();
1496
1497    // First-stage (generated-regressor) normal matrix `M = AᵀWA + λR`, the same
1498    // weighted-ridge system `gaussian_weighted_ridge` factorizes internally;
1499    // rebuilt here so its inverse can form the closed-form coefficient sandwich
1500    // `V₁` that the second-stage Murphy–Topel correction consumes. `p` is the
1501    // marginal-index width (small), so this is a cheap dense `(p+1)²` form.
1502    let normal_matrix = {
1503        let mut wa = basis.to_owned();
1504        for i in 0..wa.nrows() {
1505            let wi = weights[i];
1506            wa.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1507        }
1508        let mut m = basis.t().dot(&wa);
1509        m += &(penalty.to_owned() * AUTO_Z_CONDITIONAL_RIDGE_REL);
1510        m
1511    };
1512    let mean_residuals: Vec<f64> = z
1513        .iter()
1514        .zip(mean_fitted.column(0).iter())
1515        .map(|(&zi, &mi)| zi - mi)
1516        .collect();
1517    let mean_cov = weighted_ridge_sandwich_cov(
1518        basis.view(),
1519        &mean_residuals,
1520        weights.view(),
1521        &normal_matrix,
1522    )?;
1523
1524    let var_floor = (AUTO_Z_CONDITIONAL_VAR_FLOOR_FRAC * global_var).max(f64::MIN_POSITIVE);
1525    let (var_coeffs, var_cov): (Vec<f64>, Array2<f64>) = if var_fires {
1526        // Conditional-variance correction: regress the squared mean-residual on
1527        // the same basis. Fitted values are floored at `var_floor` when applied.
1528        let resid_sq: Array1<f64> = mean_residuals.iter().map(|&e| e * e).collect();
1529        let resid_col = resid_sq.view().insert_axis(ndarray::Axis(1));
1530        let (var_coeffs_mat, var_fitted) = gam_linalg::utils::gaussian_weighted_ridge(
1531            basis.view(),
1532            resid_col,
1533            penalty.view(),
1534            weights.view(),
1535            AUTO_Z_CONDITIONAL_RIDGE_REL,
1536        )?;
1537        // First-stage sandwich for the variance coefficients on the same ridge
1538        // normal matrix `M` (the basis and weights are identical; only the
1539        // response — and hence the residual — differs). `û_i = (z−m̂)²_i − v̂_i`
1540        // is the Breusch–Pagan residual.
1541        let var_residuals: Vec<f64> = resid_sq
1542            .iter()
1543            .zip(var_fitted.column(0).iter())
1544            .map(|(&si, &vi)| si - vi)
1545            .collect();
1546        let cov = weighted_ridge_sandwich_cov(
1547            basis.view(),
1548            &var_residuals,
1549            weights.view(),
1550            &normal_matrix,
1551        )?;
1552        (var_coeffs_mat.column(0).to_vec(), cov)
1553    } else {
1554        (Vec::new(), Array2::<f64>::zeros((0, 0)))
1555    };
1556
1557    let mut calibration = LatentZConditionalCalibration {
1558        mean_coeffs,
1559        var_coeffs,
1560        basis_ncols: p,
1561        var_floor,
1562        global_var,
1563        post_mean: 0.0,
1564        post_sd: 1.0,
1565        mean_cov,
1566        var_cov,
1567    };
1568
1569    // Sanity-check post-correction moments on the training sample.
1570    let calibrated = calibration.apply(z.view(), a_block)?;
1571    let post_mean = weighted_mean(calibrated.as_slice().unwrap(), weights.view(), total_weight);
1572    let post_var = calibrated
1573        .iter()
1574        .zip(weights.iter())
1575        .map(|(&zi, &wi)| wi * (zi - post_mean) * (zi - post_mean))
1576        .sum::<f64>()
1577        / total_weight;
1578    calibration.post_mean = post_mean;
1579    calibration.post_sd = post_var.max(0.0).sqrt();
1580
1581    Ok(Some(calibration))
1582}
1583
1584/// Prepend a column of ones to `a_block`, producing the `[1 | a(C)]` regression
1585/// basis used by the conditional location-scale fit.
1586pub(crate) fn build_intercept_basis(a_block: ArrayView2<'_, f64>) -> Array2<f64> {
1587    let n = a_block.nrows();
1588    let p = a_block.ncols();
1589    let mut basis = Array2::<f64>::ones((n, p + 1));
1590    basis.slice_mut(s![.., 1..]).assign(&a_block);
1591    basis
1592}
1593
1594pub(crate) fn build_latent_measure_with_geometry(
1595    z: &Array1<f64>,
1596    weights: &Array1<f64>,
1597    policy: &LatentZPolicy,
1598    conditioning: Option<ArrayView2<'_, f64>>,
1599) -> Result<(LatentMeasureKind, LatentMeasureCalibration), String> {
1600    match policy.latent_measure {
1601        LatentMeasureSpec::Auto { grid_size: _ } => {
1602            // #905: conditional `E[z|C]`/`Var(z|C)` Rao gate. Inspect the latent
1603            // score's conditional moments on the marginal-index span a(C)
1604            // BEFORE the pooled-marginal gate. A significant conditional shift
1605            // is the `b(C)·m(C)` leakage the pooled gate cannot see and that
1606            // rank-INT provably cannot fix, so it takes precedence: route to the
1607            // conditional location-scale correction `ζ = (z−m(C))/√v(C)`.
1608            if let Some(a_block) = conditioning
1609                && let Some(cal) =
1610                    fit_conditional_latent_calibration_if_needed(z, weights, a_block)?
1611            {
1612                log::info!(
1613                    "[BMS latent-z] conditional location-scale calibrated: basis_ncols={} var_active={} post_mean={:.3e} post_sd={:.3e} (E[z|C]/Var(z|C) Rao gate fired)",
1614                    cal.basis_ncols,
1615                    !cal.var_coeffs.is_empty(),
1616                    cal.post_mean,
1617                    cal.post_sd,
1618                );
1619                return Ok((
1620                    LatentMeasureKind::StandardNormal,
1621                    LatentMeasureCalibration::ConditionalLocationScale(cal),
1622                ));
1623            }
1624            if latent_z_is_standard_normal_enough(z, weights, policy)? {
1625                Ok((
1626                    LatentMeasureKind::StandardNormal,
1627                    LatentMeasureCalibration::None,
1628                ))
1629            } else {
1630                // P4: route bad-normal latent z through a Blom-rankit
1631                // weighted rank inverse-normal transform. The transformed
1632                // sample is exactly N(0,1) by construction, so the
1633                // standard-normal closed-form rigid kernel is exact on the
1634                // calibrated scale. This replaces the heavyweight
1635                // local-/global-empirical paths at the construction site;
1636                // the calibration is persisted so prediction applies the
1637                // identical map.
1638                let calibration = LatentZRankIntCalibration::fit(z, weights)?;
1639                log::info!(
1640                    "[BMS latent-z] rank-INT calibrated: post_mean={:.3e} post_sd={:.3e} knots={}",
1641                    calibration.post_mean,
1642                    calibration.post_sd,
1643                    calibration.sorted_z.len(),
1644                );
1645                Ok((
1646                    LatentMeasureKind::StandardNormal,
1647                    LatentMeasureCalibration::RankInverseNormal(calibration),
1648                ))
1649            }
1650        }
1651        LatentMeasureSpec::StandardNormal => Ok((
1652            LatentMeasureKind::StandardNormal,
1653            LatentMeasureCalibration::None,
1654        )),
1655        LatentMeasureSpec::GlobalEmpirical { grid_size } => {
1656            let kind = build_global_empirical_latent_measure(z, weights, grid_size)?;
1657            Ok((kind, LatentMeasureCalibration::None))
1658        }
1659    }
1660}
1661
1662pub(crate) fn latent_z_is_standard_normal_enough(
1663    z: &Array1<f64>,
1664    weights: &Array1<f64>,
1665    policy: &LatentZPolicy,
1666) -> Result<bool, String> {
1667    if z.len() != weights.len() {
1668        return Err(format!(
1669            "latent-measure auto-detection length mismatch: z={}, weights={}",
1670            z.len(),
1671            weights.len()
1672        ));
1673    }
1674    let weight_sum = weights.iter().copied().sum::<f64>();
1675    let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
1676    if !(weight_sum.is_finite()
1677        && weight_sum > 0.0
1678        && weight_sq_sum.is_finite()
1679        && weight_sq_sum > 0.0)
1680    {
1681        return Err("latent-measure auto-detection requires positive finite weights".to_string());
1682    }
1683    let effective_n = weight_sum * weight_sum / weight_sq_sum;
1684    if !(effective_n.is_finite() && effective_n > 1.0) {
1685        return Err(
1686            "latent-measure auto-detection requires at least two effective observations"
1687                .to_string(),
1688        );
1689    }
1690    let mean = z
1691        .iter()
1692        .zip(weights.iter())
1693        .map(|(&zi, &wi)| wi * zi)
1694        .sum::<f64>()
1695        / weight_sum;
1696    let var = z
1697        .iter()
1698        .zip(weights.iter())
1699        .map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
1700        .sum::<f64>()
1701        / weight_sum;
1702    let sd = var.sqrt();
1703    if !(mean.is_finite() && sd.is_finite() && sd > 0.0) {
1704        return Ok(false);
1705    }
1706    let skew = z
1707        .iter()
1708        .zip(weights.iter())
1709        .map(|(&zi, &wi)| {
1710            let centered = (zi - mean) / sd;
1711            wi * centered.powi(3)
1712        })
1713        .sum::<f64>()
1714        / weight_sum;
1715    let excess_kurtosis = z
1716        .iter()
1717        .zip(weights.iter())
1718        .map(|(&zi, &wi)| {
1719            let centered = (zi - mean) / sd;
1720            wi * centered.powi(4)
1721        })
1722        .sum::<f64>()
1723        / weight_sum
1724        - 3.0;
1725    let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
1726    let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
1727    let ks_to_normal = weighted_ks_to_standard_normal(z, weights, weight_sum)?;
1728    let tail_mass_4 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_INNER);
1729    let tail_mass_6 = weighted_tail_mass(z, weights, weight_sum, AUTO_Z_NORMAL_TAIL_SIGMA_OUTER);
1730    let max_abs_z = z.iter().fold(0.0_f64, |acc, &zi| acc.max(zi.abs()));
1731    let normal_tail_4 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_INNER));
1732    let normal_tail_6 = 2.0 * (1.0 - normal_cdf(AUTO_Z_NORMAL_TAIL_SIGMA_OUTER));
1733    Ok(mean.abs() <= mean_tol
1734        && (sd - 1.0).abs() <= sd_tol
1735        && skew.is_finite()
1736        && skew.abs() <= policy.max_abs_skew.min(AUTO_Z_NORMAL_SKEW_TOL)
1737        && excess_kurtosis.is_finite()
1738        && excess_kurtosis.abs() <= policy.max_abs_excess_kurtosis.min(AUTO_Z_NORMAL_KURT_TOL)
1739        && ks_to_normal.is_finite()
1740        && ks_to_normal <= AUTO_Z_NORMAL_KS_TOL
1741        && tail_mass_4
1742            <= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_4 + AUTO_Z_NORMAL_TAIL_FLOOR_INNER
1743        && tail_mass_6
1744            <= AUTO_Z_NORMAL_TAIL_MASS_SLACK * normal_tail_6 + AUTO_Z_NORMAL_TAIL_FLOOR_OUTER
1745        && max_abs_z < AUTO_Z_NORMAL_MAX_ABS)
1746}
1747
1748pub(crate) fn build_global_empirical_latent_measure(
1749    z: &Array1<f64>,
1750    weights: &Array1<f64>,
1751    grid_size: usize,
1752) -> Result<LatentMeasureKind, String> {
1753    let grid = build_empirical_z_grid(z, weights, grid_size, "empirical latent measure")?;
1754    let measure = LatentMeasureKind::GlobalEmpirical { grid };
1755    measure.validate("empirical latent measure")?;
1756    Ok(measure)
1757}
1758
1759pub(crate) fn weighted_ks_to_standard_normal(
1760    z: &Array1<f64>,
1761    weights: &Array1<f64>,
1762    total_weight: f64,
1763) -> Result<f64, String> {
1764    let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
1765    for (&zi, &wi) in z.iter().zip(weights.iter()) {
1766        if !zi.is_finite() || !wi.is_finite() || wi < 0.0 {
1767            return Err(
1768                "latent-measure KS diagnostic requires finite z and non-negative finite weights"
1769                    .to_string(),
1770            );
1771        }
1772        if wi > 0.0 {
1773            pairs.push((zi, wi));
1774        }
1775    }
1776    pairs.sort_by(|left, right| {
1777        left.0
1778            .partial_cmp(&right.0)
1779            .expect("validated latent z values are finite")
1780    });
1781    let mut prev = 0.0;
1782    let mut ks = 0.0_f64;
1783    for (zi, wi) in pairs {
1784        let cdf = normal_cdf(zi);
1785        let next = prev + wi / total_weight;
1786        ks = ks.max((cdf - prev).abs()).max((cdf - next).abs());
1787        prev = next;
1788    }
1789    Ok(ks)
1790}
1791
1792pub(crate) fn weighted_tail_mass(
1793    z: &Array1<f64>,
1794    weights: &Array1<f64>,
1795    total_weight: f64,
1796    cutoff: f64,
1797) -> f64 {
1798    z.iter()
1799        .zip(weights.iter())
1800        .filter(|&(&zi, _)| zi.abs() > cutoff)
1801        .map(|(_, &wi)| wi)
1802        .sum::<f64>()
1803        / total_weight
1804}
1805
1806pub(crate) fn build_empirical_z_grid(
1807    z: &Array1<f64>,
1808    weights: &Array1<f64>,
1809    grid_size: usize,
1810    context: &str,
1811) -> Result<EmpiricalZGrid, String> {
1812    if grid_size < 3 {
1813        return Err(format!(
1814            "empirical latent measure grid_size must be at least 3, got {grid_size}"
1815        ));
1816    }
1817    if z.len() != weights.len() {
1818        return Err(format!(
1819            "{context} length mismatch: z={}, weights={}",
1820            z.len(),
1821            weights.len()
1822        ));
1823    }
1824    let mut pairs = Vec::<(f64, f64)>::with_capacity(z.len());
1825    for (idx, (&zi, &wi)) in z.iter().zip(weights.iter()).enumerate() {
1826        if !zi.is_finite() {
1827            return Err(format!(
1828                "{context} z value at row {idx} is non-finite ({zi})"
1829            ));
1830        }
1831        if !wi.is_finite() || wi < 0.0 {
1832            return Err(format!(
1833                "{context} weight at row {idx} must be finite and non-negative, got {wi}"
1834            ));
1835        }
1836        if wi > 0.0 {
1837            pairs.push((zi, wi));
1838        }
1839    }
1840    if pairs.len() < 2 {
1841        return Err(format!(
1842            "{context} requires at least two positive-weight rows"
1843        ));
1844    }
1845    pairs.sort_by(|left, right| {
1846        left.0
1847            .partial_cmp(&right.0)
1848            .expect("validated empirical latent z values are finite")
1849    });
1850    let total_weight = pairs.iter().map(|(_, weight)| *weight).sum::<f64>();
1851    if !(total_weight.is_finite() && total_weight > 0.0) {
1852        return Err(format!("{context} requires positive finite total weight"));
1853    }
1854
1855    let m = grid_size.min(pairs.len());
1856    let mut nodes = Vec::with_capacity(m);
1857    let mut out_weights = Vec::with_capacity(m);
1858    let bin_weight_target = total_weight / (m as f64);
1859    let mut cursor = 0usize;
1860    let mut remaining = pairs[0].1;
1861    for _ in 0..m {
1862        let mut need = bin_weight_target;
1863        let mut bin_weight = 0.0;
1864        let mut bin_sum = 0.0;
1865        while need > EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * bin_weight_target
1866            && cursor < pairs.len()
1867        {
1868            let take = remaining.min(need);
1869            bin_sum += take * pairs[cursor].0;
1870            bin_weight += take;
1871            need -= take;
1872            remaining -= take;
1873            if remaining <= EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL * pairs[cursor].1 {
1874                cursor += 1;
1875                if cursor < pairs.len() {
1876                    remaining = pairs[cursor].1;
1877                }
1878            }
1879        }
1880        if bin_weight > 0.0 {
1881            nodes.push(bin_sum / bin_weight);
1882            out_weights.push(bin_weight / total_weight);
1883        }
1884    }
1885    if nodes.len() < 2 {
1886        return Err(format!(
1887            "{context} compression produced fewer than two nodes"
1888        ));
1889    }
1890    recenter_rescale_empirical_grid(&mut nodes, &out_weights);
1891    let total = out_weights.iter().sum::<f64>();
1892    if total.is_finite() && total > 0.0 {
1893        for weight in &mut out_weights {
1894            *weight /= total;
1895        }
1896    }
1897    validate_empirical_z_grid(&nodes, &out_weights, context)?;
1898    Ok(EmpiricalZGrid {
1899        nodes,
1900        weights: out_weights,
1901    })
1902}
1903
1904pub(crate) fn recenter_rescale_empirical_grid(nodes: &mut [f64], weights: &[f64]) {
1905    let total = weights.iter().sum::<f64>();
1906    if !(total.is_finite() && total > 0.0) {
1907        return;
1908    }
1909    let mean = nodes
1910        .iter()
1911        .zip(weights.iter())
1912        .map(|(&node, &weight)| weight * node)
1913        .sum::<f64>()
1914        / total;
1915    let var = nodes
1916        .iter()
1917        .zip(weights.iter())
1918        .map(|(&node, &weight)| weight * (node - mean).powi(2))
1919        .sum::<f64>()
1920        / total;
1921    let sd = var.sqrt();
1922    if sd.is_finite() && sd > BMS_VARIANCE_FLOOR {
1923        for node in nodes {
1924            *node = (*node - mean) / sd;
1925        }
1926    }
1927}
1928
1929// ---------------------------------------------------------------------------
1930// Cross-module constants — declared here so all submodules can reach them
1931// via `use super::*` without promoting implementation details to pub(crate).
1932// ---------------------------------------------------------------------------
1933pub(super) const BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
1934pub(super) const BERNOULLI_LINK_PROBABILITY_EPS: f64 = 1e-12;
1935pub(super) const BMS_VARIANCE_FLOOR: f64 = 1e-12;
1936pub(super) const BMS_DERIV_TOL: f64 = 1e-8;
1937/// Relative tolerance below which a residual weight is treated as exhausted in
1938/// the equal-mass empirical-grid compression loop. Used both for the per-bin
1939/// "need" remaining (relative to the target bin weight) and for the per-pair
1940/// remainder (relative to that pair's weight), so a pair/bin that is filled to
1941/// within a few ulps advances the cursor instead of spinning on round-off.
1942pub(super) const EMPIRICAL_GRID_WEIGHT_EXHAUSTED_REL_TOL: f64 = 1e-14;
1943/// Upper bound (and large-`n` default) for rows-per-chunk in the parallel
1944/// row-accumulation phases.
1945///
1946/// This is also a hard *ceiling* the pool-aware [`bms_row_chunk_size`] must
1947/// respect: several per-chunk fast paths (block-Hessian / block-gradient
1948/// assembly) allocate fixed `[0.0f64; ROW_CHUNK_SIZE]` stack buffers and index
1949/// them by the chunk's local row position, so a chunk may never carry more than
1950/// `ROW_CHUNK_SIZE` rows.
1951pub(super) const ROW_CHUNK_SIZE: usize = 1024;
1952/// Floor for rows-per-chunk: below it the per-chunk scratch allocation +
1953/// scheduler hand-off cost dominates the row arithmetic. Small enough that a
1954/// moderate `n` on a many-core box still carves several chunks per worker.
1955pub(super) const ROW_CHUNK_MIN: usize = 64;
1956/// Target number of row-chunks per rayon worker for the BMS exact-Newton
1957/// row-fan-out phases (gradient / HVP / diagonal directional-derivative sweeps).
1958///
1959/// Several chunks per worker keeps the pool load-balanced across the uneven
1960/// per-row cost tail (work-stealing moves whole chunks, never partial sums) so
1961/// the heavy coord-corrections / row-stream phases saturate the cores instead
1962/// of stranding the tail on one worker.
1963pub(super) const ROW_CHUNKS_PER_WORKER: usize = 4;
1964
1965/// Pool-aware rows-per-chunk for the BMS exact-Newton row fan-outs.
1966///
1967/// A *fixed* `ROW_CHUNK_SIZE` divisor makes the chunk **count** scale with `n`,
1968/// so at moderate `n` (e.g. `n = 10·ROW_CHUNK_SIZE` on a 64-core box) the
1969/// `into_par_iter` over `⌈n/ROW_CHUNK_SIZE⌉` chunks has far fewer tasks than
1970/// workers and most cores idle — the measured ~30-90% core utilization on the
1971/// biobank coord-corrections / row-stream phases. This sizes the chunk so the
1972/// chunk count targets `ROW_CHUNKS_PER_WORKER × worker_count` (the same policy
1973/// `chunked_row_reduction` uses), clamped to `[ROW_CHUNK_MIN, ROW_CHUNK_SIZE]`:
1974///
1975/// * the `ROW_CHUNK_SIZE` ceiling is mandatory — the block-assembly fast paths
1976///   index fixed `[…; ROW_CHUNK_SIZE]` stack buffers by local row, so a chunk
1977///   can never exceed it. At large `n` the per-1024-row count already exceeds
1978///   the worker count, so the clamp costs nothing there;
1979/// * the `ROW_CHUNK_MIN` floor stops sub-floor fan-out at tiny `n`.
1980///
1981/// The worker count is fixed for the process (one global pool; gam owns its
1982/// threads), so for a given `n` the returned chunk size — and therefore the
1983/// chunk boundaries `chunk_idx·chunk → (chunk_idx+1)·chunk` — is stable across
1984/// calls regardless of rayon work-stealing. The `try_fold`/`try_reduce` callers
1985/// already round-trip through these fixed boundaries, so swapping the divisor
1986/// changes only the chunk *count*, never how a chunk's rows are summed; any
1987/// bit-for-bit reduction-order property they had (same `n` ⇒ same boundaries ⇒
1988/// same tree) is preserved.
1989#[inline]
1990pub(super) fn bms_row_chunk_size(n: usize) -> usize {
1991    if n == 0 {
1992        return ROW_CHUNK_SIZE;
1993    }
1994    let workers = rayon::current_num_threads().max(1);
1995    let target_chunks = workers.saturating_mul(ROW_CHUNKS_PER_WORKER).max(1);
1996    // Rows per chunk that yields ≈ `target_chunks` chunks, clamped into
1997    // `[ROW_CHUNK_MIN, ROW_CHUNK_SIZE]`.
1998    n.div_ceil(target_chunks)
1999        .clamp(ROW_CHUNK_MIN, ROW_CHUNK_SIZE)
2000}
2001pub(super) const EXACT_WORK_LOG_MIN_ROWS: usize = 50_000;
2002pub(super) const BMS_ROW_PRIMARY_HESSIAN_EXPECTED_REUSE_PASSES: usize = 3;
2003pub(super) const BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES: usize = 2;
2004pub(super) const BMS_ROW_PRIMARY_HESSIAN_TILE_ROWS: usize = 8192;
2005pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM: u64 = 1;
2006pub(super) const BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN: u64 = 4;
2007pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM: u64 = 1;
2008pub(super) const BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN: u64 = 2;
2009pub(super) const BERNOULLI_MARGSLOPE_LINE_SEARCH_EARLY_EXIT_CHUNK_ROWS: usize = 10_000;
2010
2011// ---------------------------------------------------------------------------
2012// Submodule declarations
2013// ---------------------------------------------------------------------------
2014pub(crate) mod block_specs;
2015pub(crate) mod exact_eval_cache;
2016pub(crate) mod family;
2017pub(crate) mod gradient_paths;
2018pub(crate) mod hessian_paths;
2019pub(crate) mod install_flex;
2020pub(crate) mod row_kernel;
2021#[cfg(test)]
2022mod tests {
2023    include!("../../../../tests/src_modules/misc/families_bms_identifiability_rigid_tests.rs");
2024    include!(
2025        "../../../../tests/src_modules/optimization/families_bms_joint_hessian_hvp_correction_tests.rs"
2026    );
2027}
2028pub(crate) mod axis_direction_search;
2029pub(crate) mod cell_moment_assembly;
2030pub(crate) mod custom_family_impl;
2031pub(crate) mod row_primary_hessian;
2032
2033pub use block_specs::fit_bernoulli_marginal_slope_terms;
2034pub use gradient_paths::{
2035    MarginalSlopeCovariance, MarginalSlopeCovarianceShape, marginal_slope_covariance_from_scores,
2036    marginal_slope_preserving_scale, marginal_slope_probit_eta, padded_deviation_seed,
2037};
2038pub use install_flex::CrossBlockIdentifiabilityWarning;
2039pub(crate) use install_flex::FlexCompileOutcome;
2040
2041// pub(crate) re-exports for internal callers:
2042pub(crate) use block_specs::push_deviation_aux_blockspecs;
2043pub use block_specs::{BmsLogslopeJacobian, BmsMarginalJacobian};
2044pub(crate) use family::{
2045    BernoulliMarginalLinkMap, bernoulli_marginal_link_map,
2046    build_link_deviation_block_from_knots_design_seed_and_weights,
2047    build_score_warp_deviation_block_from_seed,
2048};
2049pub(crate) use gradient_paths::standardize_latent_z_with_policy;
2050pub(crate) use gradient_paths::{
2051    empirical_intercept_from_marginal, signed_probit_neglog_derivatives_up_to_fourth,
2052    unary_derivatives_log, unary_derivatives_log_normal_pdf, unary_derivatives_neglog_phi,
2053    unary_derivatives_sqrt,
2054};
2055pub(crate) use install_flex::{
2056    install_compiled_flex_block_into_runtime, project_monotone_feasible_beta,
2057    validate_monotone_structural_feasible,
2058};