Skip to main content

gam/inference/
predict.rs

1use crate::estimate::{BlockRole, EstimationError, FittedLinkState, UnifiedFitResult};
2use crate::families::bernoulli_marginal_slope::{
3    EmpiricalZGrid, LatentMeasureKind, bernoulli_marginal_link_map,
4    empirical_intercept_from_marginal,
5};
6use crate::families::lognormal_kernel::FrailtySpec;
7use crate::families::marginal_slope_shared::{
8    ObservedDenestedCellPartials, eval_coeff4_at,
9    probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
10};
11use crate::families::strategy::{FamilyStrategy, strategy_for_family, strategy_from_fit};
12use crate::inference::model::{
13    SavedAnchoredDeviationRuntime, SavedLatentZNormalization, SavedLinkWiggleRuntime,
14};
15use crate::inference::prediction_linalg::{
16    PredictionCovarianceBackend, design_row_chunk, prediction_chunk_rows,
17    rowwise_local_covariances_parallel,
18};
19use crate::linalg::utils::predict_gam_dimension_mismatch_message;
20use crate::matrix::{DesignMatrix, SymmetricMatrix};
21use crate::mixture_link::{
22    InverseLinkJet, beta_logistic_inverse_link_jetwith_param_partials,
23    mixture_inverse_link_jetwith_rho_partials_into, sas_inverse_link_jetwith_param_partials,
24};
25use crate::probability::{normal_cdf, normal_pdf, standard_normal_quantile};
26use crate::quadrature::QuadratureContext;
27use crate::types::{InverseLink, LikelihoodFamily};
28use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
29use rayon::iter::{IntoParallelIterator, ParallelIterator};
30
31thread_local! {
32    static PREDICT_QUADRATURE_CONTEXT: QuadratureContext = QuadratureContext::new();
33}
34
35/// Compute standard errors from a covariance matrix (sqrt of diagonal).
36pub fn se_from_covariance(cov: &Array2<f64>) -> Array1<f64> {
37    Array1::from_iter(cov.diag().iter().map(|&v| v.max(0.0).sqrt()))
38}
39
40fn apply_family_inverse_link(
41    eta: &Array1<f64>,
42    family: crate::types::LikelihoodFamily,
43    link_kind: Option<&InverseLink>,
44) -> Result<Array1<f64>, EstimationError> {
45    strategy_for_family(family, link_kind).inverse_link_array(eta.view())
46}
47
48fn local_covariances_with_backend<F>(
49    backend: &PredictionCovarianceBackend<'_>,
50    n_rows: usize,
51    local_dim: usize,
52    build_chunk: F,
53) -> Result<Vec<Vec<Array1<f64>>>, EstimationError>
54where
55    F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
56{
57    rowwise_local_covariances_parallel(backend, n_rows, local_dim, build_chunk)
58        .map_err(EstimationError::InvalidInput)
59}
60
61fn usable_penalized_hessian<'a>(
62    fit: &'a UnifiedFitResult,
63    expected_dim: usize,
64    label: &str,
65) -> Option<&'a Array2<f64>> {
66    let hessian = fit.penalized_hessian()?;
67    if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
68        log::warn!(
69            "{label}: ignoring penalized Hessian with shape {}x{}; expected {}x{}",
70            hessian.nrows(),
71            hessian.ncols(),
72            expected_dim,
73            expected_dim
74        );
75        return None;
76    }
77    if !hessian.iter().any(|value| value.abs() > 0.0) {
78        log::warn!("{label}: ignoring zero penalized Hessian placeholder");
79        return None;
80    }
81    Some(hessian)
82}
83
84fn conditional_prediction_backend<'a>(
85    fit: &'a UnifiedFitResult,
86    expected_dim: usize,
87    label: &str,
88) -> Option<PredictionCovarianceBackend<'a>> {
89    // The canonical conditional covariance is whatever the fitter exposes via
90    // `beta_covariance` (which is `Cov(β̂ | λ̂)` after any final reparameter
91    // alignment the fitter performed). The penalized Hessian is the precision
92    // matrix the fitter used to *derive* that covariance, but for the
93    // prediction path the dense covariance is the source of truth — using it
94    // directly avoids re-factorizing `H` and avoids silent disagreement when
95    // the stored covariance and Hessian were produced by different
96    // reparameterization stages of the fit.
97    //
98    // We fall back to factorizing the penalized Hessian only when no stored
99    // covariance is available. This keeps the conditional-covariance
100    // semantics in `predict_gam_with_uncertainty` consistent with
101    // `posterior_mean_backend_or_warn`, which already prefers
102    // `fit.beta_covariance()` over any indirect derivation.
103    if let Some(covariance) = fit.beta_covariance() {
104        if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
105            return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
106        }
107        log::warn!(
108            "{label}: ignoring conditional covariance with shape {}x{}; expected {}x{}",
109            covariance.nrows(),
110            covariance.ncols(),
111            expected_dim,
112            expected_dim
113        );
114    }
115    if let Some(hessian) = usable_penalized_hessian(fit, expected_dim, label) {
116        match PredictionCovarianceBackend::from_factorized_hessian(SymmetricMatrix::Dense(
117            hessian.clone(),
118        )) {
119            Ok(backend) => return Some(backend),
120            Err(err) => {
121                log::warn!(
122                    "{label}: failed to build factorized prediction precision backend: {err}"
123                );
124            }
125        }
126    }
127    None
128}
129
130fn selected_uncertainty_backend<'a>(
131    fit: &'a UnifiedFitResult,
132    expected_dim: usize,
133    requested_mode: InferenceCovarianceMode,
134    label: &str,
135) -> Result<(PredictionCovarianceBackend<'a>, bool), EstimationError> {
136    match requested_mode {
137        InferenceCovarianceMode::Conditional => {
138            conditional_prediction_backend(fit, expected_dim, label)
139                .map(|backend| (backend, false))
140                .ok_or_else(|| {
141                    EstimationError::InvalidInput(
142                "fit result does not contain conditional covariance or a usable penalized Hessian"
143                    .to_string(),
144            )
145                })
146        }
147        InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
148            if let Some(covariance) = fit.beta_covariance_corrected() {
149                if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
150                    return Err(EstimationError::InvalidInput(format!(
151                        "{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
152                        expected_dim,
153                        expected_dim,
154                        covariance.nrows(),
155                        covariance.ncols()
156                    )));
157                }
158                Ok((
159                    PredictionCovarianceBackend::from_dense(covariance.view()),
160                    true,
161                ))
162            } else {
163                selected_uncertainty_backend(
164                    fit,
165                    expected_dim,
166                    InferenceCovarianceMode::Conditional,
167                    label,
168                )
169            }
170        }
171        InferenceCovarianceMode::ConditionalPlusSmoothingRequired => {
172            let covariance = fit.beta_covariance_corrected().ok_or_else(|| {
173                EstimationError::InvalidInput(
174                    "fit result does not contain smoothing-corrected covariance".to_string(),
175                )
176            })?;
177            if covariance.nrows() != expected_dim || covariance.ncols() != expected_dim {
178                return Err(EstimationError::InvalidInput(format!(
179                    "{label}: corrected covariance dimension mismatch: expected {}x{}, got {}x{}",
180                    expected_dim,
181                    expected_dim,
182                    covariance.nrows(),
183                    covariance.ncols()
184                )));
185            }
186            Ok((
187                PredictionCovarianceBackend::from_dense(covariance.view()),
188                true,
189            ))
190        }
191    }
192}
193
194/// Symmetric quadratic form `g' · C · g` for an SPD posterior covariance `C`.
195///
196/// Math-equivalent to the naïve double loop, but exploits symmetry of `C`:
197///   `g' C g = Σ_i g_i² C_ii + 2 Σ_{i<j} g_i g_j C_ij`.
198/// This halves the multiplications and reads each off-diagonal entry only
199/// once, while pulling each row out as a contiguous slice (`Array2` is
200/// row-major) so the inner accumulator vectorizes.
201#[inline]
202fn quadratic_form(cov: &Array2<f64>, grad: &[f64]) -> Result<f64, EstimationError> {
203    let m = grad.len();
204    if cov.nrows() != m || cov.ncols() != m {
205        return Err(EstimationError::InvalidInput(format!(
206            "covariance/gradient dimension mismatch: covariance is {}x{}, gradient length is {}",
207            cov.nrows(),
208            cov.ncols(),
209            m
210        )));
211    }
212    let mut diag_acc = 0.0_f64;
213    let mut off_acc = 0.0_f64;
214    for i in 0..m {
215        let row = cov.row(i);
216        let row_slice = row.as_slice().expect("Array2 row is contiguous");
217        let gi = grad[i];
218        // Diagonal term g_i² C_ii.
219        diag_acc += gi * gi * row_slice[i];
220        // Strict upper triangle Σ_{j>i} g_i g_j C_ij; doubled below by symmetry.
221        let mut row_off = 0.0_f64;
222        for j in (i + 1)..m {
223            row_off += grad[j] * row_slice[j];
224        }
225        off_acc += gi * row_off;
226    }
227    Ok((diag_acc + 2.0 * off_acc).max(0.0))
228}
229
230/// Symmetric quadratic form for the mixture-link `∂μ/∂θ` row, exploiting the
231/// same `C = Cᵀ` symmetry as [`quadratic_form`]; see that function for the
232/// algebraic identity. Avoids materializing a separate `Vec<f64>` of `.mu`s.
233#[inline]
234fn quadratic_form_from_jetmu(
235    cov: &Array2<f64>,
236    partials: &[InverseLinkJet],
237) -> Result<f64, EstimationError> {
238    let m = partials.len();
239    if cov.nrows() != m || cov.ncols() != m {
240        return Err(EstimationError::InvalidInput(format!(
241            "covariance/mixture-gradient dimension mismatch: covariance is {}x{}, mixture gradient length is {}",
242            cov.nrows(),
243            cov.ncols(),
244            m
245        )));
246    }
247    let mut diag_acc = 0.0_f64;
248    let mut off_acc = 0.0_f64;
249    for i in 0..m {
250        let row = cov.row(i);
251        let row_slice = row.as_slice().expect("Array2 row is contiguous");
252        let gi = partials[i].mu;
253        diag_acc += gi * gi * row_slice[i];
254        let mut row_off = 0.0_f64;
255        for j in (i + 1)..m {
256            row_off += partials[j].mu * row_slice[j];
257        }
258        off_acc += gi * row_off;
259    }
260    Ok((diag_acc + 2.0 * off_acc).max(0.0))
261}
262
263fn linear_predictorvariance_from_backend(
264    x: &DesignMatrix,
265    backend: &PredictionCovarianceBackend<'_>,
266) -> Result<Array1<f64>, EstimationError> {
267    let local = local_covariances_with_backend(backend, x.nrows(), 1, |rows| {
268        Ok(vec![design_row_chunk(x, rows)?])
269    })?;
270    Ok(local[0][0].mapv(|v| v.max(0.0)))
271}
272
273const POSTERIOR_MEAN_VARIANCE_TOL: f64 = 1e-10;
274const POSTERIOR_MEAN_CROSS_TOL: f64 = 1e-10;
275
276fn posterior_mean_backend_or_warn<'a>(
277    fit: &'a UnifiedFitResult,
278    fallback: Option<&'a Array2<f64>>,
279    expected_dim: usize,
280    label: &str,
281) -> Option<PredictionCovarianceBackend<'a>> {
282    for (source, covariance) in [
283        ("fit result", fit.beta_covariance()),
284        ("predictor state", fallback),
285    ] {
286        let Some(covariance) = covariance else {
287            continue;
288        };
289        if covariance.nrows() == expected_dim && covariance.ncols() == expected_dim {
290            return Some(PredictionCovarianceBackend::from_dense(covariance.view()));
291        }
292        log::warn!(
293            "{label}: ignoring {source} covariance with shape {}x{}; expected {}x{}",
294            covariance.nrows(),
295            covariance.ncols(),
296            expected_dim,
297            expected_dim
298        );
299    }
300    if let Some(backend) = conditional_prediction_backend(fit, expected_dim, label) {
301        return Some(backend);
302    }
303    log::warn!(
304        "{label}: covariance/precision unavailable; falling back to plug-in point prediction"
305    );
306    None
307}
308
309fn require_posterior_mean_backend<'a>(
310    fit: &'a UnifiedFitResult,
311    fallback: Option<&'a Array2<f64>>,
312    expected_dim: usize,
313    label: &str,
314) -> Result<PredictionCovarianceBackend<'a>, EstimationError> {
315    posterior_mean_backend_or_warn(fit, fallback, expected_dim, label).ok_or_else(|| {
316        EstimationError::InvalidInput(format!(
317            "{label} requires covariance or penalized Hessian for posterior-mean prediction"
318        ))
319    })
320}
321
322fn project_two_block_linear_predictor_covariance(
323    design_first: &DesignMatrix,
324    design_second: &DesignMatrix,
325    backend: &PredictionCovarianceBackend<'_>,
326    p_first: usize,
327    p_second: usize,
328    label: &str,
329) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
330    let p_total = p_first + p_second;
331    if backend.nrows() != p_total {
332        return Err(EstimationError::InvalidInput(format!(
333            "{label} covariance dimension mismatch: expected parameter dimension {}, got {}",
334            p_total,
335            backend.nrows()
336        )));
337    }
338    if design_first.ncols() != p_first || design_second.ncols() != p_second {
339        return Err(EstimationError::InvalidInput(format!(
340            "{label} design dimension mismatch: threshold/location design has {} columns (expected {}), scale design has {} columns (expected {})",
341            design_first.ncols(),
342            p_first,
343            design_second.ncols(),
344            p_second
345        )));
346    }
347    let local = local_covariances_with_backend(backend, design_first.nrows(), 2, |rows| {
348        let x_first = design_row_chunk(design_first, rows.clone())?;
349        let x_second = design_row_chunk(design_second, rows.clone())?;
350        let rows_in_chunk = rows.end - rows.start;
351        let mut first = Array2::<f64>::zeros((rows_in_chunk, p_total));
352        let mut second = Array2::<f64>::zeros((rows_in_chunk, p_total));
353        first
354            .slice_mut(ndarray::s![.., 0..p_first])
355            .assign(&x_first);
356        second
357            .slice_mut(ndarray::s![.., p_first..p_total])
358            .assign(&x_second);
359        Ok(vec![first, second])
360    })?;
361    Ok((
362        local[0][0].mapv(|v| v.max(0.0)),
363        local[1][1].mapv(|v| v.max(0.0)),
364        local[0][1].clone(),
365    ))
366}
367
368fn linear_predictor_se_from_backend<F>(
369    backend: &PredictionCovarianceBackend<'_>,
370    n_rows: usize,
371    build_chunk: F,
372) -> Result<Array1<f64>, EstimationError>
373where
374    F: Fn(std::ops::Range<usize>) -> Result<Vec<Array2<f64>>, String> + Sync,
375{
376    let local = local_covariances_with_backend(backend, n_rows, 1, build_chunk)?;
377    Ok(local[0][0].mapv(|v| v.max(0.0).sqrt()))
378}
379
380fn padded_design_standard_errors_from_backend(
381    design: &DesignMatrix,
382    backend: &PredictionCovarianceBackend<'_>,
383    leading_zeros: usize,
384    trailing_zeros: usize,
385    label: &str,
386) -> Result<Array1<f64>, EstimationError> {
387    let p_design = design.ncols();
388    let p_total = leading_zeros + p_design + trailing_zeros;
389    if backend.nrows() != p_total {
390        return Err(EstimationError::InvalidInput(format!(
391            "{label} covariance dimension mismatch: expected parameter dimension {p_total}, got {}",
392            backend.nrows()
393        )));
394    }
395    linear_predictor_se_from_backend(backend, design.nrows(), |rows| {
396        let x = design_row_chunk(design, rows)?;
397        let rows_in_chunk = x.nrows();
398        let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
399        grad.slice_mut(ndarray::s![.., leading_zeros..leading_zeros + p_design])
400            .assign(&x);
401        Ok(vec![grad])
402    })
403}
404
405fn projected_bivariate_posterior_mean_result<F>(
406    quadctx: &crate::quadrature::QuadratureContext,
407    mu: [f64; 2],
408    cov: [[f64; 2]; 2],
409    integrand: F,
410) -> Result<f64, EstimationError>
411where
412    F: Fn(f64, f64) -> Result<f64, EstimationError>,
413{
414    let var0 = cov[0][0].max(0.0);
415    let var1 = cov[1][1].max(0.0);
416    let cov01 = cov[0][1];
417
418    if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && var1 <= POSTERIOR_MEAN_VARIANCE_TOL {
419        return integrand(mu[0], mu[1]);
420    }
421    if var0 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
422        return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
423            quadctx,
424            [mu[1]],
425            [[var1]],
426            21,
427            |x| integrand(mu[0], x[0]),
428        );
429    }
430    if var1 <= POSTERIOR_MEAN_VARIANCE_TOL && cov01.abs() <= POSTERIOR_MEAN_CROSS_TOL {
431        return crate::quadrature::normal_expectation_nd_adaptive_result::<1, _, _, EstimationError>(
432            quadctx,
433            [mu[0]],
434            [[var0]],
435            21,
436            |x| integrand(x[0], mu[1]),
437        );
438    }
439    crate::quadrature::normal_expectation_2d_adaptive_result(quadctx, mu, cov, integrand)
440}
441
442pub struct PredictResult {
443    pub eta: Array1<f64>,
444    pub mean: Array1<f64>,
445}
446
447// ═══════════════════════════════════════════════════════════════════════════
448//  PredictableModel trait — uniform prediction interface for all model types
449// ═══════════════════════════════════════════════════════════════════════════
450
451/// Input to the prediction trait. Contains the design matrix and metadata
452/// needed for point prediction + uncertainty quantification.
453pub struct PredictInput {
454    /// Design matrix for the primary (mean/location) block.
455    pub design: DesignMatrix,
456    /// Offset vector for the primary block.
457    pub offset: Array1<f64>,
458    /// Optional design matrix for the noise/scale block (GAMLSS/survival).
459    pub design_noise: Option<DesignMatrix>,
460    /// Optional offset vector for the noise/scale block.
461    pub offset_noise: Option<Array1<f64>>,
462    /// Optional auxiliary scalar covariate used by specialized predictors.
463    pub auxiliary_scalar: Option<Array1<f64>>,
464    /// Optional auxiliary matrix used by specialized predictors.
465    pub auxiliary_matrix: Option<Array2<f64>>,
466}
467
468fn slice_predict_input(
469    input: &PredictInput,
470    rows: std::ops::Range<usize>,
471) -> Result<PredictInput, EstimationError> {
472    Ok(PredictInput {
473        design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
474            design_row_chunk(&input.design, rows.clone()).map_err(EstimationError::InvalidInput)?,
475        )),
476        offset: input.offset.slice(ndarray::s![rows.clone()]).to_owned(),
477        design_noise: input
478            .design_noise
479            .as_ref()
480            .map(|design| {
481                design_row_chunk(design, rows.clone())
482                    .map(|d| DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(d)))
483                    .map_err(EstimationError::InvalidInput)
484            })
485            .transpose()?,
486        offset_noise: input
487            .offset_noise
488            .as_ref()
489            .map(|offset| offset.slice(ndarray::s![rows.clone()]).to_owned()),
490        auxiliary_scalar: input
491            .auxiliary_scalar
492            .as_ref()
493            .map(|values| values.slice(ndarray::s![rows.clone()]).to_owned()),
494        auxiliary_matrix: input
495            .auxiliary_matrix
496            .as_ref()
497            .map(|values| values.slice(ndarray::s![rows, ..]).to_owned()),
498    })
499}
500
501/// Point prediction with optional standard errors on the linear predictor.
502pub struct PredictionWithSE {
503    /// Linear predictor η = Xβ + offset.
504    pub eta: Array1<f64>,
505    /// Response-scale prediction g⁻¹(η).
506    pub mean: Array1<f64>,
507    /// Standard error of η (if covariance available).
508    pub eta_se: Option<Array1<f64>>,
509    /// Standard error of the mean (delta-method, if covariance available).
510    pub mean_se: Option<Array1<f64>>,
511}
512
513/// Trait for models that can produce predictions from new data.
514///
515/// Implemented by each model class (standard, GAMLSS, survival) to provide
516/// a uniform prediction interface. Eliminates the match-dispatch pattern in
517/// main.rs for predict, NUTS, and summary commands.
518pub trait PredictableModel {
519    /// Response-scale plug-in prediction at the fitted parameter value.
520    fn predict_plugin_response(
521        &self,
522        input: &PredictInput,
523    ) -> Result<PredictResult, EstimationError>;
524
525    /// Primary linear-predictor output.
526    fn predict_linear_predictor(
527        &self,
528        input: &PredictInput,
529    ) -> Result<Array1<f64>, EstimationError> {
530        self.predict_plugin_response(input).map(|pred| pred.eta)
531    }
532
533    /// Prediction with uncertainty quantification (SE on eta and mean scales).
534    fn predict_with_uncertainty(
535        &self,
536        input: &PredictInput,
537    ) -> Result<PredictionWithSE, EstimationError>;
538
539    /// Optional model-specific scale/noise parameter on the response side.
540    ///
541    /// This is distinct from estimator uncertainty. Models that expose a
542    /// per-observation distribution scale (for example Gaussian
543    /// location-scale `sigma`) override this and return it explicitly instead
544    /// of smuggling it through `PredictionWithSE`.
545    fn predict_noise_scale(
546        &self,
547        input: &PredictInput,
548    ) -> Result<Option<Array1<f64>>, EstimationError>;
549
550    /// Full prediction with confidence/observation intervals.
551    ///
552    /// Delegates to `predict_gamwith_uncertainty` for standard models.
553    /// Survival and location-scale models will override with domain-specific
554    /// interval construction.
555    fn predict_full_uncertainty(
556        &self,
557        input: &PredictInput,
558        fit: &UnifiedFitResult,
559        options: &PredictUncertaintyOptions,
560    ) -> Result<PredictUncertaintyResult, EstimationError>;
561
562    /// Posterior-mean prediction with coefficient uncertainty propagation.
563    ///
564    /// This is the canonical response-scale prediction path for nonlinear
565    /// models and the default semantics exposed by the CLI.
566    ///
567    /// When `confidence_level` is `Some(α)` with α ∈ (0, 1), the result
568    /// includes `mean_lower` / `mean_upper` confidence bounds.  Each predictor
569    /// computes bounds using the method natural to its parameterisation
570    /// (TransformEta for eta-scale SE, response-scale Delta for probability-
571    /// scale SE).
572    fn predict_posterior_mean(
573        &self,
574        input: &PredictInput,
575        fit: &UnifiedFitResult,
576        confidence_level: Option<f64>,
577    ) -> Result<PredictPosteriorMeanResult, EstimationError>;
578
579    /// Number of coefficient blocks in the model.
580    fn n_blocks(&self) -> usize;
581
582    /// Roles of each block.
583    fn block_roles(&self) -> Vec<BlockRole>;
584}
585
586/// Standard (single-block) GAM predictor.
587pub struct StandardPredictor {
588    pub beta: Array1<f64>,
589    pub family: crate::types::LikelihoodFamily,
590    pub link_kind: Option<InverseLink>,
591    pub covariance: Option<Array2<f64>>,
592    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
593}
594
595impl StandardPredictor {
596    /// Build a `StandardPredictor` from a `UnifiedFitResult`, extracting beta
597    /// from the first block and covariance from the unified result.
598    pub(crate) fn from_unified(
599        unified: &UnifiedFitResult,
600        family: crate::types::LikelihoodFamily,
601        link_kind: Option<InverseLink>,
602        link_wiggle: Option<SavedLinkWiggleRuntime>,
603    ) -> Result<Self, String> {
604        let expected_linkwiggle = link_wiggle.is_some();
605        if !expected_linkwiggle
606            && (unified.n_blocks() != 1 || unified.block_by_role(BlockRole::LinkWiggle).is_some())
607        {
608            return Err(
609                "StandardPredictor only supports single-block standard fits without link wiggles"
610                    .to_string(),
611            );
612        }
613        let beta = if expected_linkwiggle {
614            unified
615                .block_by_role(BlockRole::Mean)
616                .map(|b| b.beta.clone())
617                .ok_or_else(|| {
618                    "standard link-wiggle unified fit is missing Mean coefficient block".to_string()
619                })?
620        } else {
621            unified
622                .blocks
623                .first()
624                .map(|b| b.beta.clone())
625                .ok_or_else(|| {
626                    "standard unified fit is missing its sole coefficient block".to_string()
627                })?
628        };
629        let covariance = unified.covariance_conditional.clone();
630        Ok(Self {
631            beta,
632            family,
633            link_kind,
634            covariance,
635            link_wiggle,
636        })
637    }
638}
639
640impl PredictableModel for StandardPredictor {
641    fn predict_plugin_response(
642        &self,
643        input: &PredictInput,
644    ) -> Result<PredictResult, EstimationError> {
645        let eta_base = input.design.dot(&self.beta) + &input.offset;
646        let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
647            runtime
648                .apply(&eta_base)
649                .map_err(EstimationError::InvalidInput)?
650        } else {
651            eta_base
652        };
653        let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
654        let mean = strategy.inverse_link_array(eta.view())?;
655        Ok(PredictResult { eta, mean })
656    }
657
658    fn predict_with_uncertainty(
659        &self,
660        input: &PredictInput,
661    ) -> Result<PredictionWithSE, EstimationError> {
662        let result = self.predict_plugin_response(input)?;
663        let eta_base = input.design.dot(&self.beta) + &input.offset;
664        let (eta_se, mean_se) = if let Some(ref cov) = self.covariance {
665            let backend = PredictionCovarianceBackend::from_dense(cov.view());
666            let se = if let Some(runtime) = self.link_wiggle.as_ref() {
667                let p_main = self.beta.len();
668                let p_w = runtime.beta.len();
669                let p_total = p_main + p_w;
670                if backend.nrows() != p_total {
671                    return Err(EstimationError::InvalidInput(format!(
672                        "standard link-wiggle covariance dimension mismatch: expected parameter dimension {}, got {}",
673                        p_total,
674                        backend.nrows()
675                    )));
676                }
677                linear_predictor_se_from_backend(&backend, result.eta.len(), |rows| {
678                    let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
679                    let x_main = design_row_chunk(&input.design, rows.clone())?;
680                    let wiggle_design = runtime.design(&q0_chunk)?;
681                    let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
682                    let rows_in_chunk = q0_chunk.len();
683                    let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
684                    for i in 0..rows_in_chunk {
685                        for j in 0..p_main {
686                            grad[[i, j]] = dq_dq0[i] * x_main[[i, j]];
687                        }
688                    }
689                    grad.slice_mut(ndarray::s![.., p_main..p_total])
690                        .assign(&wiggle_design);
691                    Ok(vec![grad])
692                })?
693            } else {
694                eta_standard_errors_from_backend(&input.design, &backend)?
695            };
696            let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
697            let mean_se = delta_method_mean_se(&result.eta, &se, &strategy)?;
698            (Some(se), Some(mean_se))
699        } else {
700            (None, None)
701        };
702        Ok(PredictionWithSE {
703            eta: result.eta,
704            mean: result.mean,
705            eta_se,
706            mean_se,
707        })
708    }
709
710    fn predict_noise_scale(
711        &self,
712        _: &PredictInput,
713    ) -> Result<Option<Array1<f64>>, EstimationError> {
714        Ok(None)
715    }
716
717    fn predict_full_uncertainty(
718        &self,
719        input: &PredictInput,
720        fit: &UnifiedFitResult,
721        options: &PredictUncertaintyOptions,
722    ) -> Result<PredictUncertaintyResult, EstimationError> {
723        if self.link_wiggle.is_none() {
724            return predict_gamwith_uncertainty(
725                input.design.clone(),
726                self.beta.view(),
727                input.offset.view(),
728                self.family,
729                fit,
730                options,
731            );
732        }
733        let pred = self.predict_with_uncertainty(input)?;
734        let eta_se = pred.eta_se.clone().ok_or_else(|| {
735            EstimationError::InvalidInput(
736                "standard link-wiggle uncertainty requires covariance".to_string(),
737            )
738        })?;
739        let mean_se = pred.mean_se.clone().ok_or_else(|| {
740            EstimationError::InvalidInput(
741                "standard link-wiggle uncertainty requires covariance".to_string(),
742            )
743        })?;
744        let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
745            .map_err(EstimationError::InvalidInput)?;
746        let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
747        let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
748        let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
749        let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
750        let (lo, hi) = match self.family {
751            crate::types::LikelihoodFamily::GaussianIdentity => (f64::NEG_INFINITY, f64::INFINITY),
752            crate::types::LikelihoodFamily::PoissonLog
753            | crate::types::LikelihoodFamily::GammaLog => (0.0, f64::INFINITY),
754            _ => (1e-10, 1.0 - 1e-10),
755        };
756        mean_lower.mapv_inplace(|v| v.clamp(lo, hi));
757        mean_upper.mapv_inplace(|v| v.clamp(lo, hi));
758        Ok(PredictUncertaintyResult {
759            eta: pred.eta,
760            mean: pred.mean,
761            eta_standard_error: eta_se,
762            mean_standard_error: mean_se,
763            eta_lower,
764            eta_upper,
765            mean_lower,
766            mean_upper,
767            observation_lower: None,
768            observation_upper: None,
769            covariance_mode_requested: options.covariance_mode,
770            covariance_corrected_used: false,
771        })
772    }
773
774    fn predict_posterior_mean(
775        &self,
776        input: &PredictInput,
777        fit: &UnifiedFitResult,
778        confidence_level: Option<f64>,
779    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
780        let mut result = if self.link_wiggle.is_none() {
781            let backend = posterior_mean_backend_or_warn(
782                fit,
783                self.covariance.as_ref(),
784                self.beta.len(),
785                "standard posterior mean",
786            )
787            .ok_or_else(|| {
788                EstimationError::InvalidInput(
789                    "posterior-mean prediction requires beta covariance or penalized Hessian"
790                        .to_string(),
791                )
792            })?;
793            let strategy = strategy_from_fit(self.family, fit)?;
794            predict_gam_posterior_mean_from_backendwith_bc(
795                input.design.clone(),
796                self.beta.view(),
797                input.offset.view(),
798                &backend,
799                &strategy,
800                "standard posterior mean",
801                fit.bias_correction_beta().map(|b| b.view()),
802            )?
803        } else {
804            let runtime = self.link_wiggle.as_ref().expect("checked above");
805            let plugin = self.predict_plugin_response(input)?;
806            let eta_base = input.design.dot(&self.beta) + &input.offset;
807            let backend = posterior_mean_backend_or_warn(
808                fit,
809                self.covariance.as_ref(),
810                self.beta.len() + runtime.beta.len(),
811                "standard link-wiggle posterior mean",
812            )
813            .ok_or_else(|| {
814                EstimationError::InvalidInput(
815                    "posterior-mean prediction requires beta covariance or penalized Hessian"
816                        .to_string(),
817                )
818            })?;
819            let p_main = self.beta.len();
820            let p_w = runtime.beta.len();
821            let p_total = p_main + p_w;
822            if backend.nrows() != p_total {
823                return Err(EstimationError::InvalidInput(format!(
824                    "standard link-wiggle posterior mean covariance mismatch: expected parameter dimension {}, got {}",
825                    p_total,
826                    backend.nrows()
827                )));
828            }
829            let eta_se = linear_predictor_se_from_backend(&backend, plugin.eta.len(), |rows| {
830                let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
831                let x_main = design_row_chunk(&input.design, rows.clone())?;
832                let wiggle_design = runtime.design(&q0_chunk)?;
833                let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
834                let rows_in_chunk = q0_chunk.len();
835                let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
836                for i in 0..rows_in_chunk {
837                    for j in 0..p_main {
838                        grad[[i, j]] = dq_dq0[i] * x_main[[i, j]];
839                    }
840                }
841                grad.slice_mut(ndarray::s![.., p_main..p_total])
842                    .assign(&wiggle_design);
843                Ok(vec![grad])
844            })?;
845            let strategy = strategy_for_family(self.family, self.link_kind.as_ref());
846            let quadctx = crate::quadrature::QuadratureContext::new();
847            let mean = plugin
848                .eta
849                .iter()
850                .zip(eta_se.iter())
851                .map(|(&e, &se)| strategy.posterior_mean(&quadctx, e, se))
852                .collect::<Result<Array1<f64>, _>>()?;
853            PredictPosteriorMeanResult {
854                eta: plugin.eta,
855                eta_standard_error: eta_se,
856                mean,
857                mean_lower: None,
858                mean_upper: None,
859            }
860        };
861        if let Some(level) = confidence_level {
862            enrich_posterior_mean_bounds(&mut result, level, self.family, self.link_kind.as_ref())?;
863        }
864        Ok(result)
865    }
866
867    fn n_blocks(&self) -> usize {
868        if self.link_wiggle.is_some() { 2 } else { 1 }
869    }
870
871    fn block_roles(&self) -> Vec<BlockRole> {
872        if self.link_wiggle.is_some() {
873            vec![BlockRole::Mean, BlockRole::LinkWiggle]
874        } else {
875            vec![BlockRole::Mean]
876        }
877    }
878}
879
880pub struct BernoulliMarginalSlopePredictor {
881    pub beta_marginal: Array1<f64>,
882    pub beta_logslope: Array1<f64>,
883    pub beta_score_warp: Option<Array1<f64>>,
884    pub beta_link_dev: Option<Array1<f64>>,
885    pub base_link: InverseLink,
886    pub z_column: String,
887    pub latent_z_normalization: SavedLatentZNormalization,
888    pub latent_measure: LatentMeasureKind,
889    pub baseline_marginal: f64,
890    pub baseline_logslope: f64,
891    pub covariance: Option<Array2<f64>>,
892    pub score_warp_runtime: Option<SavedAnchoredDeviationRuntime>,
893    pub link_deviation_runtime: Option<SavedAnchoredDeviationRuntime>,
894    pub gaussian_frailty_sd: Option<f64>,
895    /// Optional rank-INT latent-z calibration. When `Some`, every
896    /// predict-time z (after `latent_z_normalization`) is routed through
897    /// [`LatentZRankIntCalibration::apply_at_predict`] before entering
898    /// the standard-normal closed-form rigid kernel — mirroring the
899    /// fit-time transform applied to training z. The map is the exact
900    /// monotone, invertible (up to empirical-CDF resolution) piecewise-
901    /// linear interpolation on (sorted_z, weighted_cdf) followed by Φ⁻¹,
902    /// so the calibrated sample is N(0,1) by construction. `None` means
903    /// training-time z already passed the strict normality check and no
904    /// transform was applied.
905    pub(crate) latent_z_calibration:
906        Option<crate::families::bernoulli_marginal_slope::LatentZRankIntCalibration>,
907}
908
909/// Per-runtime predict-time anchor correction matrices.
910///
911/// Built once per top-level predict call from the marginal + logslope
912/// designs at the prediction rows. Each `Array2<f64>` is shaped
913/// `n_predict × runtime.basis_dim` and holds `n_row(i) · M` for every
914/// prediction row, where `n_row(i)` is the concatenation of the marginal
915/// and logslope design rows in the runtime's anchor component order.
916///
917/// At any `local_cubic_at` / `basis_cubic_at` / `design` call site we
918/// subtract the appropriate slice of these matrices from the raw cubic
919/// output to apply the cross-block residual `n_row · M` correction.
920///
921/// `n_anchor_rows` is the underlying `n × d` parametric anchor stack
922/// (concatenated marginal + logslope). It is shared between the two
923/// runtimes' correction computations and reused by `design_with_anchor_rows`
924/// for the full-matrix evaluation sites.
925#[derive(Default)]
926struct BmsAnchorCorrections {
927    n_anchor_rows: Option<Array2<f64>>,
928    score_warp: Option<Array2<f64>>,
929    link_dev: Option<Array2<f64>>,
930}
931
932impl BmsAnchorCorrections {
933    fn score_warp_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
934        self.score_warp.as_ref().map(|m| m.row(row))
935    }
936
937    fn link_dev_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
938        self.link_dev.as_ref().map(|m| m.row(row))
939    }
940
941    fn n_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
942        self.n_anchor_rows.as_ref().map(|m| m.view())
943    }
944}
945
946impl BernoulliMarginalSlopePredictor {
947    /// Build the anchor correction matrices for a given predict-input batch.
948    ///
949    /// Returns an empty bundle (all `None`) when neither runtime carries
950    /// an anchor residual — this is the fast path for fits without
951    /// cross-block residualisation. When at least one runtime has a
952    /// residual, materialises the marginal + logslope designs at the
953    /// predict rows once and computes the per-runtime correction matrices
954    /// against each runtime's stored `M`.
955    fn build_anchor_correction_matrices(
956        &self,
957        input: &PredictInput,
958        design_logslope: &DesignMatrix,
959    ) -> Result<BmsAnchorCorrections, EstimationError> {
960        let needs_score = self
961            .score_warp_runtime
962            .as_ref()
963            .map_or(false, |r| r.anchor_residual_coefficients.is_some());
964        let needs_link = self
965            .link_deviation_runtime
966            .as_ref()
967            .map_or(false, |r| r.anchor_residual_coefficients.is_some());
968        if !needs_score && !needs_link {
969            return Ok(BmsAnchorCorrections::default());
970        }
971        // Materialise the marginal + logslope designs at predict rows.
972        // For biobank-scale predict batches the caller already chunks via
973        // `prediction_chunk_rows`, so this densification is bounded per
974        // chunk by `chunk_size × (p_marginal + p_logslope)`.
975        let marginal_dense = input
976            .design
977            .try_to_dense_arc(
978                "bernoulli marginal-slope predict-time marginal anchor materialisation",
979            )
980            .map_err(EstimationError::InvalidInput)?;
981        let logslope_dense = design_logslope
982            .try_to_dense_arc(
983                "bernoulli marginal-slope predict-time logslope anchor materialisation",
984            )
985            .map_err(EstimationError::InvalidInput)?;
986        let n_rows = marginal_dense.nrows();
987        if logslope_dense.nrows() != n_rows {
988            return Err(EstimationError::InvalidInput(format!(
989                "bernoulli marginal-slope predict anchor materialisation row mismatch: marginal {} vs logslope {}",
990                n_rows,
991                logslope_dense.nrows()
992            )));
993        }
994        let p_marginal = marginal_dense.ncols();
995        let p_logslope = logslope_dense.ncols();
996        let d = p_marginal + p_logslope;
997        let mut n_anchor_rows = Array2::<f64>::zeros((n_rows, d));
998        n_anchor_rows
999            .slice_mut(ndarray::s![.., 0..p_marginal])
1000            .assign(&marginal_dense.view());
1001        n_anchor_rows
1002            .slice_mut(ndarray::s![.., p_marginal..d])
1003            .assign(&logslope_dense.view());
1004        let score_warp = if needs_score {
1005            self.score_warp_runtime
1006                .as_ref()
1007                .unwrap()
1008                .anchor_correction_matrix(n_anchor_rows.view())
1009                .map_err(EstimationError::InvalidInput)?
1010        } else {
1011            None
1012        };
1013        let link_dev = if needs_link {
1014            self.link_deviation_runtime
1015                .as_ref()
1016                .unwrap()
1017                .anchor_correction_matrix(n_anchor_rows.view())
1018                .map_err(EstimationError::InvalidInput)?
1019        } else {
1020            None
1021        };
1022        Ok(BmsAnchorCorrections {
1023            n_anchor_rows: Some(n_anchor_rows),
1024            score_warp,
1025            link_dev,
1026        })
1027    }
1028
1029    fn likelihood_family(&self) -> LikelihoodFamily {
1030        LikelihoodFamily::BinomialProbit
1031    }
1032
1033    fn mean_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
1034        Ok(eta.mapv(normal_cdf))
1035    }
1036
1037    fn mean_derivative_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
1038        Ok(eta.mapv(normal_pdf))
1039    }
1040
1041    fn probit_frailty_scale(&self) -> f64 {
1042        marginal_slope_probit_frailty_scale(self.gaussian_frailty_sd)
1043    }
1044
1045    /// Apply the (optional) rank-INT latent-z calibration to a batch of
1046    /// normalized predict-time z values.
1047    ///
1048    /// The calibration was fit on the training z + weights as a Blom-
1049    /// rankit weighted rank inverse-normal transform; the calibrated
1050    /// sample is N(0, 1) by construction (exact, not approximate), which
1051    /// is why the BMS standard-normal closed-form kernel is correct on
1052    /// the calibrated scale. At predict time, every z that flows into a
1053    /// kernel evaluation site (`final_eta_and_gradient_from_theta`,
1054    /// `predict_eta_and_q_chain`, and indirectly the per-row `solve_intercept_scalar`
1055    /// / `evaluate_prediction_calibration` / `observed_denested_cell_partials_at_z`
1056    /// helpers that consume per-row scalar z values from the closure-
1057    /// captured `z` array) must be routed through the same monotone
1058    /// transform. When `latent_z_calibration` is `None`, this returns
1059    /// the input unchanged — that case corresponds to training-time z
1060    /// having passed the strict normality check, so no transform was
1061    /// applied at fit time either.
1062    fn apply_latent_z_calibration(&self, z: &Array1<f64>) -> Array1<f64> {
1063        match &self.latent_z_calibration {
1064            Some(cal) => Array1::from_iter(z.iter().map(|&zi| cal.apply_at_predict(zi))),
1065            None => z.clone(),
1066        }
1067    }
1068
1069    fn rigid_intercept_from_marginal(&self, marginal_eta: f64, slope: f64) -> f64 {
1070        let probit_scale = self.probit_frailty_scale();
1071        marginal_eta * (1.0 + (probit_scale * slope).powi(2)).sqrt() / probit_scale
1072    }
1073
1074    fn empirical_rigid_intercept_and_gradient(
1075        &self,
1076        marginal_eta: f64,
1077        slope: f64,
1078        nodes: &[f64],
1079        weights: &[f64],
1080    ) -> Result<(f64, f64, f64), EstimationError> {
1081        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1082            .map_err(EstimationError::InvalidInput)?;
1083        let scale = self.probit_frailty_scale();
1084        let intercept = empirical_intercept_from_marginal(
1085            marginal.mu,
1086            marginal.q,
1087            slope,
1088            scale,
1089            nodes,
1090            weights,
1091            None,
1092        )
1093        .map_err(EstimationError::InvalidInput)?;
1094        let observed_slope = scale * slope;
1095        let mut f_a = 0.0;
1096        let mut f_b = 0.0;
1097        for (&node, &weight) in nodes.iter().zip(weights.iter()) {
1098            let eta = intercept + observed_slope * node;
1099            let pdf = normal_pdf(eta);
1100            f_a += weight * pdf;
1101            f_b += weight * pdf * scale * node;
1102        }
1103        if !(f_a.is_finite() && f_a > 0.0 && f_b.is_finite()) {
1104            return Err(EstimationError::InvalidInput(format!(
1105                "empirical latent prediction calibration derivative is invalid: F_a={f_a}, F_b={f_b}"
1106            )));
1107        }
1108        let a_marginal_eta = marginal.mu1 / f_a;
1109        let a_slope = -f_b / f_a;
1110        Ok((intercept, a_marginal_eta, a_slope))
1111    }
1112
1113    fn local_empirical_mixture_for_point(
1114        point: &[f64],
1115        centers: &[Vec<f64>],
1116        top_k: usize,
1117        bandwidth: f64,
1118    ) -> Result<Vec<(usize, f64)>, EstimationError> {
1119        if centers.is_empty() {
1120            return Err(EstimationError::InvalidInput(
1121                "local empirical latent prediction has no centers".to_string(),
1122            ));
1123        }
1124        if top_k == 0 {
1125            return Err(EstimationError::InvalidInput(
1126                "local empirical latent prediction top_k must be positive".to_string(),
1127            ));
1128        }
1129        if !(bandwidth.is_finite() && bandwidth > 0.0) {
1130            return Err(EstimationError::InvalidInput(format!(
1131                "local empirical latent prediction bandwidth must be finite and positive, got {bandwidth}"
1132            )));
1133        }
1134        let bw2 = bandwidth * bandwidth;
1135        let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
1136        for (idx, center) in centers.iter().enumerate() {
1137            if center.len() != point.len() {
1138                return Err(EstimationError::InvalidInput(format!(
1139                    "local empirical latent prediction center {idx} dimension mismatch: center={}, point={}",
1140                    center.len(),
1141                    point.len()
1142                )));
1143            }
1144            let d2 = center
1145                .iter()
1146                .zip(point.iter())
1147                .map(|(&c, &x)| {
1148                    let delta = x - c;
1149                    delta * delta
1150                })
1151                .sum::<f64>();
1152            if !d2.is_finite() {
1153                return Err(EstimationError::InvalidInput(
1154                    "local empirical latent prediction distance is non-finite".to_string(),
1155                ));
1156            }
1157            distances.push((idx, d2));
1158        }
1159        distances.sort_by(|left, right| {
1160            left.1
1161                .partial_cmp(&right.1)
1162                .expect("validated local empirical distances are finite")
1163        });
1164        let k = top_k.min(distances.len());
1165        let mut mixture = Vec::with_capacity(k);
1166        let mut total = 0.0;
1167        for &(idx, d2) in distances.iter().take(k) {
1168            let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
1169            mixture.push((idx, weight));
1170            total += weight;
1171        }
1172        if !(total.is_finite() && total > 0.0) {
1173            return Err(EstimationError::InvalidInput(
1174                "local empirical latent prediction mixture has non-positive total weight"
1175                    .to_string(),
1176            ));
1177        }
1178        for (_, weight) in &mut mixture {
1179            *weight /= total;
1180        }
1181        Ok(mixture)
1182    }
1183
1184    fn combine_empirical_grids(
1185        grids: &[EmpiricalZGrid],
1186        mixture: &[(usize, f64)],
1187    ) -> Result<EmpiricalZGrid, EstimationError> {
1188        let total_len = mixture
1189            .iter()
1190            .map(|&(idx, _)| grids.get(idx).map_or(0, |grid| grid.nodes.len()))
1191            .sum::<usize>();
1192        let mut nodes = Vec::with_capacity(total_len);
1193        let mut weights = Vec::with_capacity(total_len);
1194        let mut total_weight = 0.0;
1195        for &(grid_idx, grid_weight) in mixture {
1196            if !(grid_weight.is_finite() && grid_weight >= 0.0) {
1197                return Err(EstimationError::InvalidInput(format!(
1198                    "local empirical latent prediction mixture weight must be finite and non-negative, got {grid_weight}"
1199                )));
1200            }
1201            let grid = grids.get(grid_idx).ok_or_else(|| {
1202                EstimationError::InvalidInput(format!(
1203                    "local empirical latent prediction grid index {grid_idx} is out of bounds for {} grids",
1204                    grids.len()
1205                ))
1206            })?;
1207            if grid.nodes.len() != grid.weights.len() || grid.nodes.is_empty() {
1208                return Err(EstimationError::InvalidInput(format!(
1209                    "local empirical latent prediction grid {grid_idx} is invalid: nodes={}, weights={}",
1210                    grid.nodes.len(),
1211                    grid.weights.len()
1212                )));
1213            }
1214            for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
1215                let combined_weight = grid_weight * weight;
1216                if !(node.is_finite() && combined_weight.is_finite() && combined_weight >= 0.0) {
1217                    return Err(EstimationError::InvalidInput(
1218                        "local empirical latent prediction grid contains invalid node/weight"
1219                            .to_string(),
1220                    ));
1221                }
1222                nodes.push(node);
1223                weights.push(combined_weight);
1224                total_weight += combined_weight;
1225            }
1226        }
1227        if !(total_weight.is_finite() && total_weight > 0.0) {
1228            return Err(EstimationError::InvalidInput(
1229                "local empirical latent prediction combined grid has non-positive total weight"
1230                    .to_string(),
1231            ));
1232        }
1233        for weight in &mut weights {
1234            *weight /= total_weight;
1235        }
1236        Ok(EmpiricalZGrid { nodes, weights })
1237    }
1238
1239    fn empirical_grid_for_prediction_row(
1240        &self,
1241        input: &PredictInput,
1242        row: usize,
1243    ) -> Result<Option<EmpiricalZGrid>, EstimationError> {
1244        match &self.latent_measure {
1245            LatentMeasureKind::StandardNormal => Ok(None),
1246            LatentMeasureKind::GlobalEmpirical { nodes, weights } => Ok(Some(EmpiricalZGrid {
1247                nodes: nodes.clone(),
1248                weights: weights.clone(),
1249            })),
1250            LatentMeasureKind::LocalEmpirical {
1251                centers,
1252                grids,
1253                top_k,
1254                bandwidth,
1255                ..
1256            } => {
1257                let conditioning = input.auxiliary_matrix.as_ref().ok_or_else(|| {
1258                    EstimationError::InvalidInput(
1259                        "bernoulli marginal-slope local empirical prediction requires auxiliary conditioning matrix"
1260                            .to_string(),
1261                    )
1262                })?;
1263                if row >= conditioning.nrows() {
1264                    return Err(EstimationError::InvalidInput(format!(
1265                        "local empirical latent prediction row {row} is out of bounds for {} conditioning rows",
1266                        conditioning.nrows()
1267                    )));
1268                }
1269                let expected_dim = centers.first().map_or(0, Vec::len);
1270                if conditioning.ncols() != expected_dim {
1271                    return Err(EstimationError::InvalidInput(format!(
1272                        "local empirical latent prediction conditioning dimension mismatch: got {}, expected {expected_dim}",
1273                        conditioning.ncols()
1274                    )));
1275                }
1276                let point = conditioning.row(row).to_vec();
1277                let mixture =
1278                    Self::local_empirical_mixture_for_point(&point, centers, *top_k, *bandwidth)?;
1279                Self::combine_empirical_grids(grids, &mixture).map(Some)
1280            }
1281        }
1282    }
1283
1284    fn transform_internal_eta_to_base_scale(
1285        &self,
1286        internal_eta: Array1<f64>,
1287        internal_grad: Option<Array2<f64>>,
1288    ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1289        Ok((internal_eta, internal_grad))
1290    }
1291
1292    fn link_terms_value_d1(
1293        &self,
1294        eta0: &Array1<f64>,
1295        beta_link_dev: Option<&Array1<f64>>,
1296        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1297    ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
1298        if let (Some(runtime), Some(beta)) = (&self.link_deviation_runtime, beta_link_dev) {
1299            // When the runtime carries a cross-block anchor residual, every
1300            // raw-design row needs `n_row · M` subtracted. `correction_for_row`
1301            // already holds the precomputed `n_row · M` for this predict row
1302            // (length basis_dim), so the corrected basis contribution to η
1303            // is `basis · beta - correction.dot(beta)` for every eta0 entry.
1304            // Derivative paths are unaffected (the anchor argument is a
1305            // different scalar than eta0).
1306            let basis = runtime
1307                .design_uncorrected(eta0)
1308                .map_err(EstimationError::InvalidInput)?;
1309            let mut value = &basis.dot(beta) + eta0;
1310            if let Some(corr) = link_dev_correction_for_row {
1311                let offset = corr.dot(beta);
1312                for v in value.iter_mut() {
1313                    *v -= offset;
1314                }
1315            } else if runtime.anchor_residual_coefficients.is_some() {
1316                return Err(EstimationError::InvalidInput(
1317                    "bernoulli marginal-slope link-deviation runtime has an anchor residual but \
1318                     no per-row correction was supplied to link_terms_value_d1"
1319                        .to_string(),
1320                ));
1321            }
1322            let d1 = runtime
1323                .first_derivative_design(eta0)
1324                .map_err(EstimationError::InvalidInput)?;
1325            Ok((value, d1.dot(beta) + 1.0))
1326        } else {
1327            Ok((eta0.clone(), Array1::ones(eta0.len())))
1328        }
1329    }
1330
1331    fn denested_partition_cells(
1332        &self,
1333        a: f64,
1334        b: f64,
1335        beta_score_warp: Option<&Array1<f64>>,
1336        beta_link_dev: Option<&Array1<f64>>,
1337        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1338        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1339    ) -> Result<
1340        Vec<crate::families::bernoulli_marginal_slope::exact_kernel::DenestedPartitionCell>,
1341        EstimationError,
1342    > {
1343        let score_breaks = if let Some(runtime) = self.score_warp_runtime.as_ref() {
1344            runtime
1345                .breakpoints()
1346                .map_err(EstimationError::InvalidInput)?
1347        } else {
1348            Vec::new()
1349        };
1350        let link_breaks = if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1351            runtime
1352                .breakpoints()
1353                .map_err(EstimationError::InvalidInput)?
1354        } else {
1355            Vec::new()
1356        };
1357        let mut cells = crate::families::bernoulli_marginal_slope::exact_kernel::build_denested_partition_cells_with_tails(
1358            a,
1359            b,
1360            &score_breaks,
1361            &link_breaks,
1362            |z| {
1363                if let (Some(runtime), Some(beta)) =
1364                    (self.score_warp_runtime.as_ref(), beta_score_warp)
1365                {
1366                    let mut span = runtime.local_cubic_at(beta, z)?;
1367                    // `local_cubic_at`'s c0 is `Σ_j basis_c0[span][j] · beta[j]`.
1368                    // The cross-block residual replaces basis_c0 by
1369                    // basis_c0 − n_row · M, contributing a row-constant
1370                    // `correction.dot(beta)` to c0. Higher coefficients
1371                    // (c1..c3) depend on derivatives of the basis w.r.t.
1372                    // its own argument and are untouched.
1373                    if let Some(corr) = score_warp_correction_for_row {
1374                        span.c0 -= corr.dot(beta);
1375                    }
1376                    Ok(span)
1377                } else {
1378                    Ok(crate::families::bernoulli_marginal_slope::exact_kernel::LocalSpanCubic {
1379                        left: 0.0,
1380                        right: 1.0,
1381                        c0: 0.0,
1382                        c1: 0.0,
1383                        c2: 0.0,
1384                        c3: 0.0,
1385                    })
1386                }
1387            },
1388            |u| {
1389                if let (Some(runtime), Some(beta)) =
1390                    (self.link_deviation_runtime.as_ref(), beta_link_dev)
1391                {
1392                    let mut span = runtime.local_cubic_at(beta, u)?;
1393                    if let Some(corr) = link_dev_correction_for_row {
1394                        span.c0 -= corr.dot(beta);
1395                    }
1396                    Ok(span)
1397                } else {
1398                    Ok(crate::families::bernoulli_marginal_slope::exact_kernel::LocalSpanCubic {
1399                        left: 0.0,
1400                        right: 1.0,
1401                        c0: 0.0,
1402                        c1: 0.0,
1403                        c2: 0.0,
1404                        c3: 0.0,
1405                    })
1406                }
1407            },
1408        )
1409        .map_err(EstimationError::InvalidInput)?;
1410        let scale = self.probit_frailty_scale();
1411        if scale != 1.0 {
1412            for partition_cell in &mut cells {
1413                partition_cell.cell.c0 *= scale;
1414                partition_cell.cell.c1 *= scale;
1415                partition_cell.cell.c2 *= scale;
1416                partition_cell.cell.c3 *= scale;
1417            }
1418        }
1419        Ok(cells)
1420    }
1421
1422    fn evaluate_denested_calibration(
1423        &self,
1424        a: f64,
1425        marginal_eta: f64,
1426        slope: f64,
1427        beta_score_warp: Option<&Array1<f64>>,
1428        beta_link_dev: Option<&Array1<f64>>,
1429        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1430        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1431    ) -> Result<(f64, f64, f64), EstimationError> {
1432        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1433            .map_err(EstimationError::InvalidInput)?;
1434        let cells = self.denested_partition_cells(
1435            a,
1436            slope,
1437            beta_score_warp,
1438            beta_link_dev,
1439            score_warp_correction_for_row,
1440            link_dev_correction_for_row,
1441        )?;
1442        let scale = self.probit_frailty_scale();
1443        let mut f = -marginal.mu;
1444        let mut f_a = 0.0;
1445        let mut f_aa = 0.0;
1446        for partition_cell in cells {
1447            let cell = partition_cell.cell;
1448            let state =
1449                crate::families::bernoulli_marginal_slope::exact_kernel::evaluate_cell_moments(
1450                    cell, 7,
1451                )
1452                .map_err(EstimationError::InvalidInput)?;
1453            f += state.value;
1454            let (dc_da_raw, _) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_coefficient_partials(
1455                partition_cell.score_span,
1456                partition_cell.link_span,
1457                a,
1458                slope,
1459            );
1460            let (d2c_da2_raw, _, _) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_second_partials(
1461                partition_cell.score_span,
1462                partition_cell.link_span,
1463                a,
1464                slope,
1465            );
1466            let dc_da = scale_coeff4(dc_da_raw, scale);
1467            let d2c_da2 = scale_coeff4(d2c_da2_raw, scale);
1468            f_a += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
1469                &dc_da,
1470                &state.moments,
1471            )
1472            .map_err(EstimationError::InvalidInput)?;
1473            f_aa += crate::families::bernoulli_marginal_slope::exact_kernel::cell_second_derivative_from_moments(
1474                cell,
1475                &dc_da,
1476                &dc_da,
1477                &d2c_da2,
1478                &state.moments,
1479            )
1480            .map_err(EstimationError::InvalidInput)?;
1481        }
1482        Ok((f, f_a, f_aa))
1483    }
1484
1485    fn observed_denested_cell_partials_at_z(
1486        &self,
1487        z_value: f64,
1488        a: f64,
1489        b: f64,
1490        beta_score_warp: Option<&Array1<f64>>,
1491        beta_link_dev: Option<&Array1<f64>>,
1492        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1493        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1494    ) -> Result<ObservedDenestedCellPartials, EstimationError> {
1495        use crate::families::bernoulli_marginal_slope::exact_kernel as exact;
1496
1497        let zero_span = exact::LocalSpanCubic {
1498            left: 0.0,
1499            right: 1.0,
1500            c0: 0.0,
1501            c1: 0.0,
1502            c2: 0.0,
1503            c3: 0.0,
1504        };
1505        let u_value = a + b * z_value;
1506        let score_span = if let (Some(runtime), Some(beta)) =
1507            (self.score_warp_runtime.as_ref(), beta_score_warp)
1508        {
1509            let mut span = runtime
1510                .local_cubic_at(beta, z_value)
1511                .map_err(EstimationError::InvalidInput)?;
1512            if let Some(corr) = score_warp_correction_for_row {
1513                span.c0 -= corr.dot(beta);
1514            }
1515            span
1516        } else {
1517            zero_span
1518        };
1519        let link_span = if let (Some(runtime), Some(beta)) =
1520            (self.link_deviation_runtime.as_ref(), beta_link_dev)
1521        {
1522            let mut span = runtime
1523                .local_cubic_at(beta, u_value)
1524                .map_err(EstimationError::InvalidInput)?;
1525            if let Some(corr) = link_dev_correction_for_row {
1526                span.c0 -= corr.dot(beta);
1527            }
1528            span
1529        } else {
1530            zero_span
1531        };
1532        let scale = self.probit_frailty_scale();
1533        let coeff = scale_coeff4(
1534            exact::denested_cell_coefficients(score_span, link_span, a, b),
1535            scale,
1536        );
1537        let (dc_da_raw, dc_db_raw) =
1538            exact::denested_cell_coefficient_partials(score_span, link_span, a, b);
1539        let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
1540            exact::denested_cell_second_partials(score_span, link_span, a, b);
1541        let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) = exact::denested_cell_third_partials(link_span);
1542        Ok(ObservedDenestedCellPartials {
1543            coeff,
1544            dc_da: scale_coeff4(dc_da_raw, scale),
1545            dc_db: scale_coeff4(dc_db_raw, scale),
1546            dc_daa: scale_coeff4(dc_daa_raw, scale),
1547            dc_dab: scale_coeff4(dc_dab_raw, scale),
1548            dc_dbb: scale_coeff4(dc_dbb_raw, scale),
1549            dc_daaa: scale_coeff4(dc_daaa, scale),
1550            dc_daab: scale_coeff4(dc_daab, scale),
1551            dc_dabb: scale_coeff4(dc_dabb, scale),
1552            dc_dbbb: scale_coeff4(dc_dbbb, scale),
1553        })
1554    }
1555
1556    fn evaluate_empirical_denested_calibration(
1557        &self,
1558        a: f64,
1559        marginal_eta: f64,
1560        slope: f64,
1561        beta_score_warp: Option<&Array1<f64>>,
1562        beta_link_dev: Option<&Array1<f64>>,
1563        grid: &EmpiricalZGrid,
1564        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1565        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1566    ) -> Result<(f64, f64, f64), EstimationError> {
1567        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1568            .map_err(EstimationError::InvalidInput)?;
1569        let mut f = -marginal.mu;
1570        let mut f_a = 0.0;
1571        let mut f_aa = 0.0;
1572        for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
1573            let obs = self.observed_denested_cell_partials_at_z(
1574                node,
1575                a,
1576                slope,
1577                beta_score_warp,
1578                beta_link_dev,
1579                score_warp_correction_for_row,
1580                link_dev_correction_for_row,
1581            )?;
1582            let eta = eval_coeff4_at(&obs.coeff, node);
1583            let eta_a = eval_coeff4_at(&obs.dc_da, node);
1584            let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
1585            let pdf = normal_pdf(eta);
1586            f += weight * normal_cdf(eta);
1587            f_a += weight * pdf * eta_a;
1588            f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
1589        }
1590        Ok((f, f_a, f_aa))
1591    }
1592
1593    fn evaluate_prediction_calibration(
1594        &self,
1595        a: f64,
1596        marginal_eta: f64,
1597        slope: f64,
1598        beta_score_warp: Option<&Array1<f64>>,
1599        beta_link_dev: Option<&Array1<f64>>,
1600        empirical_grid: Option<&EmpiricalZGrid>,
1601        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1602        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1603    ) -> Result<(f64, f64, f64), EstimationError> {
1604        if let Some(grid) = empirical_grid {
1605            self.evaluate_empirical_denested_calibration(
1606                a,
1607                marginal_eta,
1608                slope,
1609                beta_score_warp,
1610                beta_link_dev,
1611                grid,
1612                score_warp_correction_for_row,
1613                link_dev_correction_for_row,
1614            )
1615        } else {
1616            self.evaluate_denested_calibration(
1617                a,
1618                marginal_eta,
1619                slope,
1620                beta_score_warp,
1621                beta_link_dev,
1622                score_warp_correction_for_row,
1623                link_dev_correction_for_row,
1624            )
1625        }
1626    }
1627
1628    pub fn from_unified(
1629        unified: &UnifiedFitResult,
1630        z_column: String,
1631        latent_z_normalization: SavedLatentZNormalization,
1632        latent_measure: LatentMeasureKind,
1633        baseline_marginal: f64,
1634        baseline_logslope: f64,
1635        base_link: InverseLink,
1636        frailty: FrailtySpec,
1637        score_warp_runtime: Option<SavedAnchoredDeviationRuntime>,
1638        link_deviation_runtime: Option<SavedAnchoredDeviationRuntime>,
1639        latent_z_calibration: Option<
1640            crate::families::bernoulli_marginal_slope::LatentZRankIntCalibration,
1641        >,
1642    ) -> Result<Self, String> {
1643        let gaussian_frailty_sd = match frailty {
1644            FrailtySpec::None => None,
1645            FrailtySpec::GaussianShift {
1646                sigma_fixed: Some(sigma),
1647            } => Some(sigma),
1648            FrailtySpec::GaussianShift { sigma_fixed: None } => {
1649                return Err(
1650                    "bernoulli marginal-slope predictor requires a fixed GaussianShift sigma"
1651                        .to_string(),
1652                );
1653            }
1654            FrailtySpec::HazardMultiplier { .. } => {
1655                return Err(
1656                    "bernoulli marginal-slope predictor does not support HazardMultiplier frailty"
1657                        .to_string(),
1658                );
1659            }
1660        };
1661        if !matches!(
1662            base_link,
1663            InverseLink::Standard(crate::types::LinkFunction::Probit)
1664        ) {
1665            return Err(
1666                "bernoulli marginal-slope predictor requires link(type=probit); saved non-probit marginal-slope models must be refit"
1667                    .to_string(),
1668            );
1669        }
1670        if let Some(runtime) = score_warp_runtime.as_ref() {
1671            runtime.validate_exact_replay_contract().map_err(|e| {
1672                format!("bernoulli marginal-slope score-warp runtime is invalid: {e}")
1673            })?;
1674        }
1675        if let Some(runtime) = link_deviation_runtime.as_ref() {
1676            runtime.validate_exact_replay_contract().map_err(|e| {
1677                format!("bernoulli marginal-slope link-deviation runtime is invalid: {e}")
1678            })?;
1679        }
1680        // Cross-block anchor residuals on either runtime are now applied
1681        // per-row by every predict-time `local_cubic_at` / `basis_cubic_at`
1682        // / `design` call site via `build_anchor_correction_matrices`.
1683        latent_z_normalization
1684            .validate("bernoulli marginal-slope predictor")
1685            .map_err(|e| {
1686                format!("bernoulli marginal-slope predictor latent z normalization is invalid: {e}")
1687            })?;
1688        latent_measure
1689            .validate("bernoulli marginal-slope predictor latent measure")
1690            .map_err(|e| {
1691                format!("bernoulli marginal-slope predictor latent measure is invalid: {e}")
1692            })?;
1693        let blocks = &unified.blocks;
1694        let expected_blocks = 2
1695            + usize::from(score_warp_runtime.is_some())
1696            + usize::from(link_deviation_runtime.is_some());
1697        if blocks.len() != expected_blocks {
1698            return Err(format!(
1699                "bernoulli marginal-slope predictor requires exactly {expected_blocks} coefficient blocks under the current exact de-nested semantics, got {}",
1700                blocks.len()
1701            ));
1702        }
1703        let mut cursor = 2usize;
1704        let beta_score_warp = if score_warp_runtime.is_some() {
1705            let beta = blocks
1706                .get(cursor)
1707                .ok_or_else(|| "missing score-warp coefficient block".to_string())?
1708                .beta
1709                .clone();
1710            cursor += 1;
1711            Some(beta)
1712        } else {
1713            None
1714        };
1715        let beta_link_dev = if link_deviation_runtime.is_some() {
1716            Some(
1717                blocks
1718                    .get(cursor)
1719                    .ok_or_else(|| "missing link-deviation coefficient block".to_string())?
1720                    .beta
1721                    .clone(),
1722            )
1723        } else {
1724            None
1725        };
1726        Ok(Self {
1727            beta_marginal: blocks[0].beta.clone(),
1728            beta_logslope: blocks[1].beta.clone(),
1729            beta_score_warp,
1730            beta_link_dev,
1731            base_link,
1732            z_column,
1733            latent_z_normalization,
1734            latent_measure,
1735            baseline_marginal,
1736            baseline_logslope,
1737            covariance: unified.beta_covariance().cloned(),
1738            score_warp_runtime,
1739            link_deviation_runtime,
1740            gaussian_frailty_sd,
1741            latent_z_calibration,
1742        })
1743    }
1744
1745    fn theta(&self) -> Array1<f64> {
1746        let total = self.beta_marginal.len()
1747            + self.beta_logslope.len()
1748            + self.beta_score_warp.as_ref().map_or(0, |b| b.len())
1749            + self.beta_link_dev.as_ref().map_or(0, |b| b.len());
1750        let mut theta = Array1::<f64>::zeros(total);
1751        let mut cursor = 0usize;
1752        theta
1753            .slice_mut(ndarray::s![cursor..cursor + self.beta_marginal.len()])
1754            .assign(&self.beta_marginal);
1755        cursor += self.beta_marginal.len();
1756        theta
1757            .slice_mut(ndarray::s![cursor..cursor + self.beta_logslope.len()])
1758            .assign(&self.beta_logslope);
1759        cursor += self.beta_logslope.len();
1760        if let Some(beta) = self.beta_score_warp.as_ref() {
1761            theta
1762                .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1763                .assign(beta);
1764            cursor += beta.len();
1765        }
1766        if let Some(beta) = self.beta_link_dev.as_ref() {
1767            theta
1768                .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1769                .assign(beta);
1770        }
1771        theta
1772    }
1773
1774    fn split_theta<'a>(
1775        &'a self,
1776        theta: &'a Array1<f64>,
1777    ) -> Result<
1778        (
1779            ArrayView1<'a, f64>,
1780            ArrayView1<'a, f64>,
1781            Option<ArrayView1<'a, f64>>,
1782            Option<ArrayView1<'a, f64>>,
1783        ),
1784        EstimationError,
1785    > {
1786        let expected = self.theta().len();
1787        if theta.len() != expected {
1788            return Err(EstimationError::InvalidInput(format!(
1789                "bernoulli marginal-slope theta length mismatch: expected {expected}, got {}",
1790                theta.len()
1791            )));
1792        }
1793        let mut cursor = 0usize;
1794        let marginal = theta.slice(ndarray::s![cursor..cursor + self.beta_marginal.len()]);
1795        cursor += self.beta_marginal.len();
1796        let logslope = theta.slice(ndarray::s![cursor..cursor + self.beta_logslope.len()]);
1797        cursor += self.beta_logslope.len();
1798        let score_warp = self.beta_score_warp.as_ref().map(|beta| {
1799            let view = theta.slice(ndarray::s![cursor..cursor + beta.len()]);
1800            cursor += beta.len();
1801            view
1802        });
1803        let link_dev = self
1804            .beta_link_dev
1805            .as_ref()
1806            .map(|beta| theta.slice(ndarray::s![cursor..cursor + beta.len()]));
1807        Ok((marginal, logslope, score_warp, link_dev))
1808    }
1809
1810    /// Safeguarded monotone root solve for the marginal intercept under the
1811    /// de-nested flexible model
1812    ///   η(z) = a + b z + b Δ_h(z) + Δ_w(a + b z).
1813    fn solve_intercept_scalar(
1814        &self,
1815        marginal_eta: f64,
1816        slope: f64,
1817        link_dev_beta: Option<&Array1<f64>>,
1818        score_warp_beta: Option<&Array1<f64>>,
1819        empirical_grid: Option<&EmpiricalZGrid>,
1820        warm_start_buf: &mut Array1<f64>,
1821        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1822        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1823    ) -> Result<f64, EstimationError> {
1824        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1825            .map_err(EstimationError::InvalidInput)?;
1826        let eval = |a: f64| -> Result<(f64, f64, f64), String> {
1827            self.evaluate_prediction_calibration(
1828                a,
1829                marginal_eta,
1830                slope,
1831                score_warp_beta,
1832                link_dev_beta,
1833                empirical_grid,
1834                score_warp_correction_for_row,
1835                link_dev_correction_for_row,
1836            )
1837            .map_err(|err| err.to_string())
1838        };
1839
1840        let probit_scale = self.probit_frailty_scale();
1841        let a_rigid = self.rigid_intercept_from_marginal(marginal.q, slope);
1842        let mut intercept = a_rigid;
1843        if let (Some(_), Some(beta)) = (self.link_deviation_runtime.as_ref(), link_dev_beta) {
1844            warm_start_buf[0] = a_rigid;
1845            let one_pt = warm_start_buf.slice(ndarray::s![0..1]).to_owned();
1846            let (l_val, l_d1) =
1847                self.link_terms_value_d1(&one_pt, Some(beta), link_dev_correction_for_row)?;
1848            let ell1 = l_d1[0];
1849            if ell1 > 1e-8 {
1850                let ell0 = l_val[0] - ell1 * a_rigid;
1851                let observed_logslope = probit_scale * ell1 * slope;
1852                intercept = (marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
1853                    / probit_scale
1854                    - ell0)
1855                    / ell1;
1856            }
1857        }
1858
1859        // Same adaptive tolerance the acceptance check below uses; passing
1860        // a tighter `convergence_tol` would just iterate past what we accept.
1861        let target = marginal.mu;
1862        let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
1863
1864        let (root, _, f_best) = crate::families::monotone_root::solve_monotone_root(
1865            eval,
1866            intercept,
1867            "saved bernoulli intercept",
1868            abs_tol,
1869            64,
1870            48,
1871        )
1872        .map_err(EstimationError::InvalidInput)?;
1873
1874        if f_best.abs() > abs_tol {
1875            return Err(EstimationError::InvalidInput(format!(
1876                "saved bernoulli marginal-slope intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target:.6}"
1877            )));
1878        }
1879        Ok(root)
1880    }
1881
1882    fn final_eta_and_gradient_from_theta(
1883        &self,
1884        input: &PredictInput,
1885        theta: &Array1<f64>,
1886        need_gradient: bool,
1887    ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1888        let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1889            EstimationError::InvalidInput(format!(
1890                "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1891                self.z_column
1892            ))
1893        })?;
1894        let z_normalized = self
1895            .latent_z_normalization
1896            .apply(z_raw, "bernoulli marginal-slope prediction")
1897            .map_err(EstimationError::InvalidInput)?;
1898        // P4: when training applied a rank-INT calibration to the latent
1899        // z (so the BMS rigid kernel could use the closed-form
1900        // standard-normal path), the predictor MUST apply the same
1901        // monotone transform to predict-time z before any kernel
1902        // evaluation. The transform is mathematically exact: piecewise-
1903        // linear interpolation on (sorted_z, weighted_cdf) followed by
1904        // Φ⁻¹, both strictly monotone and invertible up to the empirical
1905        // CDF resolution. `None` ⇒ training-time z passed the strict
1906        // normality check, no transform was applied, leave z unchanged.
1907        let z = self.apply_latent_z_calibration(&z_normalized);
1908        let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1909            EstimationError::InvalidInput(
1910                "bernoulli marginal-slope prediction requires logslope design".to_string(),
1911            )
1912        })?;
1913        let (beta_marginal, beta_logslope, beta_score_warp, beta_link_dev) =
1914            self.split_theta(theta)?;
1915        if self.score_warp_runtime.is_some() != beta_score_warp.is_some() {
1916            return Err(EstimationError::InvalidInput(
1917                "bernoulli marginal-slope saved score-warp runtime/coefficients are inconsistent"
1918                    .to_string(),
1919            ));
1920        }
1921        if self.link_deviation_runtime.is_some() != beta_link_dev.is_some() {
1922            return Err(EstimationError::InvalidInput(
1923                "bernoulli marginal-slope saved link-deviation runtime/coefficients are inconsistent"
1924                    .to_string(),
1925            ));
1926        }
1927        let n = z.len();
1928        if input.offset.len() != n {
1929            return Err(EstimationError::InvalidInput(format!(
1930                "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1931                input.offset.len()
1932            )));
1933        }
1934        let logslope_offset = input
1935            .offset_noise
1936            .as_ref()
1937            .map_or_else(|| Array1::zeros(n), Clone::clone);
1938        if logslope_offset.len() != n {
1939            return Err(EstimationError::InvalidInput(format!(
1940                "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1941                logslope_offset.len()
1942            )));
1943        }
1944        let marginal_eta = input
1945            .design
1946            .dot(&beta_marginal.to_owned())
1947            .mapv(|v| v + self.baseline_marginal)
1948            + &input.offset;
1949        let logslope_eta = design_logslope
1950            .dot(&beta_logslope.to_owned())
1951            .mapv(|v| v + self.baseline_logslope)
1952            + &logslope_offset;
1953        let flex_active =
1954            self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1955        let marginal_dim = self.beta_marginal.len();
1956        let logslope_dim = self.beta_logslope.len();
1957        let score_warp_dim = self.beta_score_warp.as_ref().map_or(0, Array1::len);
1958        let link_dev_dim = self.beta_link_dev.as_ref().map_or(0, Array1::len);
1959        let logslope_offset = marginal_dim;
1960        let score_warp_offset = logslope_offset + logslope_dim;
1961        let link_dev_offset = score_warp_offset + score_warp_dim;
1962        let chunk_size = prediction_chunk_rows(theta.len(), 1, n);
1963        let num_chunks = (n + chunk_size - 1) / chunk_size;
1964        let scale = self.probit_frailty_scale();
1965        // Cross-block anchor corrections: when either runtime carries an
1966        // anchor residual, precompute the per-row correction matrices
1967        // (n × runtime_basis_dim) once. Each subsequent per-row evaluation
1968        // subtracts the corresponding row of these matrices from the raw
1969        // cubic-span basis output. When neither runtime has a residual,
1970        // the returned bundle is empty and threading is a no-op.
1971        let anchor_corrections = self.build_anchor_correction_matrices(input, design_logslope)?;
1972        let marginal_map = marginal_eta
1973            .iter()
1974            .map(|&eta| {
1975                bernoulli_marginal_link_map(&self.base_link, eta)
1976                    .map_err(EstimationError::InvalidInput)
1977            })
1978            .collect::<Result<Vec<_>, _>>()?;
1979
1980        if !flex_active {
1981            let (final_eta_internal, marginal_scales, logslope_scales) = match &self.latent_measure
1982            {
1983                LatentMeasureKind::StandardNormal => {
1984                    let sb_vec = logslope_eta.mapv(|b| scale * b);
1985                    let c_vec = sb_vec.mapv(|sb| (1.0 + sb * sb).sqrt());
1986                    let final_eta_internal = Array1::from_iter(
1987                        (0..n).map(|i| c_vec[i] * marginal_eta[i] + sb_vec[i] * z[i]),
1988                    );
1989                    let marginal_scales = c_vec;
1990                    let logslope_scales = Array1::from_iter((0..n).map(|i| {
1991                        marginal_eta[i] * (scale * scale) * logslope_eta[i] / marginal_scales[i]
1992                            + scale * z[i]
1993                    }));
1994                    (final_eta_internal, marginal_scales, logslope_scales)
1995                }
1996                LatentMeasureKind::GlobalEmpirical { nodes, weights } => {
1997                    let mut final_eta = Array1::<f64>::zeros(n);
1998                    let mut marginal_scales = Array1::<f64>::zeros(n);
1999                    let mut logslope_scales = Array1::<f64>::zeros(n);
2000                    for i in 0..n {
2001                        let (intercept, a_marginal, a_slope) = self
2002                            .empirical_rigid_intercept_and_gradient(
2003                                marginal_eta[i],
2004                                logslope_eta[i],
2005                                nodes,
2006                                weights,
2007                            )?;
2008                        final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
2009                        marginal_scales[i] = a_marginal;
2010                        logslope_scales[i] = a_slope + scale * z[i];
2011                    }
2012                    (final_eta, marginal_scales, logslope_scales)
2013                }
2014                LatentMeasureKind::LocalEmpirical { .. } => {
2015                    let mut final_eta = Array1::<f64>::zeros(n);
2016                    let mut marginal_scales = Array1::<f64>::zeros(n);
2017                    let mut logslope_scales = Array1::<f64>::zeros(n);
2018                    for i in 0..n {
2019                        let grid = self
2020                            .empirical_grid_for_prediction_row(input, i)?
2021                            .ok_or_else(|| {
2022                                EstimationError::InvalidInput(
2023                                    "local empirical latent prediction did not produce a row grid"
2024                                        .to_string(),
2025                                )
2026                            })?;
2027                        let (intercept, a_marginal, a_slope) = self
2028                            .empirical_rigid_intercept_and_gradient(
2029                                marginal_eta[i],
2030                                logslope_eta[i],
2031                                &grid.nodes,
2032                                &grid.weights,
2033                            )?;
2034                        final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
2035                        marginal_scales[i] = a_marginal;
2036                        logslope_scales[i] = a_slope + scale * z[i];
2037                    }
2038                    (final_eta, marginal_scales, logslope_scales)
2039                }
2040            };
2041
2042            if !need_gradient {
2043                return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
2044            }
2045
2046            // Chunk Jacobian: one pass per row fills both blocks.
2047            let mut grad_internal = Array2::<f64>::zeros((n, theta.len()));
2048            let mut start = 0usize;
2049            while start < n {
2050                let end = (start + chunk_size).min(n);
2051                let mc = input
2052                    .design
2053                    .try_row_chunk(start..end)
2054                    .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
2055                let lc = design_logslope
2056                    .try_row_chunk(start..end)
2057                    .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
2058
2059                for li in 0..(end - start) {
2060                    let i = start + li;
2061                    let c = marginal_scales[i];
2062                    let g_scale = logslope_scales[i];
2063                    let mut row = grad_internal.row_mut(i);
2064                    for j in 0..marginal_dim {
2065                        row[j] = c * mc[[li, j]];
2066                    }
2067                    for j in 0..logslope_dim {
2068                        row[logslope_offset + j] = g_scale * lc[[li, j]];
2069                    }
2070                }
2071
2072                start = end;
2073            }
2074            return self
2075                .transform_internal_eta_to_base_scale(final_eta_internal, Some(grad_internal));
2076        }
2077
2078        // ── Flexible path: per-row intercept solve, chunked Jacobians ──
2079        let score_warp_obs_design = self
2080            .score_warp_runtime
2081            .as_ref()
2082            .map(|runtime| {
2083                if runtime.anchor_residual_coefficients.is_some() {
2084                    let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2085                        EstimationError::InvalidInput(
2086                            "bernoulli marginal-slope score-warp anchor residual present but \
2087                             anchor_corrections bundle is missing the parametric anchor rows"
2088                                .to_string(),
2089                        )
2090                    })?;
2091                    runtime
2092                        .design_with_anchor_rows(&z, anchor_rows)
2093                        .map_err(EstimationError::InvalidInput)
2094                } else {
2095                    runtime.design(&z).map_err(EstimationError::InvalidInput)
2096                }
2097            })
2098            .transpose()?;
2099        let score_dev_obs = if let (Some(design), Some(beta)) =
2100            (score_warp_obs_design.as_ref(), beta_score_warp.clone())
2101        {
2102            design.dot(&beta.to_owned())
2103        } else {
2104            Array1::zeros(n)
2105        };
2106
2107        // Solve intercepts and (when gradient needed) IFT scalars in chunk-parallel passes.
2108        let score_warp_beta_owned = beta_score_warp.as_ref().map(|v| v.to_owned());
2109        let link_dev_beta_owned = beta_link_dev.as_ref().map(|v| v.to_owned());
2110        struct FlexSolveChunk {
2111            start: usize,
2112            end: usize,
2113            intercepts: Array1<f64>,
2114            a_q: Option<Array1<f64>>,
2115            a_b: Option<Array1<f64>>,
2116            a_h: Option<Array2<f64>>,
2117            a_w: Option<Array2<f64>>,
2118        }
2119        let solve_chunks = (0..num_chunks)
2120            .into_par_iter()
2121            .map(|chunk_idx| -> Result<FlexSolveChunk, EstimationError> {
2122                let start = chunk_idx * chunk_size;
2123                let end = (start + chunk_size).min(n);
2124                let rows = end - start;
2125                let mut intercepts = Array1::<f64>::zeros(rows);
2126                let mut a_q = need_gradient.then(|| Array1::<f64>::zeros(rows));
2127                let mut a_b = need_gradient.then(|| Array1::<f64>::zeros(rows));
2128                let mut a_h = if need_gradient && score_warp_dim > 0 {
2129                    Some(Array2::<f64>::zeros((rows, score_warp_dim)))
2130                } else {
2131                    None
2132                };
2133                let mut a_w = if need_gradient && link_dev_dim > 0 {
2134                    Some(Array2::<f64>::zeros((rows, link_dev_dim)))
2135                } else {
2136                    None
2137                };
2138                let mut warm_start_buf = Array1::<f64>::zeros(1);
2139
2140                for local_row in 0..rows {
2141                    let i = start + local_row;
2142                    let slope = logslope_eta[i];
2143                    let q = marginal_eta[i];
2144                    let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2145                    let score_corr_row = anchor_corrections.score_warp_row(i);
2146                    let link_corr_row = anchor_corrections.link_dev_row(i);
2147                    intercepts[local_row] = self.solve_intercept_scalar(
2148                        q,
2149                        slope,
2150                        link_dev_beta_owned.as_ref(),
2151                        score_warp_beta_owned.as_ref(),
2152                        empirical_grid.as_ref(),
2153                        &mut warm_start_buf,
2154                        score_corr_row,
2155                        link_corr_row,
2156                    )?;
2157
2158                    if !need_gradient {
2159                        continue;
2160                    }
2161
2162                    let intercept = intercepts[local_row];
2163                    let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2164                        intercept,
2165                        q,
2166                        slope,
2167                        score_warp_beta_owned.as_ref(),
2168                        link_dev_beta_owned.as_ref(),
2169                        empirical_grid.as_ref(),
2170                        score_corr_row,
2171                        link_corr_row,
2172                    )?;
2173                    let m_a = m_a_raw.max(1e-12);
2174                    a_q.as_mut().unwrap()[local_row] = marginal_map[i].mu1 / m_a;
2175                    let mut f_b = 0.0;
2176                    let mut f_h_row = vec![0.0; score_warp_dim];
2177                    let mut f_w_row = vec![0.0; link_dev_dim];
2178                    if let Some(grid) = empirical_grid.as_ref() {
2179                        for (&node, &weight) in grid.nodes.iter().zip(grid.weights.iter()) {
2180                            let obs = self.observed_denested_cell_partials_at_z(
2181                                node,
2182                                intercept,
2183                                slope,
2184                                score_warp_beta_owned.as_ref(),
2185                                link_dev_beta_owned.as_ref(),
2186                                score_corr_row,
2187                                link_corr_row,
2188                            )?;
2189                            let eta = eval_coeff4_at(&obs.coeff, node);
2190                            let pdf = normal_pdf(eta);
2191                            f_b += weight * pdf * eval_coeff4_at(&obs.dc_db, node);
2192
2193                            if let Some(runtime) = self.score_warp_runtime.as_ref() {
2194                                for j in 0..score_warp_dim {
2195                                    let mut basis_span = runtime
2196                                        .basis_cubic_at(j, node)
2197                                        .map_err(EstimationError::InvalidInput)?;
2198                                    // `basis_cubic_at` returns the j-th basis
2199                                    // function's local cubic; the residual
2200                                    // subtracts `correction[j]` from the
2201                                    // constant term (row-constant, basis-
2202                                    // function-specific). Higher span
2203                                    // coefficients are unaffected.
2204                                    if let Some(corr) = score_corr_row {
2205                                        basis_span.c0 -= corr[j];
2206                                    }
2207                                    let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::score_basis_cell_coefficients(
2208                                        basis_span,
2209                                        slope,
2210                                    );
2211                                    let coeffs = scale_coeff4(coeffs, scale);
2212                                    f_h_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
2213                                }
2214                            }
2215
2216                            if let Some(runtime) = self.link_deviation_runtime.as_ref() {
2217                                for j in 0..link_dev_dim {
2218                                    let mut basis_span = runtime
2219                                        .basis_cubic_at(j, intercept + slope * node)
2220                                        .map_err(EstimationError::InvalidInput)?;
2221                                    if let Some(corr) = link_corr_row {
2222                                        basis_span.c0 -= corr[j];
2223                                    }
2224                                    let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::link_basis_cell_coefficients(
2225                                        basis_span,
2226                                        intercept,
2227                                        slope,
2228                                    );
2229                                    let coeffs = scale_coeff4(coeffs, scale);
2230                                    f_w_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
2231                                }
2232                            }
2233                        }
2234                    } else {
2235                        let cells = self.denested_partition_cells(
2236                            intercept,
2237                            slope,
2238                            score_warp_beta_owned.as_ref(),
2239                            link_dev_beta_owned.as_ref(),
2240                            score_corr_row,
2241                            link_corr_row,
2242                        )?;
2243                        for partition_cell in cells {
2244                            let cell = partition_cell.cell;
2245                            let state =
2246                                crate::families::bernoulli_marginal_slope::exact_kernel::evaluate_cell_moments(
2247                                    cell, 9,
2248                                )
2249                                .map_err(EstimationError::InvalidInput)?;
2250                            let (_, dc_db_raw) = crate::families::bernoulli_marginal_slope::exact_kernel::denested_cell_coefficient_partials(
2251                                partition_cell.score_span,
2252                                partition_cell.link_span,
2253                                intercept,
2254                                slope,
2255                            );
2256                            // `denested_partition_cells` scales the cell itself for
2257                            // Gaussian frailty, so every coefficient partial of
2258                            // F(a, theta) must carry the same probit scale as F_a.
2259                            let dc_db = scale_coeff4(dc_db_raw, scale);
2260                            f_b += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2261                                &dc_db,
2262                                &state.moments,
2263                            )
2264                            .map_err(EstimationError::InvalidInput)?;
2265
2266                            let mid = 0.5 * (cell.left + cell.right);
2267                            if let Some(runtime) = self.score_warp_runtime.as_ref() {
2268                                for j in 0..score_warp_dim {
2269                                    let mut basis_span = runtime
2270                                        .basis_cubic_at(j, mid)
2271                                        .map_err(EstimationError::InvalidInput)?;
2272                                    if let Some(corr) = score_corr_row {
2273                                        basis_span.c0 -= corr[j];
2274                                    }
2275                                    let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::score_basis_cell_coefficients(
2276                                        basis_span, slope,
2277                                    );
2278                                    let coeffs = scale_coeff4(coeffs, scale);
2279                                    f_h_row[j] += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2280                                        &coeffs,
2281                                        &state.moments,
2282                                    )
2283                                    .map_err(EstimationError::InvalidInput)?;
2284                                }
2285                            }
2286
2287                            if let Some(runtime) = self.link_deviation_runtime.as_ref() {
2288                                for j in 0..link_dev_dim {
2289                                    let mut basis_span = runtime
2290                                        .basis_cubic_at(j, intercept + slope * mid)
2291                                        .map_err(EstimationError::InvalidInput)?;
2292                                    if let Some(corr) = link_corr_row {
2293                                        basis_span.c0 -= corr[j];
2294                                    }
2295                                    let coeffs = crate::families::bernoulli_marginal_slope::exact_kernel::link_basis_cell_coefficients(
2296                                        basis_span,
2297                                        intercept,
2298                                        slope,
2299                                    );
2300                                    let coeffs = scale_coeff4(coeffs, scale);
2301                                    f_w_row[j] += crate::families::bernoulli_marginal_slope::exact_kernel::cell_first_derivative_from_moments(
2302                                        &coeffs,
2303                                        &state.moments,
2304                                    )
2305                                    .map_err(EstimationError::InvalidInput)?;
2306                                }
2307                            }
2308                        }
2309                    }
2310                    if let Some(a_h) = a_h.as_mut() {
2311                        let factor = -1.0 / m_a;
2312                        for j in 0..score_warp_dim {
2313                            a_h[[local_row, j]] = factor * f_h_row[j];
2314                        }
2315                    }
2316                    if let Some(a_w) = a_w.as_mut() {
2317                        let factor = -1.0 / m_a;
2318                        for j in 0..link_dev_dim {
2319                            a_w[[local_row, j]] = factor * f_w_row[j];
2320                        }
2321                    }
2322                    a_b.as_mut().unwrap()[local_row] = -f_b / m_a;
2323                }
2324
2325                Ok(FlexSolveChunk {
2326                    start,
2327                    end,
2328                    intercepts,
2329                    a_q,
2330                    a_b,
2331                    a_h,
2332                    a_w,
2333                })
2334            })
2335            .collect::<Vec<_>>();
2336
2337        let mut intercepts = Array1::<f64>::zeros(n);
2338        let mut a_q_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
2339        let mut a_b_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
2340        let mut a_h_rows = if need_gradient && score_warp_dim > 0 {
2341            Some(Array2::<f64>::zeros((n, score_warp_dim)))
2342        } else {
2343            None
2344        };
2345        let mut a_w_rows = if need_gradient && link_dev_dim > 0 {
2346            Some(Array2::<f64>::zeros((n, link_dev_dim)))
2347        } else {
2348            None
2349        };
2350
2351        for solve_chunk in solve_chunks {
2352            let chunk = solve_chunk?;
2353            intercepts
2354                .slice_mut(ndarray::s![chunk.start..chunk.end])
2355                .assign(&chunk.intercepts);
2356            if let (Some(target), Some(source)) = (a_q_vec.as_mut(), chunk.a_q.as_ref()) {
2357                target
2358                    .slice_mut(ndarray::s![chunk.start..chunk.end])
2359                    .assign(source);
2360            }
2361            if let (Some(target), Some(source)) = (a_b_vec.as_mut(), chunk.a_b.as_ref()) {
2362                target
2363                    .slice_mut(ndarray::s![chunk.start..chunk.end])
2364                    .assign(source);
2365            }
2366            if let (Some(target), Some(source)) = (a_h_rows.as_mut(), chunk.a_h.as_ref()) {
2367                target
2368                    .slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2369                    .assign(source);
2370            }
2371            if let (Some(target), Some(source)) = (a_w_rows.as_mut(), chunk.a_w.as_ref()) {
2372                target
2373                    .slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2374                    .assign(source);
2375            }
2376        }
2377
2378        let eta_base = &intercepts + &(&logslope_eta * &z);
2379
2380        let mut link_c_obs: Option<Array1<f64>> = None;
2381        let mut link_basis_obs: Option<Array2<f64>> = None;
2382        let link_dev_obs = if let (Some(runtime), Some(beta_owned)) = (
2383            self.link_deviation_runtime.as_ref(),
2384            link_dev_beta_owned.as_ref(),
2385        ) {
2386            let basis = if runtime.anchor_residual_coefficients.is_some() {
2387                let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2388                    EstimationError::InvalidInput(
2389                        "bernoulli marginal-slope link-deviation anchor residual present but \
2390                         anchor_corrections bundle is missing the parametric anchor rows"
2391                            .to_string(),
2392                    )
2393                })?;
2394                runtime
2395                    .design_with_anchor_rows(&eta_base, anchor_rows)
2396                    .map_err(EstimationError::InvalidInput)?
2397            } else {
2398                runtime
2399                    .design(&eta_base)
2400                    .map_err(EstimationError::InvalidInput)?
2401            };
2402            let dev = basis.dot(beta_owned);
2403            if need_gradient {
2404                let d1 = runtime
2405                    .first_derivative_design(&eta_base)
2406                    .map_err(EstimationError::InvalidInput)?;
2407                let mut c_obs = d1.dot(beta_owned);
2408                c_obs.mapv_inplace(|v| v + 1.0);
2409                link_c_obs = Some(c_obs);
2410                link_basis_obs = Some(basis);
2411            }
2412            dev
2413        } else {
2414            Array1::zeros(n)
2415        };
2416        let final_eta_internal =
2417            (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2418
2419        if !need_gradient {
2420            return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
2421        }
2422
2423        let a_q_vec = a_q_vec.unwrap();
2424        let a_b_vec = a_b_vec.unwrap();
2425
2426        // Emit chunk Jacobians using precomputed scalars.
2427        struct FlexGradientChunk {
2428            start: usize,
2429            end: usize,
2430            grad: Array2<f64>,
2431        }
2432        let grad_chunks = (0..num_chunks)
2433            .into_par_iter()
2434            .map(|chunk_idx| -> Result<FlexGradientChunk, String> {
2435                let start = chunk_idx * chunk_size;
2436                let end = (start + chunk_size).min(n);
2437                let mc = input
2438                    .design
2439                    .try_row_chunk(start..end)
2440                    .map_err(|e| e.to_string())?;
2441                let lc = design_logslope
2442                    .try_row_chunk(start..end)
2443                    .map_err(|e| e.to_string())?;
2444                let rows = end - start;
2445                let mut grad = Array2::<f64>::zeros((rows, theta.len()));
2446
2447                for li in 0..rows {
2448                    let i = start + li;
2449                    let mut row = grad.row_mut(li);
2450
2451                    let a_q = a_q_vec[i];
2452                    for j in 0..marginal_dim {
2453                        row[j] = a_q * mc[[li, j]];
2454                    }
2455
2456                    let base_multiplier = link_c_obs.as_ref().map_or(1.0, |c| c[i]);
2457                    let g_scale = base_multiplier * (a_b_vec[i] + z[i]) + score_dev_obs[i];
2458                    for j in 0..logslope_dim {
2459                        row[logslope_offset + j] = g_scale * lc[[li, j]];
2460                    }
2461
2462                    if let (Some(a_h_rows), Some(obs_design)) =
2463                        (a_h_rows.as_ref(), score_warp_obs_design.as_ref())
2464                    {
2465                        let slope = logslope_eta[i];
2466                        for j in 0..score_warp_dim {
2467                            row[score_warp_offset + j] =
2468                                base_multiplier * a_h_rows[[i, j]] + slope * obs_design[[i, j]];
2469                        }
2470                    }
2471
2472                    if let Some(a_w_rows) = a_w_rows.as_ref() {
2473                        for j in 0..link_dev_dim {
2474                            row[link_dev_offset + j] = a_w_rows[[i, j]];
2475                        }
2476                    }
2477
2478                    if let (Some(link_c), Some(link_basis)) =
2479                        (link_c_obs.as_ref(), link_basis_obs.as_ref())
2480                    {
2481                        let c = link_c[i];
2482                        for j in 0..marginal_dim {
2483                            row[j] *= c;
2484                        }
2485                        for j in 0..link_dev_dim {
2486                            row[link_dev_offset + j] =
2487                                c * row[link_dev_offset + j] + link_basis[[i, j]];
2488                        }
2489                    }
2490                }
2491
2492                Ok(FlexGradientChunk { start, end, grad })
2493            })
2494            .collect::<Result<Vec<_>, String>>()
2495            .map_err(EstimationError::InvalidInput)?;
2496        let mut grad = Array2::<f64>::zeros((n, theta.len()));
2497        for chunk in grad_chunks {
2498            grad.slice_mut(ndarray::s![chunk.start..chunk.end, ..])
2499                .assign(&chunk.grad);
2500        }
2501        if scale != 1.0 {
2502            grad.mapv_inplace(|v| scale * v);
2503        }
2504        self.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad))
2505    }
2506
2507    fn final_eta_from_theta(
2508        &self,
2509        input: &PredictInput,
2510        theta: &Array1<f64>,
2511    ) -> Result<Array1<f64>, EstimationError> {
2512        let (eta, _) = self.final_eta_and_gradient_from_theta(input, theta, false)?;
2513        Ok(eta)
2514    }
2515
2516    fn eta_standard_error_from_covariance(
2517        &self,
2518        input: &PredictInput,
2519        covariance: &Array2<f64>,
2520    ) -> Result<Array1<f64>, EstimationError> {
2521        let theta = self.theta();
2522        let backend = PredictionCovarianceBackend::from_dense(covariance.view());
2523        linear_predictor_se_from_backend(&backend, input.design.nrows(), |rows| {
2524            let chunk_input = slice_predict_input(input, rows).map_err(|e| e.to_string())?;
2525            let (_, grad) = self
2526                .final_eta_and_gradient_from_theta(&chunk_input, &theta, true)
2527                .map_err(|e| e.to_string())?;
2528            let grad = grad.ok_or_else(|| {
2529                "bernoulli marginal-slope analytic predictor gradient was not produced".to_string()
2530            })?;
2531            Ok(vec![grad])
2532        })
2533    }
2534
2535    fn eta_standard_error(
2536        &self,
2537        input: &PredictInput,
2538        fit: &UnifiedFitResult,
2539    ) -> Result<Array1<f64>, EstimationError> {
2540        let theta = self.theta();
2541        let backend = require_posterior_mean_backend(
2542            fit,
2543            self.covariance.as_ref(),
2544            theta.len(),
2545            "bernoulli marginal-slope posterior mean",
2546        )?;
2547        linear_predictor_se_from_backend(&backend, input.design.nrows(), |rows| {
2548            let chunk_input = slice_predict_input(input, rows).map_err(|e| e.to_string())?;
2549            let (_, grad) = self
2550                .final_eta_and_gradient_from_theta(&chunk_input, &theta, true)
2551                .map_err(|e| e.to_string())?;
2552            let grad = grad.ok_or_else(|| {
2553                "bernoulli marginal-slope analytic predictor gradient was not produced".to_string()
2554            })?;
2555            Ok(vec![grad])
2556        })
2557    }
2558
2559    /// Per-row `(eta, ∂eta/∂q_marginal)` under the exact IFT pull-back.
2560    ///
2561    /// Returns the same `eta` as `predict_plugin_response`/`predict_linear_predictor`
2562    /// plus the analytic derivative of the internal probit index with respect to
2563    /// the per-row marginal q (the linear predictor before the de-nested
2564    /// calibration). Survival prediction multiplies the second component by the
2565    /// per-row `dq/dt` to obtain the exact hazard time derivative under
2566    /// score-warp / link-deviation flex blocks.
2567    ///
2568    /// Rigid path (no flex blocks): `∂eta/∂q = c = sqrt(1 + (s b)^2)`, recovering
2569    /// the rigid-path probit-frailty composition. Flex path: `∂eta/∂q =
2570    /// scale · link_c_obs · a_q` where `link_c_obs = 1 + Δ_w'(eta_base)` is the
2571    /// link-deviation slope at the observed `eta_base = a + b z` and `a_q =
2572    /// φ(q) / |F_a|` is the implicit-function derivative of the calibration
2573    /// intercept (mirrors the bernoulli `final_eta_and_gradient_from_theta`
2574    /// flex branch lines 1399-1593).
2575    pub fn predict_eta_and_q_chain(
2576        &self,
2577        input: &PredictInput,
2578    ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
2579        let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
2580            EstimationError::InvalidInput(format!(
2581                "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
2582                self.z_column
2583            ))
2584        })?;
2585        let z_normalized = self
2586            .latent_z_normalization
2587            .apply(z_raw, "bernoulli marginal-slope prediction")
2588            .map_err(EstimationError::InvalidInput)?;
2589        // P4: see `final_eta_and_gradient_from_theta` for the rationale.
2590        // The rank-INT calibration is a mathematically exact monotone
2591        // transform; both the rigid standard-normal kernel and the
2592        // implicit-function chain rule consume the calibrated z, never
2593        // the raw normalized z, exactly mirroring fit-time semantics.
2594        let z = self.apply_latent_z_calibration(&z_normalized);
2595        let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
2596            EstimationError::InvalidInput(
2597                "bernoulli marginal-slope prediction requires logslope design".to_string(),
2598            )
2599        })?;
2600        let n = z.len();
2601        if input.offset.len() != n {
2602            return Err(EstimationError::InvalidInput(format!(
2603                "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
2604                input.offset.len()
2605            )));
2606        }
2607        let logslope_offset = input
2608            .offset_noise
2609            .as_ref()
2610            .map_or_else(|| Array1::zeros(n), Clone::clone);
2611        if logslope_offset.len() != n {
2612            return Err(EstimationError::InvalidInput(format!(
2613                "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
2614                logslope_offset.len()
2615            )));
2616        }
2617        let marginal_eta = input
2618            .design
2619            .dot(&self.beta_marginal)
2620            .mapv(|v| v + self.baseline_marginal)
2621            + &input.offset;
2622        let logslope_eta = design_logslope
2623            .dot(&self.beta_logslope)
2624            .mapv(|v| v + self.baseline_logslope)
2625            + &logslope_offset;
2626        let scale = self.probit_frailty_scale();
2627        let flex_active =
2628            self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
2629
2630        // Rigid path mirrors `final_eta_and_gradient_from_theta` lines 1342-1383:
2631        //   eta = c·q + s·b·z,  ∂eta/∂q = c.
2632        if !flex_active {
2633            match &self.latent_measure {
2634                LatentMeasureKind::StandardNormal => {
2635                    // Vectorize: sb = scale·logslope, c = sqrt(1 + sb²),
2636                    // eta = c·marginal_eta + sb·z, ∂eta/∂q = c.
2637                    let sb = logslope_eta.mapv(|x| scale * x);
2638                    let deta_dq = sb.mapv(|s| (1.0 + s * s).sqrt());
2639                    let eta = &deta_dq * marginal_eta + &sb * z;
2640                    return Ok((eta, deta_dq));
2641                }
2642                _ => {
2643                    let mut eta = Array1::<f64>::zeros(n);
2644                    let mut deta_dq = Array1::<f64>::zeros(n);
2645                    for i in 0..n {
2646                        let grid = self
2647                            .empirical_grid_for_prediction_row(input, i)?
2648                            .ok_or_else(|| {
2649                                EstimationError::InvalidInput(
2650                                    "empirical latent prediction did not produce a row grid"
2651                                        .to_string(),
2652                                )
2653                            })?;
2654                        let (intercept, a_marginal, _) = self
2655                            .empirical_rigid_intercept_and_gradient(
2656                                marginal_eta[i],
2657                                logslope_eta[i],
2658                                &grid.nodes,
2659                                &grid.weights,
2660                            )?;
2661                        eta[i] = intercept + scale * logslope_eta[i] * z[i];
2662                        deta_dq[i] = a_marginal;
2663                    }
2664                    return Ok((eta, deta_dq));
2665                }
2666            }
2667        }
2668
2669        // Flex path: solve the per-row intercept, then evaluate
2670        //   eta = scale · (a + b·z + b·Δ_h(z) + Δ_w(a + b·z))
2671        //   ∂eta/∂q = scale · (1 + Δ_w'(a + b·z)) · ∂a/∂q,
2672        //   ∂a/∂q   = φ(q) / |F_a|         (IFT, marginal_link is probit so mu1 = φ(q))
2673        // Mirrors `final_eta_and_gradient_from_theta` lines 1385-1621.
2674        let marginal_map = marginal_eta
2675            .iter()
2676            .map(|&eta_marg| {
2677                bernoulli_marginal_link_map(&self.base_link, eta_marg)
2678                    .map_err(EstimationError::InvalidInput)
2679            })
2680            .collect::<Result<Vec<_>, _>>()?;
2681        // Cross-block anchor corrections (see final_eta_and_gradient_from_theta
2682        // for the design); precompute once before the per-row loop.
2683        let anchor_corrections = self.build_anchor_correction_matrices(input, design_logslope)?;
2684        // Per-row: solve intercept scalar, evaluate denested calibration,
2685        // record (intercept, a_q). The `warm_start_buf` is just per-call
2686        // scratch — give each rayon worker its own buffer via fold init.
2687        use rayon::iter::{IntoParallelIterator, ParallelIterator};
2688        let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
2689            .into_par_iter()
2690            .map_init(
2691                || Array1::<f64>::zeros(1),
2692                |warm_start_buf, i| {
2693                    let q = marginal_eta[i];
2694                    let slope = logslope_eta[i];
2695                    let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2696                    let score_corr_row = anchor_corrections.score_warp_row(i);
2697                    let link_corr_row = anchor_corrections.link_dev_row(i);
2698                    let intercept = self.solve_intercept_scalar(
2699                        q,
2700                        slope,
2701                        self.beta_link_dev.as_ref(),
2702                        self.beta_score_warp.as_ref(),
2703                        empirical_grid.as_ref(),
2704                        warm_start_buf,
2705                        score_corr_row,
2706                        link_corr_row,
2707                    )?;
2708                    let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2709                        intercept,
2710                        q,
2711                        slope,
2712                        self.beta_score_warp.as_ref(),
2713                        self.beta_link_dev.as_ref(),
2714                        empirical_grid.as_ref(),
2715                        score_corr_row,
2716                        link_corr_row,
2717                    )?;
2718                    let m_a = m_a_raw.max(1e-12);
2719                    Ok((intercept, marginal_map[i].mu1 / m_a))
2720                },
2721            )
2722            .collect();
2723        let pairs = pairs?;
2724        let mut intercepts = Array1::<f64>::zeros(n);
2725        let mut a_q = Array1::<f64>::zeros(n);
2726        for (i, (intercept, a)) in pairs.into_iter().enumerate() {
2727            intercepts[i] = intercept;
2728            a_q[i] = a;
2729        }
2730
2731        let score_dev_obs = if let (Some(runtime), Some(beta)) = (
2732            self.score_warp_runtime.as_ref(),
2733            self.beta_score_warp.as_ref(),
2734        ) {
2735            let design = if runtime.anchor_residual_coefficients.is_some() {
2736                let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2737                    EstimationError::InvalidInput(
2738                        "bernoulli marginal-slope score-warp anchor residual present but \
2739                         anchor_corrections bundle is missing the parametric anchor rows"
2740                            .to_string(),
2741                    )
2742                })?;
2743                runtime
2744                    .design_with_anchor_rows(&z, anchor_rows)
2745                    .map_err(EstimationError::InvalidInput)?
2746            } else {
2747                runtime.design(&z).map_err(EstimationError::InvalidInput)?
2748            };
2749            design.dot(beta)
2750        } else {
2751            Array1::zeros(n)
2752        };
2753        let eta_base = &intercepts + &(&logslope_eta * &z);
2754        let (link_dev_obs, link_c_obs) = if let (Some(runtime), Some(beta)) = (
2755            self.link_deviation_runtime.as_ref(),
2756            self.beta_link_dev.as_ref(),
2757        ) {
2758            let basis = if runtime.anchor_residual_coefficients.is_some() {
2759                let anchor_rows = anchor_corrections.n_anchor_rows_view().ok_or_else(|| {
2760                    EstimationError::InvalidInput(
2761                        "bernoulli marginal-slope link-deviation anchor residual present but \
2762                         anchor_corrections bundle is missing the parametric anchor rows"
2763                            .to_string(),
2764                    )
2765                })?;
2766                runtime
2767                    .design_with_anchor_rows(&eta_base, anchor_rows)
2768                    .map_err(EstimationError::InvalidInput)?
2769            } else {
2770                runtime
2771                    .design(&eta_base)
2772                    .map_err(EstimationError::InvalidInput)?
2773            };
2774            let dev = basis.dot(beta);
2775            let d1 = runtime
2776                .first_derivative_design(&eta_base)
2777                .map_err(EstimationError::InvalidInput)?;
2778            let mut c_obs = d1.dot(beta);
2779            c_obs.mapv_inplace(|v| v + 1.0);
2780            (dev, c_obs)
2781        } else {
2782            (Array1::zeros(n), Array1::ones(n))
2783        };
2784        let final_eta_internal =
2785            (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2786        let deta_dq = (&link_c_obs * &a_q).mapv(|v| scale * v);
2787        Ok((final_eta_internal, deta_dq))
2788    }
2789}
2790
2791impl PredictableModel for BernoulliMarginalSlopePredictor {
2792    fn predict_plugin_response(
2793        &self,
2794        input: &PredictInput,
2795    ) -> Result<PredictResult, EstimationError> {
2796        let eta = self.final_eta_from_theta(input, &self.theta())?;
2797        let mean = self.mean_from_eta(&eta)?;
2798        Ok(PredictResult { eta, mean })
2799    }
2800
2801    fn predict_with_uncertainty(
2802        &self,
2803        input: &PredictInput,
2804    ) -> Result<PredictionWithSE, EstimationError> {
2805        let plugin = self.predict_plugin_response(input)?;
2806        let (eta_se, mean_se) = if let Some(covariance) = self.covariance.as_ref() {
2807            let theta = self.theta();
2808            if covariance.nrows() != theta.len() || covariance.ncols() != theta.len() {
2809                return Err(EstimationError::InvalidInput(format!(
2810                    "bernoulli marginal-slope covariance dimension mismatch: expected {}x{}, got {}x{}",
2811                    theta.len(),
2812                    theta.len(),
2813                    covariance.nrows(),
2814                    covariance.ncols()
2815                )));
2816            }
2817            let eta_se = self.eta_standard_error_from_covariance(input, covariance)?;
2818            let mean_se = eta_se.clone() * self.mean_derivative_from_eta(&plugin.eta)?;
2819            (Some(eta_se), Some(mean_se))
2820        } else {
2821            (None, None)
2822        };
2823        Ok(PredictionWithSE {
2824            eta: plugin.eta,
2825            mean: plugin.mean,
2826            eta_se,
2827            mean_se,
2828        })
2829    }
2830
2831    fn predict_noise_scale(
2832        &self,
2833        _: &PredictInput,
2834    ) -> Result<Option<Array1<f64>>, EstimationError> {
2835        Ok(None)
2836    }
2837
2838    fn predict_full_uncertainty(
2839        &self,
2840        input: &PredictInput,
2841        fit: &UnifiedFitResult,
2842        options: &PredictUncertaintyOptions,
2843    ) -> Result<PredictUncertaintyResult, EstimationError> {
2844        let plugin = self.predict_plugin_response(input)?;
2845        let eta_se = self.eta_standard_error(input, fit)?;
2846        let zcrit = standard_normal_quantile(0.5 + options.confidence_level * 0.5)
2847            .map_err(EstimationError::InvalidInput)?;
2848        let eta_lower = &plugin.eta - &eta_se.mapv(|s| zcrit * s);
2849        let eta_upper = &plugin.eta + &eta_se.mapv(|s| zcrit * s);
2850        let mean_lower = self.mean_from_eta(&eta_lower)?;
2851        let mean_upper = self.mean_from_eta(&eta_upper)?;
2852        let mean_se = eta_se.clone() * self.mean_derivative_from_eta(&plugin.eta)?;
2853        Ok(PredictUncertaintyResult {
2854            eta: plugin.eta,
2855            mean: plugin.mean,
2856            eta_standard_error: eta_se.clone(),
2857            mean_standard_error: mean_se,
2858            eta_lower,
2859            eta_upper,
2860            mean_lower,
2861            mean_upper,
2862            observation_lower: None,
2863            observation_upper: None,
2864            covariance_mode_requested: options.covariance_mode,
2865            covariance_corrected_used: false,
2866        })
2867    }
2868
2869    fn predict_posterior_mean(
2870        &self,
2871        input: &PredictInput,
2872        fit: &UnifiedFitResult,
2873        confidence_level: Option<f64>,
2874    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
2875        let plugin = self.predict_plugin_response(input)?;
2876        let eta_se = self.eta_standard_error(input, fit)?;
2877        let strategy = strategy_for_family(self.likelihood_family(), Some(&self.base_link));
2878        let quadctx = crate::quadrature::QuadratureContext::new();
2879        let mean = Array1::from_iter(
2880            plugin
2881                .eta
2882                .iter()
2883                .zip(eta_se.iter())
2884                .map(|(&eta, &se)| strategy.posterior_mean(&quadctx, eta, se))
2885                .collect::<Result<Vec<_>, _>>()?,
2886        );
2887        let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
2888            let z = standard_normal_quantile(0.5 + 0.5 * level)
2889                .map_err(EstimationError::InvalidInput)?;
2890            let eta_lower = &plugin.eta - &eta_se.mapv(|s| z * s);
2891            let eta_upper = &plugin.eta + &eta_se.mapv(|s| z * s);
2892            (
2893                Some(self.mean_from_eta(&eta_lower)?),
2894                Some(self.mean_from_eta(&eta_upper)?),
2895            )
2896        } else {
2897            (None, None)
2898        };
2899        Ok(PredictPosteriorMeanResult {
2900            eta: plugin.eta,
2901            eta_standard_error: eta_se,
2902            mean,
2903            mean_lower,
2904            mean_upper,
2905        })
2906    }
2907
2908    fn n_blocks(&self) -> usize {
2909        2 + usize::from(self.beta_score_warp.is_some()) + usize::from(self.beta_link_dev.is_some())
2910    }
2911
2912    fn block_roles(&self) -> Vec<BlockRole> {
2913        let mut roles = vec![BlockRole::Location, BlockRole::Scale];
2914        if self.beta_score_warp.is_some() {
2915            roles.push(BlockRole::Mean);
2916        }
2917        if self.beta_link_dev.is_some() {
2918            roles.push(BlockRole::LinkWiggle);
2919        }
2920        roles
2921    }
2922}
2923
2924/// Gaussian location-scale predictor: two blocks (mean + log-sigma).
2925///
2926/// Predicts `mean = X_mu @ beta_mu` (identity link on mean) and
2927/// `sigma = (LOGB_SIGMA_FLOOR + exp(X_noise @ beta_noise + offset_noise)) * response_scale`.
2928pub struct GaussianLocationScalePredictor {
2929    pub beta_mu: Array1<f64>,
2930    pub beta_noise: Array1<f64>,
2931    pub response_scale: f64,
2932    pub covariance: Option<Array2<f64>>,
2933    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
2934}
2935
2936impl GaussianLocationScalePredictor {
2937    /// Compute σ = (LOGB_SIGMA_FLOOR + exp(η_noise + offset_noise)) · response_scale.
2938    /// The logb link bounds σ ≥ LOGB_SIGMA_FLOOR · response_scale > 0 in
2939    /// response units, matching the fit-time parameterization in
2940    /// `gaussian_diagonal_row_kernel`. The previous `clamp(-500, 500)` on η
2941    /// was a defensive guard against `exp` underflow with the pure-exp link;
2942    /// it is unnecessary here because the floor keeps σ representable for any
2943    /// finite η.
2944    fn compute_sigma(
2945        &self,
2946        design_noise: &DesignMatrix,
2947        offset_noise: Option<&Array1<f64>>,
2948    ) -> Result<Array1<f64>, EstimationError> {
2949        let mut eta_noise = design_noise.dot(&self.beta_noise);
2950        if let Some(offset_noise) = offset_noise {
2951            if offset_noise.len() != eta_noise.len() {
2952                return Err(EstimationError::InvalidInput(format!(
2953                    "gaussian location-scale noise offset length mismatch: expected {}, got {}",
2954                    eta_noise.len(),
2955                    offset_noise.len()
2956                )));
2957            }
2958            eta_noise += offset_noise;
2959        }
2960        let scale = self.response_scale;
2961        Ok(eta_noise
2962            .mapv(|eta| crate::families::sigma_link::logb_sigma_from_eta_scalar(eta) * scale))
2963    }
2964
2965    fn eta_standard_error(
2966        &self,
2967        input: &PredictInput,
2968        fit: &UnifiedFitResult,
2969        eta_len: usize,
2970    ) -> Result<Array1<f64>, EstimationError> {
2971        let backend = require_posterior_mean_backend(
2972            fit,
2973            self.covariance.as_ref(),
2974            self.beta_mu.len()
2975                + self.beta_noise.len()
2976                + self.link_wiggle.as_ref().map_or(0, |w| w.beta.len()),
2977            "gaussian location-scale posterior mean",
2978        )?;
2979        let p_mu = self.beta_mu.len();
2980        let p_sigma = self.beta_noise.len();
2981        let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
2982        let p_total = p_mu + p_sigma + p_w;
2983        if backend.nrows() != p_total {
2984            return Err(EstimationError::InvalidInput(format!(
2985                "gaussian location-scale covariance mismatch: expected parameter dimension {}, got {}",
2986                p_total,
2987                backend.nrows()
2988            )));
2989        }
2990        self.eta_standard_error_from_backend(input, &backend, eta_len, p_mu, p_sigma, p_w)
2991    }
2992
2993    fn eta_standard_error_from_backend(
2994        &self,
2995        input: &PredictInput,
2996        backend: &PredictionCovarianceBackend<'_>,
2997        eta_len: usize,
2998        p_mu: usize,
2999        p_sigma: usize,
3000        p_w: usize,
3001    ) -> Result<Array1<f64>, EstimationError> {
3002        let p_total = p_mu + p_sigma + p_w;
3003        if backend.nrows() != p_total {
3004            return Err(EstimationError::InvalidInput(format!(
3005                "gaussian location-scale covariance mismatch: expected parameter dimension {}, got {}",
3006                p_total,
3007                backend.nrows()
3008            )));
3009        }
3010        if let Some(runtime) = self.link_wiggle.as_ref() {
3011            let eta_base = input.design.dot(&self.beta_mu) + &input.offset;
3012            linear_predictor_se_from_backend(&backend, eta_len, |rows| {
3013                let q0_chunk = eta_base.slice(ndarray::s![rows.clone()]).to_owned();
3014                let x_mu = design_row_chunk(&input.design, rows.clone())?;
3015                let wiggle_design = runtime.design(&q0_chunk)?;
3016                let dq_dq0 = runtime.derivative_q0(&q0_chunk)?;
3017                let rows_in_chunk = q0_chunk.len();
3018                let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3019                for i in 0..rows_in_chunk {
3020                    for j in 0..p_mu {
3021                        grad[[i, j]] = dq_dq0[i] * x_mu[[i, j]];
3022                    }
3023                }
3024                grad.slice_mut(ndarray::s![.., p_mu + p_sigma..p_total])
3025                    .assign(&wiggle_design);
3026                Ok(vec![grad])
3027            })
3028        } else {
3029            padded_design_standard_errors_from_backend(
3030                &input.design,
3031                &backend,
3032                0,
3033                p_sigma + p_w,
3034                "gaussian location-scale posterior mean",
3035            )
3036        }
3037    }
3038}
3039
3040impl PredictableModel for GaussianLocationScalePredictor {
3041    fn predict_plugin_response(
3042        &self,
3043        input: &PredictInput,
3044    ) -> Result<PredictResult, EstimationError> {
3045        let eta_base = input.design.dot(&self.beta_mu) + &input.offset;
3046        let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
3047            runtime
3048                .apply(&eta_base)
3049                .map_err(EstimationError::InvalidInput)?
3050        } else {
3051            eta_base
3052        };
3053        // Gaussian identity link: mean = eta.
3054        let mean = eta.clone();
3055        Ok(PredictResult { eta, mean })
3056    }
3057
3058    fn predict_with_uncertainty(
3059        &self,
3060        input: &PredictInput,
3061    ) -> Result<PredictionWithSE, EstimationError> {
3062        let result = self.predict_plugin_response(input)?;
3063        let (eta_se, mean_se) = if let Some(covariance) = self.covariance.as_ref() {
3064            let p_mu = self.beta_mu.len();
3065            let p_sigma = self.beta_noise.len();
3066            let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3067            let backend = PredictionCovarianceBackend::from_dense(covariance.view());
3068            let eta_se = self.eta_standard_error_from_backend(
3069                input,
3070                &backend,
3071                result.eta.len(),
3072                p_mu,
3073                p_sigma,
3074                p_w,
3075            )?;
3076            (Some(eta_se.clone()), Some(eta_se))
3077        } else {
3078            (None, None)
3079        };
3080        Ok(PredictionWithSE {
3081            eta: result.eta,
3082            mean: result.mean,
3083            eta_se,
3084            mean_se,
3085        })
3086    }
3087
3088    fn predict_noise_scale(
3089        &self,
3090        input: &PredictInput,
3091    ) -> Result<Option<Array1<f64>>, EstimationError> {
3092        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3093            EstimationError::InvalidInput(
3094                "Gaussian location-scale prediction requires noise design matrix".to_string(),
3095            )
3096        })?;
3097        self.compute_sigma(design_noise, input.offset_noise.as_ref())
3098            .map(Some)
3099    }
3100
3101    fn predict_full_uncertainty(
3102        &self,
3103        input: &PredictInput,
3104        fit: &UnifiedFitResult,
3105        options: &PredictUncertaintyOptions,
3106    ) -> Result<PredictUncertaintyResult, EstimationError> {
3107        let pred = self.predict_plugin_response(input)?;
3108        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3109            EstimationError::InvalidInput(
3110                "Gaussian location-scale prediction requires noise design matrix".to_string(),
3111            )
3112        })?;
3113        let sigma = self.compute_sigma(design_noise, input.offset_noise.as_ref())?;
3114        let eta_se = self.eta_standard_error(input, fit, pred.eta.len())?;
3115        let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3116            .map_err(|e| EstimationError::InvalidInput(e))?;
3117        let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
3118        let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
3119        Ok(PredictUncertaintyResult {
3120            eta: pred.eta.clone(),
3121            mean: pred.mean.clone(),
3122            eta_standard_error: eta_se.clone(),
3123            mean_standard_error: eta_se.clone(),
3124            eta_lower: eta_lower.clone(),
3125            eta_upper: eta_upper.clone(),
3126            mean_lower: eta_lower,
3127            mean_upper: eta_upper,
3128            observation_lower: options
3129                .includeobservation_interval
3130                .then(|| &pred.mean - &sigma.mapv(|s| z * s)),
3131            observation_upper: options
3132                .includeobservation_interval
3133                .then(|| &pred.mean + &sigma.mapv(|s| z * s)),
3134            covariance_mode_requested: options.covariance_mode,
3135            covariance_corrected_used: false,
3136        })
3137    }
3138
3139    fn predict_posterior_mean(
3140        &self,
3141        input: &PredictInput,
3142        fit: &UnifiedFitResult,
3143        confidence_level: Option<f64>,
3144    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3145        let result = self.predict_plugin_response(input)?;
3146        let eta_se = self.eta_standard_error(input, fit, result.eta.len())?;
3147        // Gaussian identity link: mean == eta, so bounds are eta ± z·se.
3148        let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
3149            let z = standard_normal_quantile(0.5 + 0.5 * level)
3150                .map_err(EstimationError::InvalidInput)?;
3151            (
3152                Some(&result.eta - &eta_se.mapv(|s| z * s)),
3153                Some(&result.eta + &eta_se.mapv(|s| z * s)),
3154            )
3155        } else {
3156            (None, None)
3157        };
3158        Ok(PredictPosteriorMeanResult {
3159            eta: result.eta,
3160            eta_standard_error: eta_se,
3161            mean: result.mean,
3162            mean_lower,
3163            mean_upper,
3164        })
3165    }
3166
3167    fn n_blocks(&self) -> usize {
3168        if self.link_wiggle.is_some() { 3 } else { 2 }
3169    }
3170
3171    fn block_roles(&self) -> Vec<BlockRole> {
3172        if self.link_wiggle.is_some() {
3173            vec![BlockRole::Location, BlockRole::Scale, BlockRole::LinkWiggle]
3174        } else {
3175            vec![BlockRole::Location, BlockRole::Scale]
3176        }
3177    }
3178}
3179
3180/// Binomial location-scale predictor: two blocks (threshold + log-sigma).
3181///
3182/// Predicts probabilities through the threshold-scale parameterisation:
3183///   eta_t = X_threshold @ beta_threshold + offset
3184///   eta_s = X_noise @ beta_noise + offset_noise
3185///   sigma = exp(eta_s)
3186///   q0    = -eta_t / sigma
3187///   prob  = inverse_link(q0)
3188///
3189/// Delta-method SEs propagate through the chain rule of q0 w.r.t. both
3190/// linear predictors.
3191pub struct BinomialLocationScalePredictor {
3192    pub beta_threshold: Array1<f64>,
3193    pub beta_noise: Array1<f64>,
3194    pub covariance: Option<Array2<f64>>,
3195    pub inverse_link: InverseLink,
3196    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
3197}
3198
3199impl BinomialLocationScalePredictor {
3200    /// Compute q0 = -eta_t * exp(-eta_s) for each observation, where
3201    /// eta_t is the threshold linear predictor and sigma = exp(eta_s).
3202    ///
3203    /// Returns (q0_base, sigma, eta_t).
3204    fn compute_q0_and_sigma(
3205        &self,
3206        input: &PredictInput,
3207    ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
3208        let eta_t = input.design.dot(&self.beta_threshold) + &input.offset;
3209        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3210            EstimationError::InvalidInput(
3211                "Binomial location-scale prediction requires noise design matrix".to_string(),
3212            )
3213        })?;
3214        let offset_noise = input
3215            .offset_noise
3216            .as_ref()
3217            .map_or_else(|| Array1::zeros(design_noise.nrows()), |o| o.clone());
3218        let eta_s = design_noise.dot(&self.beta_noise) + &offset_noise;
3219        // Floor sigma to prevent division by zero when eta_s underflows.
3220        let sigma = eta_s.mapv(|v| v.exp().max(f64::MIN_POSITIVE));
3221        let q0 = Array1::from_shape_fn(eta_t.len(), |i| (-eta_t[i] / sigma[i]).clamp(-1e6, 1e6));
3222        Ok((q0, sigma, eta_t))
3223    }
3224
3225    /// Apply the saved wiggle (if present) and then the inverse link to q0.
3226    fn apply_link(&self, q0: &Array1<f64>) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
3227        let eta = if let Some(runtime) = self.link_wiggle.as_ref() {
3228            runtime.apply(q0).map_err(EstimationError::InvalidInput)?
3229        } else {
3230            q0.clone()
3231        };
3232        use rayon::iter::{IntoParallelIterator, ParallelIterator};
3233        let n = eta.len();
3234        let prob_vec: Result<Vec<f64>, EstimationError> = (0..n)
3235            .into_par_iter()
3236            .map(|i| {
3237                let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3238                    &self.inverse_link,
3239                    eta[i],
3240                )?;
3241                Ok(jet.mu.clamp(0.0, 1.0))
3242            })
3243            .collect();
3244        let prob = Array1::from_vec(prob_vec?);
3245        Ok((eta, prob))
3246    }
3247}
3248
3249impl PredictableModel for BinomialLocationScalePredictor {
3250    fn predict_plugin_response(
3251        &self,
3252        input: &PredictInput,
3253    ) -> Result<PredictResult, EstimationError> {
3254        let (q0_base, _, _) = self.compute_q0_and_sigma(input)?;
3255        let (eta, prob) = self.apply_link(&q0_base)?;
3256        Ok(PredictResult { eta, mean: prob })
3257    }
3258
3259    fn predict_with_uncertainty(
3260        &self,
3261        input: &PredictInput,
3262    ) -> Result<PredictionWithSE, EstimationError> {
3263        let (q0_base, sigma, eta_t) = self.compute_q0_and_sigma(input)?;
3264        let (eta, prob) = self.apply_link(&q0_base)?;
3265
3266        let mean_se = if let Some(ref cov) = self.covariance {
3267            let n = eta_t.len();
3268            let p_t = self.beta_threshold.len();
3269            let p_s = self.beta_noise.len();
3270            let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3271            let p_total = p_t + p_s + p_w;
3272            let backend = PredictionCovarianceBackend::from_dense(cov.view());
3273            if backend.nrows() != p_total {
3274                return Err(EstimationError::InvalidInput(format!(
3275                    "covariance dimension mismatch for binomial LS: expected parameter dimension {}, got {}",
3276                    p_total,
3277                    backend.nrows()
3278                )));
3279            }
3280
3281            let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3282                EstimationError::InvalidInput(
3283                    "binomial location-scale uncertainty requires noise design matrix".to_string(),
3284                )
3285            })?;
3286            Some(linear_predictor_se_from_backend(&backend, n, |rows| {
3287                let x_t = design_row_chunk(&input.design, rows.clone())?;
3288                let x_s = design_row_chunk(design_noise, rows.clone())?;
3289                let eta_chunk = eta.slice(ndarray::s![rows.clone()]).to_owned();
3290                let q0_chunk = q0_base.slice(ndarray::s![rows.clone()]).to_owned();
3291                let sigma_chunk = sigma.slice(ndarray::s![rows.clone()]).to_owned();
3292                let eta_t_chunk = eta_t.slice(ndarray::s![rows.clone()]).to_owned();
3293                let wiggle_design = if let Some(runtime) = self.link_wiggle.as_ref() {
3294                    Some(runtime.design(&q0_chunk)?)
3295                } else {
3296                    None
3297                };
3298                let dq_dq0 = if let Some(runtime) = self.link_wiggle.as_ref() {
3299                    runtime.derivative_q0(&q0_chunk)?
3300                } else {
3301                    Array1::ones(q0_chunk.len())
3302                };
3303                let rows_in_chunk = q0_chunk.len();
3304                let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3305                for i in 0..rows_in_chunk {
3306                    let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3307                        &self.inverse_link,
3308                        eta_chunk[i],
3309                    )
3310                    .map_err(|e| e.to_string())?;
3311                    let dphi = jet.d1;
3312                    let scale = dq_dq0[i];
3313                    let dprob_deta_t = dphi * scale * (-1.0 / sigma_chunk[i]);
3314                    // dq/dη_ls = eta_t / σ for the exact exp link.
3315                    let dprob_deta_s = dphi * scale * (eta_t_chunk[i] / sigma_chunk[i]);
3316                    for j in 0..p_t {
3317                        grad[[i, j]] = dprob_deta_t * x_t[[i, j]];
3318                    }
3319                    for j in 0..p_s {
3320                        grad[[i, p_t + j]] = dprob_deta_s * x_s[[i, j]];
3321                    }
3322                    if let Some(wd) = wiggle_design.as_ref() {
3323                        for j in 0..p_w {
3324                            grad[[i, p_t + p_s + j]] = dphi * wd[[i, j]];
3325                        }
3326                    }
3327                }
3328                Ok(vec![grad])
3329            })?)
3330        } else {
3331            None
3332        };
3333
3334        Ok(PredictionWithSE {
3335            eta,
3336            mean: prob,
3337            eta_se: None,
3338            mean_se,
3339        })
3340    }
3341
3342    fn predict_noise_scale(
3343        &self,
3344        _: &PredictInput,
3345    ) -> Result<Option<Array1<f64>>, EstimationError> {
3346        Ok(None)
3347    }
3348
3349    fn predict_full_uncertainty(
3350        &self,
3351        input: &PredictInput,
3352        _: &UnifiedFitResult,
3353        options: &PredictUncertaintyOptions,
3354    ) -> Result<PredictUncertaintyResult, EstimationError> {
3355        let pred = self.predict_with_uncertainty(input)?;
3356        let z = standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3357            .map_err(EstimationError::InvalidInput)?;
3358
3359        let mean_se = pred
3360            .mean_se
3361            .as_ref()
3362            .cloned()
3363            .unwrap_or_else(|| Array1::zeros(pred.mean.len()));
3364
3365        let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
3366        let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
3367        // Clamp probabilities to [0, 1].
3368        mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
3369        mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
3370
3371        // For binomial LS, eta intervals on the threshold predictor are not
3372        // directly meaningful for response-scale inference. Provide the
3373        // response-scale SE as the primary uncertainty measure.
3374        Ok(PredictUncertaintyResult {
3375            eta: pred.eta.clone(),
3376            mean: pred.mean.clone(),
3377            eta_standard_error: mean_se.clone(),
3378            mean_standard_error: mean_se,
3379            eta_lower: pred.eta.clone(),
3380            eta_upper: pred.eta,
3381            mean_lower,
3382            mean_upper,
3383            observation_lower: None,
3384            observation_upper: None,
3385            covariance_mode_requested: options.covariance_mode,
3386            covariance_corrected_used: false,
3387        })
3388    }
3389
3390    fn predict_posterior_mean(
3391        &self,
3392        input: &PredictInput,
3393        fit: &UnifiedFitResult,
3394        confidence_level: Option<f64>,
3395    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3396        // Validation target for this projected 2D GHQ path:
3397        // compare against 100K Monte Carlo draws under strong threshold/scale
3398        // posterior correlation and require agreement within ~0.01; as
3399        // covariance -> 0, the integrated mean must converge to the plug-in
3400        // point prediction row-wise.
3401        let (q0_base, sigma, eta_t) = self.compute_q0_and_sigma(input)?;
3402        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3403            EstimationError::InvalidInput(
3404                "Binomial location-scale posterior mean requires noise design matrix".to_string(),
3405            )
3406        })?;
3407        let offset_noise = input
3408            .offset_noise
3409            .as_ref()
3410            .map_or_else(|| Array1::zeros(design_noise.nrows()), |o| o.clone());
3411        let eta_s = design_noise.dot(&self.beta_noise) + &offset_noise;
3412        let (eta, _) = self.apply_link(&q0_base)?;
3413        let p_t = self.beta_threshold.len();
3414        let p_s = self.beta_noise.len();
3415        let p_w = self.link_wiggle.as_ref().map_or(0, |w| w.beta.len());
3416        let p_total = p_t + p_s + p_w;
3417        let backend = require_posterior_mean_backend(
3418            fit,
3419            self.covariance.as_ref(),
3420            p_total,
3421            "binomial location-scale posterior mean",
3422        )?;
3423
3424        let eta_se = linear_predictor_se_from_backend(&backend, eta_t.len(), |rows| {
3425            let x_t = design_row_chunk(&input.design, rows.clone())?;
3426            let x_s = design_row_chunk(design_noise, rows.clone())?;
3427            let eta_chunk = eta.slice(ndarray::s![rows.clone()]).to_owned();
3428            let q0_chunk = q0_base.slice(ndarray::s![rows.clone()]).to_owned();
3429            let sigma_chunk = sigma.slice(ndarray::s![rows.clone()]).to_owned();
3430            let eta_t_chunk = eta_t.slice(ndarray::s![rows.clone()]).to_owned();
3431            let wiggle_design = if let Some(runtime) = self.link_wiggle.as_ref() {
3432                Some(runtime.design(&q0_chunk)?)
3433            } else {
3434                None
3435            };
3436            let dq_dq0 = if let Some(runtime) = self.link_wiggle.as_ref() {
3437                runtime.derivative_q0(&q0_chunk)?
3438            } else {
3439                Array1::ones(q0_chunk.len())
3440            };
3441            let rows_in_chunk = q0_chunk.len();
3442            let row_gradients: Result<Vec<Vec<f64>>, String> = (0..rows_in_chunk)
3443                .into_par_iter()
3444                .map(|i| {
3445                    let jet = crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3446                        &self.inverse_link,
3447                        eta_chunk[i],
3448                    )
3449                    .map_err(|e| e.to_string())?;
3450                    let dphi = jet.d1;
3451                    let scale = dq_dq0[i];
3452                    let dprob_deta_t = dphi * scale * (-1.0 / sigma_chunk[i]);
3453                    let dprob_deta_s = dphi * scale * (eta_t_chunk[i] / sigma_chunk[i]);
3454                    let mut row = vec![0.0; p_total];
3455                    for j in 0..p_t {
3456                        row[j] = dprob_deta_t * x_t[[i, j]];
3457                    }
3458                    for j in 0..p_s {
3459                        row[p_t + j] = dprob_deta_s * x_s[[i, j]];
3460                    }
3461                    if let Some(wd) = wiggle_design.as_ref() {
3462                        for j in 0..p_w {
3463                            row[p_t + p_s + j] = dphi * wd[[i, j]];
3464                        }
3465                    }
3466                    Ok(row)
3467                })
3468                .collect();
3469            let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_total));
3470            for (i, row) in row_gradients?.into_iter().enumerate() {
3471                for (j, value) in row.into_iter().enumerate() {
3472                    grad[[i, j]] = value;
3473                }
3474            }
3475            Ok(vec![grad])
3476        })?;
3477
3478        let mean = if self.link_wiggle.is_none() {
3479            let (var_t, var_s, cov_ts) = project_two_block_linear_predictor_covariance(
3480                &input.design,
3481                design_noise,
3482                &backend,
3483                p_t,
3484                p_s,
3485                "binomial location-scale posterior mean",
3486            )?;
3487            let values: Result<Vec<_>, _> = (0..eta_t.len())
3488                .into_par_iter()
3489                .map(|i| {
3490                    PREDICT_QUADRATURE_CONTEXT.with(|quadctx| {
3491                        projected_bivariate_posterior_mean_result(
3492                            quadctx,
3493                            [eta_t[i], eta_s[i]],
3494                            [
3495                                [var_t[i].max(0.0), cov_ts[i]],
3496                                [cov_ts[i], var_s[i].max(0.0)],
3497                            ],
3498                            |eta_threshold, eta_log_sigma| {
3499                                let q0 = -eta_threshold * (-eta_log_sigma).exp();
3500                                let jet =
3501                                    crate::solver::mixture_link::inverse_link_jet_for_inverse_link(
3502                                        &self.inverse_link,
3503                                        q0,
3504                                    )?;
3505                                Ok(jet.mu.clamp(0.0, 1.0))
3506                            },
3507                        )
3508                    })
3509                })
3510                .collect();
3511            Array1::from_vec(values?)
3512        } else {
3513            let runtime = self.link_wiggle.as_ref().expect("checked above");
3514            let betaw = Array1::from_vec(runtime.beta.clone());
3515            let mut wiggle_basis_rhs = Array2::<f64>::zeros((p_total, p_w));
3516            for j in 0..p_w {
3517                wiggle_basis_rhs[[p_t + p_s + j, j]] = 1.0;
3518            }
3519            let covww = backend
3520                .apply_columns(&wiggle_basis_rhs)
3521                .map_err(EstimationError::InvalidInput)?
3522                .slice(ndarray::s![p_t + p_s..p_total, ..])
3523                .to_owned();
3524            let mut out = Array1::<f64>::zeros(eta.len());
3525            let chunk_rows = prediction_chunk_rows(p_total, 2, eta.len());
3526            let mut start = 0usize;
3527            while start < eta.len() {
3528                let end = (start + chunk_rows).min(eta.len());
3529                let rows = start..end;
3530                let rows_in_chunk = end - start;
3531                let x_t = design_row_chunk(&input.design, rows.clone())
3532                    .map_err(EstimationError::InvalidInput)?;
3533                let x_ls = design_row_chunk(design_noise, rows.clone())
3534                    .map_err(EstimationError::InvalidInput)?;
3535                let mut rhs = Array2::<f64>::zeros((p_total, rows_in_chunk * 2));
3536                rhs.slice_mut(ndarray::s![0..p_t, 0..rows_in_chunk])
3537                    .assign(&x_t.t());
3538                rhs.slice_mut(ndarray::s![
3539                    p_t..p_t + p_s,
3540                    rows_in_chunk..2 * rows_in_chunk
3541                ])
3542                .assign(&x_ls.t());
3543                let solved = backend
3544                    .apply_columns(&rhs)
3545                    .map_err(EstimationError::InvalidInput)?;
3546                let compute_chunk_row = |quadctx: &QuadratureContext, local_row: usize| {
3547                    let i = start + local_row;
3548                    let solved_t = solved.slice(ndarray::s![.., local_row]);
3549                    let solved_ls = solved.slice(ndarray::s![.., rows_in_chunk + local_row]);
3550                    let var_t = x_t
3551                        .row(local_row)
3552                        .dot(&solved_t.slice(ndarray::s![0..p_t]))
3553                        .max(0.0);
3554                    let var_ls = x_ls
3555                        .row(local_row)
3556                        .dot(&solved_ls.slice(ndarray::s![p_t..p_t + p_s]))
3557                        .max(0.0);
3558                    let cov_tls_t = x_t
3559                        .row(local_row)
3560                        .dot(&solved_ls.slice(ndarray::s![0..p_t]));
3561                    let cov_tls_ls = x_ls
3562                        .row(local_row)
3563                        .dot(&solved_t.slice(ndarray::s![p_t..p_t + p_s]));
3564                    let cov_tls = 0.5 * (cov_tls_t + cov_tls_ls);
3565                    let suv_t = solved_t.slice(ndarray::s![p_t + p_s..p_total]);
3566                    let suv_ls = solved_ls.slice(ndarray::s![p_t + p_s..p_total]);
3567                    let det = (var_t * var_ls - cov_tls * cov_tls).max(1e-12);
3568                    let inv_uu = [
3569                        [var_ls / det, -cov_tls / det],
3570                        [-cov_tls / det, var_t / det],
3571                    ];
3572                    let mut k0 = Array1::<f64>::zeros(p_w);
3573                    let mut k1 = Array1::<f64>::zeros(p_w);
3574                    for j in 0..p_w {
3575                        k0[j] = suv_t[j] * inv_uu[0][0] + suv_ls[j] * inv_uu[1][0];
3576                        k1[j] = suv_t[j] * inv_uu[0][1] + suv_ls[j] * inv_uu[1][1];
3577                    }
3578                    let mut covw_cond = covww.clone();
3579                    for r in 0..p_w {
3580                        for c in 0..p_w {
3581                            covw_cond[[r, c]] -= k0[r] * suv_t[c] + k1[r] * suv_ls[c];
3582                        }
3583                    }
3584                    crate::quadrature::normal_expectation_2d_adaptive_result(
3585                        quadctx,
3586                        [eta_t[i], eta_s[i]],
3587                        [[var_t, cov_tls], [cov_tls, var_ls]],
3588                        |t, ls| {
3589                            let q0 = -t * (-ls).exp();
3590                            let xw = runtime
3591                                .basis_row_scalar(q0)
3592                                .map_err(EstimationError::InvalidInput)?;
3593                            let dt = t - eta_t[i];
3594                            let dls = ls - eta_s[i];
3595                            let meanw = q0 + xw.dot(&betaw) + dt * xw.dot(&k0) + dls * xw.dot(&k1);
3596                            let mut varw = 0.0;
3597                            for r in 0..p_w {
3598                                let xr = xw[r];
3599                                for c in 0..p_w {
3600                                    varw += xr * covw_cond[[r, c]] * xw[c];
3601                                }
3602                            }
3603                            let jet = crate::quadrature::integrated_inverse_link_jetwith_state(
3604                                quadctx,
3605                                self.inverse_link.link_function(),
3606                                meanw,
3607                                varw.max(0.0).sqrt(),
3608                                self.inverse_link.mixture_state(),
3609                                self.inverse_link.sas_state(),
3610                            )?;
3611                            Ok::<f64, EstimationError>(jet.mean.clamp(0.0, 1.0))
3612                        },
3613                    )
3614                };
3615                let chunk_values: Result<Vec<f64>, EstimationError> = (0..rows_in_chunk)
3616                    .into_par_iter()
3617                    .map(|local_row| {
3618                        PREDICT_QUADRATURE_CONTEXT
3619                            .with(|quadctx| compute_chunk_row(quadctx, local_row))
3620                    })
3621                    .collect();
3622                for (local_row, value) in chunk_values?.into_iter().enumerate() {
3623                    out[start + local_row] = value;
3624                }
3625                start = end;
3626            }
3627            out
3628        };
3629        // Binomial location-scale eta_se is response-scale (dprob/dθ chain
3630        // rule), so bounds are mean ± z·se clamped to [0, 1].
3631        let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
3632            let z = standard_normal_quantile(0.5 + 0.5 * level)
3633                .map_err(EstimationError::InvalidInput)?;
3634            (
3635                Some((&mean - &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0))),
3636                Some((&mean + &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0))),
3637            )
3638        } else {
3639            (None, None)
3640        };
3641        Ok(PredictPosteriorMeanResult {
3642            eta,
3643            eta_standard_error: eta_se,
3644            mean,
3645            mean_lower,
3646            mean_upper,
3647        })
3648    }
3649
3650    fn n_blocks(&self) -> usize {
3651        if self.link_wiggle.is_some() { 3 } else { 2 }
3652    }
3653
3654    fn block_roles(&self) -> Vec<BlockRole> {
3655        if self.link_wiggle.is_some() {
3656            vec![BlockRole::Location, BlockRole::Scale, BlockRole::LinkWiggle]
3657        } else {
3658            vec![BlockRole::Location, BlockRole::Scale]
3659        }
3660    }
3661}
3662
3663/// Survival location-scale predictor: two blocks (threshold + log-sigma).
3664///
3665/// Predicts survival probability via:
3666///   q0 = -eta_threshold * exp(-eta_log_sigma)
3667///   survival_prob = 1 - inverse_link(q0)
3668///
3669/// The "design" in `PredictInput` is the threshold design matrix, and
3670/// "design_noise" is the log-sigma design matrix. The time dimension
3671/// (x_time_exit) is handled externally and is not part of this predictor.
3672const SURVIVAL_EXP_NEG_STABLE_MAX_ARG: f64 = 500.0;
3673
3674#[inline]
3675fn survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma: f64) -> f64 {
3676    (-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG).exp()
3677}
3678
3679#[inline]
3680fn survival_q0_and_inverse_sigma(eta_threshold: f64, eta_log_sigma: f64) -> (f64, f64) {
3681    let inv_sigma = survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma);
3682    if eta_threshold == 0.0 {
3683        return (0.0, inv_sigma);
3684    }
3685    let log_abs = eta_threshold.abs().ln() + (-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG);
3686    let q0 = if log_abs > SURVIVAL_EXP_NEG_STABLE_MAX_ARG {
3687        if eta_threshold > 0.0 {
3688            -f64::MAX
3689        } else {
3690            f64::MAX
3691        }
3692    } else {
3693        -eta_threshold * inv_sigma
3694    };
3695    (q0, inv_sigma)
3696}
3697
3698#[inline]
3699fn survival_tail_value_from_failure_jet(
3700    inverse_link: &InverseLink,
3701    eta: f64,
3702    failure_jet: &InverseLinkJet,
3703) -> f64 {
3704    match inverse_link {
3705        InverseLink::Standard(crate::types::LinkFunction::Probit) => {
3706            if eta.is_nan() {
3707                f64::NAN
3708            } else if eta == f64::INFINITY {
3709                0.0
3710            } else if eta == f64::NEG_INFINITY {
3711                1.0
3712            } else {
3713                0.5 * statrs::function::erf::erfc(eta / std::f64::consts::SQRT_2)
3714            }
3715        }
3716        InverseLink::Standard(crate::types::LinkFunction::Logit) => 1.0 / (1.0 + eta.exp()),
3717        InverseLink::Standard(crate::types::LinkFunction::CLogLog) => (-(eta.exp())).exp(),
3718        _ => (1.0 - failure_jet.mu).clamp(0.0, 1.0),
3719    }
3720}
3721
3722#[inline]
3723fn inverse_link_survival_tail_value_and_failure_density(
3724    inverse_link: &InverseLink,
3725    eta: f64,
3726) -> Result<(f64, f64), EstimationError> {
3727    let failure_jet =
3728        crate::solver::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta)?;
3729    Ok((
3730        survival_tail_value_from_failure_jet(inverse_link, eta, &failure_jet).clamp(0.0, 1.0),
3731        failure_jet.d1,
3732    ))
3733}
3734
3735pub struct SurvivalPredictor {
3736    pub beta_threshold: Array1<f64>,
3737    pub beta_log_sigma: Array1<f64>,
3738    pub covariance: Option<Array2<f64>>,
3739    pub inverse_link: InverseLink,
3740}
3741
3742impl SurvivalPredictor {
3743    /// Build a `SurvivalPredictor` from a `UnifiedFitResult`, extracting betas
3744    /// from blocks by role: Threshold (or legacy Location/Mean) ->
3745    /// beta_threshold, Scale -> beta_log_sigma.
3746    pub(crate) fn from_unified(
3747        unified: &UnifiedFitResult,
3748        inverse_link: InverseLink,
3749    ) -> Result<Self, EstimationError> {
3750        let beta_threshold = unified
3751            .block_by_role(BlockRole::Threshold)
3752            .or_else(|| unified.block_by_role(BlockRole::Location))
3753            .or_else(|| unified.block_by_role(BlockRole::Mean))
3754            .map(|b| b.beta.clone())
3755            .ok_or_else(|| {
3756                EstimationError::InvalidInput("Survival model missing threshold block".to_string())
3757            })?;
3758        let beta_log_sigma = unified
3759            .block_by_role(BlockRole::Scale)
3760            .map(|b| b.beta.clone())
3761            .ok_or_else(|| {
3762                EstimationError::InvalidInput(
3763                    "Survival model missing scale (log-sigma) block".to_string(),
3764                )
3765            })?;
3766        Ok(Self {
3767            beta_threshold,
3768            beta_log_sigma,
3769            covariance: unified.covariance_conditional.clone(),
3770            inverse_link,
3771        })
3772    }
3773
3774    /// Compute q0 = -eta_threshold * exp(-eta_log_sigma) and survival_prob = 1 - F(q0).
3775    fn compute_survival(
3776        &self,
3777        eta_threshold: &Array1<f64>,
3778        eta_log_sigma: &Array1<f64>,
3779    ) -> Result<Array1<f64>, EstimationError> {
3780        use rayon::iter::{IntoParallelIterator, ParallelIterator};
3781        let n = eta_threshold.len();
3782        let survival_prob: Result<Vec<f64>, EstimationError> = (0..n)
3783            .into_par_iter()
3784            .map(|i| {
3785                let (q0, _) = survival_q0_and_inverse_sigma(eta_threshold[i], eta_log_sigma[i]);
3786                let (survival, _) =
3787                    inverse_link_survival_tail_value_and_failure_density(&self.inverse_link, q0)?;
3788                Ok(survival)
3789            })
3790            .collect();
3791        Ok(Array1::from_vec(survival_prob?))
3792    }
3793}
3794
3795impl PredictableModel for SurvivalPredictor {
3796    fn predict_plugin_response(
3797        &self,
3798        input: &PredictInput,
3799    ) -> Result<PredictResult, EstimationError> {
3800        let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3801        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3802            EstimationError::InvalidInput(
3803                "Survival prediction requires noise (log-sigma) design matrix".to_string(),
3804            )
3805        })?;
3806        let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3807            EstimationError::InvalidInput(
3808                "Survival prediction requires noise (log-sigma) offset".to_string(),
3809            )
3810        })?;
3811        let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3812        let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
3813        Ok(PredictResult {
3814            eta: eta_threshold,
3815            mean: survival_prob,
3816        })
3817    }
3818
3819    fn predict_with_uncertainty(
3820        &self,
3821        input: &PredictInput,
3822    ) -> Result<PredictionWithSE, EstimationError> {
3823        let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3824        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3825            EstimationError::InvalidInput(
3826                "Survival prediction requires noise (log-sigma) design matrix".to_string(),
3827            )
3828        })?;
3829        let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3830            EstimationError::InvalidInput(
3831                "Survival prediction requires noise (log-sigma) offset".to_string(),
3832            )
3833        })?;
3834        let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3835        let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
3836
3837        let (eta_se, mean_se) = if let Some(ref cov) = self.covariance {
3838            let n = eta_threshold.len();
3839            let p_t = self.beta_threshold.len();
3840            let p_s = self.beta_log_sigma.len();
3841            let backend = PredictionCovarianceBackend::from_dense(cov.view());
3842
3843            let eta_se = padded_design_standard_errors_from_backend(
3844                &input.design,
3845                &backend,
3846                0,
3847                p_s,
3848                "survival threshold uncertainty",
3849            )?;
3850
3851            // Delta-method SE for survival probability.
3852            let mean_se_vec = linear_predictor_se_from_backend(&backend, n, |rows| {
3853                let x_t = design_row_chunk(&input.design, rows.clone())?;
3854                let x_s = design_row_chunk(design_noise, rows.clone())?;
3855                let eta_t_chunk = eta_threshold.slice(ndarray::s![rows.clone()]).to_owned();
3856                let eta_ls_chunk = eta_log_sigma.slice(ndarray::s![rows.clone()]).to_owned();
3857                let rows_in_chunk = eta_t_chunk.len();
3858                let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_t + p_s));
3859                for i in 0..rows_in_chunk {
3860                    let (q0, inv_sigma) =
3861                        survival_q0_and_inverse_sigma(eta_t_chunk[i], eta_ls_chunk[i]);
3862                    let (_, failure_density) =
3863                        inverse_link_survival_tail_value_and_failure_density(
3864                            &self.inverse_link,
3865                            q0,
3866                        )
3867                        .map_err(|e| e.to_string())?;
3868                    let dsurv_deta_t = failure_density * inv_sigma;
3869                    let dsurv_deta_s = failure_density * q0;
3870                    for j in 0..p_t {
3871                        grad[[i, j]] = dsurv_deta_t * x_t[[i, j]];
3872                    }
3873                    for j in 0..p_s {
3874                        grad[[i, p_t + j]] = dsurv_deta_s * x_s[[i, j]];
3875                    }
3876                }
3877                Ok(vec![grad])
3878            })?;
3879            (Some(eta_se), Some(mean_se_vec))
3880        } else {
3881            (None, None)
3882        };
3883
3884        Ok(PredictionWithSE {
3885            eta: eta_threshold,
3886            mean: survival_prob,
3887            eta_se,
3888            mean_se,
3889        })
3890    }
3891
3892    fn predict_noise_scale(
3893        &self,
3894        _: &PredictInput,
3895    ) -> Result<Option<Array1<f64>>, EstimationError> {
3896        Ok(None)
3897    }
3898
3899    fn predict_full_uncertainty(
3900        &self,
3901        input: &PredictInput,
3902        _: &UnifiedFitResult,
3903        options: &PredictUncertaintyOptions,
3904    ) -> Result<PredictUncertaintyResult, EstimationError> {
3905        let pred = self.predict_with_uncertainty(input)?;
3906        let z = crate::probability::standard_normal_quantile(0.5 + options.confidence_level * 0.5)
3907            .map_err(|e| EstimationError::InvalidInput(e))?;
3908
3909        let eta_se = pred.eta_se.as_ref().ok_or_else(|| {
3910            EstimationError::InvalidInput(
3911                "Survival full uncertainty requires covariance (eta_se unavailable)".to_string(),
3912            )
3913        })?;
3914        let mean_se = pred.mean_se.as_ref().ok_or_else(|| {
3915            EstimationError::InvalidInput(
3916                "Survival full uncertainty requires covariance (mean_se unavailable)".to_string(),
3917            )
3918        })?;
3919
3920        let eta_lower = &pred.eta - &eta_se.mapv(|s| z * s);
3921        let eta_upper = &pred.eta + &eta_se.mapv(|s| z * s);
3922        let mut mean_lower = &pred.mean - &mean_se.mapv(|s| z * s);
3923        let mut mean_upper = &pred.mean + &mean_se.mapv(|s| z * s);
3924        // Clamp survival probabilities to [0, 1].
3925        mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
3926        mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
3927
3928        Ok(PredictUncertaintyResult {
3929            eta: pred.eta,
3930            mean: pred.mean,
3931            eta_standard_error: eta_se.clone(),
3932            mean_standard_error: mean_se.clone(),
3933            eta_lower,
3934            eta_upper,
3935            mean_lower,
3936            mean_upper,
3937            observation_lower: None,
3938            observation_upper: None,
3939            covariance_mode_requested: options.covariance_mode,
3940            covariance_corrected_used: false,
3941        })
3942    }
3943
3944    fn predict_posterior_mean(
3945        &self,
3946        input: &PredictInput,
3947        fit: &UnifiedFitResult,
3948        confidence_level: Option<f64>,
3949    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
3950        // The eta_se here covers only the threshold block. Response-scale
3951        // survival intervals also need sigma uncertainty, which is propagated
3952        // by the caller when it requests full interval output.
3953        //
3954        // Validation target for this survival posterior-mean path:
3955        // compare against 50K Monte Carlo draws from N(beta_hat, V) for a
3956        // simple Weibull-style location-scale survival fit and require
3957        // agreement within ~0.005; as covariance -> 0, the integrated mean
3958        // must collapse to the point prediction.
3959        let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
3960        let design_noise = input.design_noise.as_ref().ok_or_else(|| {
3961            EstimationError::InvalidInput(
3962                "Survival posterior mean requires noise (log-sigma) design matrix".to_string(),
3963            )
3964        })?;
3965        let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
3966            EstimationError::InvalidInput(
3967                "Survival posterior mean requires noise (log-sigma) offset".to_string(),
3968            )
3969        })?;
3970        let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
3971        let p_t = self.beta_threshold.len();
3972        let p_s = self.beta_log_sigma.len();
3973        let p_total = p_t + p_s;
3974        let backend = require_posterior_mean_backend(
3975            fit,
3976            self.covariance.as_ref(),
3977            p_total,
3978            "survival posterior mean",
3979        )?;
3980
3981        let eta_se = padded_design_standard_errors_from_backend(
3982            &input.design,
3983            &backend,
3984            0,
3985            p_s,
3986            "survival posterior mean",
3987        )?;
3988        let (var_t, var_s, cov_ts) = project_two_block_linear_predictor_covariance(
3989            &input.design,
3990            design_noise,
3991            &backend,
3992            p_t,
3993            p_s,
3994            "survival posterior mean",
3995        )?;
3996        let quadctx = crate::quadrature::QuadratureContext::new();
3997        let mean = Array1::from_vec(
3998            (0..eta_threshold.len())
3999                .map(|i| {
4000                    projected_bivariate_posterior_mean_result(
4001                        &quadctx,
4002                        [eta_threshold[i], eta_log_sigma[i]],
4003                        [
4004                            [var_t[i].max(0.0), cov_ts[i]],
4005                            [cov_ts[i], var_s[i].max(0.0)],
4006                        ],
4007                        |threshold, log_sigma| {
4008                            let (q0, _) = survival_q0_and_inverse_sigma(threshold, log_sigma);
4009                            let (survival, _) =
4010                                inverse_link_survival_tail_value_and_failure_density(
4011                                    &self.inverse_link,
4012                                    q0,
4013                                )?;
4014                            Ok(survival)
4015                        },
4016                    )
4017                })
4018                .collect::<Result<Vec<_>, _>>()?,
4019        );
4020        let (mean_lower, mean_upper) = if let Some(level) = confidence_level {
4021            let z = crate::probability::standard_normal_quantile(0.5 + 0.5 * level).unwrap_or(1.96);
4022            let lo = (&mean - &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
4023            let hi = (&mean + &eta_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
4024            (Some(lo), Some(hi))
4025        } else {
4026            (None, None)
4027        };
4028        Ok(PredictPosteriorMeanResult {
4029            eta: eta_threshold,
4030            eta_standard_error: eta_se,
4031            mean,
4032            mean_lower,
4033            mean_upper,
4034        })
4035    }
4036
4037    fn n_blocks(&self) -> usize {
4038        2
4039    }
4040
4041    fn block_roles(&self) -> Vec<BlockRole> {
4042        vec![BlockRole::Threshold, BlockRole::Scale]
4043    }
4044}
4045
4046/// Predictor for transformation-normal (PIT) models.
4047///
4048/// The PIT-transformed values h(y|x) are precomputed in
4049/// `build_predict_input_for_model` and stored in the PredictInput offset.
4050/// This predictor passes them through as the prediction: eta = h, mean = h.
4051pub struct TransformationNormalPredictor {
4052    pub covariance: Option<Array2<f64>>,
4053}
4054
4055impl PredictableModel for TransformationNormalPredictor {
4056    fn predict_plugin_response(
4057        &self,
4058        input: &PredictInput,
4059    ) -> Result<PredictResult, EstimationError> {
4060        let h = input.offset.clone();
4061        Ok(PredictResult {
4062            eta: h.clone(),
4063            mean: h,
4064        })
4065    }
4066
4067    fn predict_with_uncertainty(
4068        &self,
4069        input: &PredictInput,
4070    ) -> Result<PredictionWithSE, EstimationError> {
4071        let h = input.offset.clone();
4072        Ok(PredictionWithSE {
4073            eta: h.clone(),
4074            mean: h,
4075            eta_se: None,
4076            mean_se: None,
4077        })
4078    }
4079
4080    fn predict_noise_scale(
4081        &self,
4082        _: &PredictInput,
4083    ) -> Result<Option<Array1<f64>>, EstimationError> {
4084        Ok(None)
4085    }
4086
4087    fn predict_full_uncertainty(
4088        &self,
4089        input: &PredictInput,
4090        fit: &UnifiedFitResult,
4091        options: &PredictUncertaintyOptions,
4092    ) -> Result<PredictUncertaintyResult, EstimationError> {
4093        let h = input.offset.clone();
4094        let n = h.len();
4095        let zeros = Array1::zeros(n);
4096        Ok(PredictUncertaintyResult {
4097            eta: h.clone(),
4098            mean: h.clone(),
4099            eta_standard_error: zeros.clone(),
4100            mean_standard_error: zeros,
4101            eta_lower: h.clone(),
4102            eta_upper: h.clone(),
4103            mean_lower: h.clone(),
4104            mean_upper: h,
4105            observation_lower: None,
4106            observation_upper: None,
4107            covariance_mode_requested: options.covariance_mode,
4108            covariance_corrected_used: fit.covariance_corrected.is_some(),
4109        })
4110    }
4111
4112    fn predict_posterior_mean(
4113        &self,
4114        input: &PredictInput,
4115        fit: &UnifiedFitResult,
4116        confidence_level: Option<f64>,
4117    ) -> Result<PredictPosteriorMeanResult, EstimationError> {
4118        let h = input.offset.clone();
4119        let n = h.len();
4120        let has_fit_covariance =
4121            fit.covariance_corrected.is_some() || fit.covariance_conditional.is_some();
4122        let (mean_lower, mean_upper) = if confidence_level.is_some() && has_fit_covariance {
4123            (Some(h.clone()), Some(h.clone()))
4124        } else {
4125            (None, None)
4126        };
4127        Ok(PredictPosteriorMeanResult {
4128            eta: h.clone(),
4129            eta_standard_error: Array1::zeros(n),
4130            mean: h,
4131            mean_lower,
4132            mean_upper,
4133        })
4134    }
4135
4136    fn n_blocks(&self) -> usize {
4137        1
4138    }
4139    fn block_roles(&self) -> Vec<BlockRole> {
4140        vec![BlockRole::Mean]
4141    }
4142}
4143
4144/// Compute eta standard errors from a design matrix and covariance/precision backend.
4145fn eta_standard_errors_from_backend(
4146    x: &DesignMatrix,
4147    backend: &PredictionCovarianceBackend<'_>,
4148) -> Result<Array1<f64>, EstimationError> {
4149    let vars = linear_predictorvariance_from_backend(x, backend)?;
4150    Ok(vars.mapv(|v| v.max(0.0).sqrt()))
4151}
4152
4153/// Delta-method standard errors on the mean scale.
4154fn delta_method_mean_se(
4155    eta: &Array1<f64>,
4156    eta_se: &Array1<f64>,
4157    strategy: &(dyn FamilyStrategy + Sync),
4158) -> Result<Array1<f64>, EstimationError> {
4159    use rayon::iter::{IntoParallelIterator, ParallelIterator};
4160    let n = eta.len();
4161    let values: Result<Vec<f64>, EstimationError> = (0..n)
4162        .into_par_iter()
4163        .map(|i| {
4164            let jet = strategy.inverse_link_jet(eta[i])?;
4165            Ok((jet.d1 * eta_se[i]).abs())
4166        })
4167        .collect();
4168    Ok(Array1::from_vec(values?))
4169}
4170
4171pub struct PredictPosteriorMeanResult {
4172    pub eta: Array1<f64>,
4173    pub eta_standard_error: Array1<f64>,
4174    pub mean: Array1<f64>,
4175    /// Response-scale lower confidence bound (set by
4176    /// [`enrich_posterior_mean_bounds`]).
4177    pub mean_lower: Option<Array1<f64>>,
4178    /// Response-scale upper confidence bound (set by
4179    /// [`enrich_posterior_mean_bounds`]).
4180    pub mean_upper: Option<Array1<f64>>,
4181}
4182
4183/// Compute and attach TransformEta confidence bounds to a posterior-mean result.
4184///
4185/// This mirrors the bound construction in [`predict_gamwith_uncertainty`] using
4186/// the `TransformEta` method: transform `eta ± z * eta_se` through the inverse
4187/// link, then clamp to [0, 1] for bounded-response families.
4188///
4189/// Call this after [`PredictableModel::predict_posterior_mean`] whenever a
4190/// confidence level is available so that `mean_lower` / `mean_upper` are
4191/// always populated alongside `eta_standard_error`.
4192pub fn enrich_posterior_mean_bounds(
4193    result: &mut PredictPosteriorMeanResult,
4194    confidence_level: f64,
4195    family: crate::types::LikelihoodFamily,
4196    link_kind: Option<&InverseLink>,
4197) -> Result<(), EstimationError> {
4198    if !(confidence_level.is_finite() && confidence_level > 0.0 && confidence_level < 1.0) {
4199        return Err(EstimationError::InvalidInput(format!(
4200            "confidence_level must be in (0,1), got {confidence_level}"
4201        )));
4202    }
4203    let z = crate::probability::standard_normal_quantile(0.5 + 0.5 * confidence_level)
4204        .map_err(EstimationError::InvalidInput)?;
4205
4206    let eta_lower = &result.eta - &result.eta_standard_error.mapv(|s| z * s);
4207    let eta_upper = &result.eta + &result.eta_standard_error.mapv(|s| z * s);
4208
4209    let transformed_lower = apply_family_inverse_link(&eta_lower, family, link_kind)?;
4210    let transformed_upper = apply_family_inverse_link(&eta_upper, family, link_kind)?;
4211
4212    // Handle potentially non-monotone transforms (e.g. survival).
4213    let mut mean_lower = Array1::from_iter(
4214        transformed_lower
4215            .iter()
4216            .zip(transformed_upper.iter())
4217            .map(|(&lo, &hi)| lo.min(hi)),
4218    );
4219    let mut mean_upper = Array1::from_iter(
4220        transformed_lower
4221            .iter()
4222            .zip(transformed_upper.iter())
4223            .map(|(&lo, &hi)| lo.max(hi)),
4224    );
4225
4226    // Clamp bounded-response families to [0, 1].
4227    if matches!(
4228        family,
4229        crate::types::LikelihoodFamily::BinomialLogit
4230            | crate::types::LikelihoodFamily::BinomialProbit
4231            | crate::types::LikelihoodFamily::BinomialCLogLog
4232            | crate::types::LikelihoodFamily::BinomialSas
4233            | crate::types::LikelihoodFamily::BinomialBetaLogistic
4234            | crate::types::LikelihoodFamily::BinomialMixture
4235            | crate::types::LikelihoodFamily::RoystonParmar
4236    ) {
4237        mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
4238        mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
4239    }
4240
4241    result.mean_lower = Some(mean_lower);
4242    result.mean_upper = Some(mean_upper);
4243    Ok(())
4244}
4245
4246#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4247pub enum InferenceCovarianceMode {
4248    /// Use conditional posterior covariance only:
4249    ///   Var(beta | lambda_hat) ~= H_{rho_hat}^{-1}.
4250    Conditional,
4251    /// Prefer first-order smoothing-corrected covariance when available:
4252    ///   Var(beta) ~= H_{rho_hat}^{-1} + J Var(rho_hat) J^T.
4253    /// Falls back to conditional if correction is unavailable.
4254    ConditionalPlusSmoothingPreferred,
4255    /// Require the first-order smoothing-corrected covariance; error if unavailable.
4256    ConditionalPlusSmoothingRequired,
4257}
4258
4259/// Per-axis training support range used by boundary and OOD corrections.
4260/// For each predictor axis we record the empirical [min, max] from training.
4261/// Boundary correction inflates variance for x_i within a small fraction of
4262/// the range from either edge; OOD inflation inflates variance for x_i
4263/// outside [min, max] proportional to (excess / range).
4264#[derive(Clone, Debug)]
4265pub struct TrainingSupport {
4266    /// Axis-wise minimum across the training rows; length = number of input
4267    /// columns the design treats as continuous predictors. The order must
4268    /// match `predictor_x` rows passed in `PredictUncertaintyOptions::
4269    /// predictor_x_for_corrections` (see helper below); a length of zero
4270    /// disables both boundary and OOD corrections.
4271    pub axis_min: Array1<f64>,
4272    /// Axis-wise maximum, paired with `axis_min`.
4273    pub axis_max: Array1<f64>,
4274}
4275
4276impl TrainingSupport {
4277    /// Convenience constructor from raw training rows. Computes per-axis
4278    /// min/max in a single pass.
4279    pub fn from_training_rows(rows: ArrayView2<'_, f64>) -> Self {
4280        let d = rows.ncols();
4281        if rows.nrows() == 0 || d == 0 {
4282            return Self {
4283                axis_min: Array1::zeros(0),
4284                axis_max: Array1::zeros(0),
4285            };
4286        }
4287        let mut axis_min = Array1::from_elem(d, f64::INFINITY);
4288        let mut axis_max = Array1::from_elem(d, f64::NEG_INFINITY);
4289        for row in rows.outer_iter() {
4290            for k in 0..d {
4291                let v = row[k];
4292                if v < axis_min[k] {
4293                    axis_min[k] = v;
4294                }
4295                if v > axis_max[k] {
4296                    axis_max[k] = v;
4297                }
4298            }
4299        }
4300        Self { axis_min, axis_max }
4301    }
4302}
4303
4304pub struct PredictUncertaintyOptions {
4305    /// Central interval level in (0, 1), e.g. 0.95.
4306    pub confidence_level: f64,
4307    /// Covariance mode used for eta/mean intervals.
4308    pub covariance_mode: InferenceCovarianceMode,
4309    /// Mean-scale interval construction method.
4310    pub mean_interval_method: MeanIntervalMethod,
4311    /// For Gaussian identity, also return observation intervals using
4312    /// Var(y_new | x) = Var(eta_hat) + sigma^2.
4313    pub includeobservation_interval: bool,
4314    /// Apply the O(n⁻¹) frequentist bias correction at prediction time.
4315    /// When enabled (default), η̂_BC(x) = η̂(x) + s_*(x)^T H⁻¹ S(λ̂) β̂
4316    /// is reported instead of the raw plug-in η̂(x), restoring the OLS-style
4317    /// predictor at the cost of slightly higher variance. Standard errors
4318    /// are unaffected at first order. Requires `fit.bias_correction_beta()`
4319    /// to be available; silently falls back to the raw predictor otherwise.
4320    pub apply_bias_correction: bool,
4321    /// Edgeworth expansion correction for one-sided tail coverage. When ON
4322    /// (default), the per-row z-multiplier is replaced by the Cornish–Fisher
4323    /// expansion z + (z² − 1)·κ₃ / 6 + … using a per-row skewness estimate
4324    /// derived from `eta` and `eta_standard_error`. The result is an
4325    /// asymmetric (lower, upper) multiplier pair that preserves the central
4326    /// confidence level while adjusting tail rates separately. Requires
4327    /// `eta_skewness_for_corrections` if a non-zero skew estimate is to be
4328    /// used; otherwise this reduces to the standard symmetric interval.
4329    pub edgeworth_one_sided: bool,
4330    /// Inflate variance near the support boundary. When ON (default),
4331    /// requires both `predictor_x_for_corrections` and `training_support`;
4332    /// otherwise behaves as a no-op. The inflation factor is
4333    /// `1 + α · max(0, 1 − d_edge / (β · range))²` per axis, with
4334    /// α = `boundary_alpha` and β = `boundary_band_fraction`. d_edge is the
4335    /// minimum of (x − min, max − x) per axis.
4336    pub boundary_correction: bool,
4337    /// Inflate variance for predictions outside the per-axis training
4338    /// range. When ON (default OFF), requires both
4339    /// `predictor_x_for_corrections` and `training_support`. Factor is
4340    /// `1 + γ · Σ_k (excess_k / range_k)²`, with γ = `ood_gamma`.
4341    pub ood_inflation: bool,
4342    /// Joint coverage adjustment over a query batch. When ON (default
4343    /// OFF) the per-row z multiplier is increased so the family-wise
4344    /// coverage of the returned intervals matches `confidence_level`.
4345    /// Uses Bonferroni: `z_joint = standard_normal_quantile(
4346    /// 0.5 + 0.5·(1 − (1 − level) / m))` where m is the joint query count
4347    /// (defaults to the prediction batch size when `joint_query_count` is
4348    /// None).
4349    pub multi_point_joint: bool,
4350    /// Predictor rows aligned with the prediction batch, used by boundary
4351    /// and OOD corrections. Number of columns must match
4352    /// `training_support.axis_min.len()`. When None, both corrections
4353    /// silently no-op even if their flags are set.
4354    pub predictor_x_for_corrections: Option<Array2<f64>>,
4355    /// Per-axis training support, paired with `predictor_x_for_corrections`.
4356    pub training_support: Option<TrainingSupport>,
4357    /// Per-row Edgeworth skewness κ₃ estimate (length = batch size). When
4358    /// None, Edgeworth correction reduces to the standard symmetric
4359    /// quantile (no-op).
4360    pub eta_skewness_for_corrections: Option<Array1<f64>>,
4361    /// Joint query count m for the multi-point adjustment. When None the
4362    /// prediction batch size is used.
4363    pub joint_query_count: Option<usize>,
4364    /// Boundary correction strength α (multiplier on the squared shortfall).
4365    /// Default 0.25. Larger ⇒ more inflation near the edge.
4366    pub boundary_alpha: f64,
4367    /// Boundary correction band β (fraction of range that counts as "near"
4368    /// the edge). Default 0.05. Inside this band the inflation factor
4369    /// grows quadratically as x → edge.
4370    pub boundary_band_fraction: f64,
4371    /// OOD inflation strength γ (multiplier on the squared per-axis
4372    /// overshoot fraction). Default 1.0.
4373    pub ood_gamma: f64,
4374}
4375
4376impl Default for PredictUncertaintyOptions {
4377    fn default() -> Self {
4378        Self {
4379            confidence_level: 0.95,
4380            covariance_mode: InferenceCovarianceMode::ConditionalPlusSmoothingPreferred,
4381            mean_interval_method: MeanIntervalMethod::TransformEta,
4382            includeobservation_interval: true,
4383            apply_bias_correction: true,
4384            edgeworth_one_sided: true,
4385            boundary_correction: true,
4386            ood_inflation: false,
4387            multi_point_joint: false,
4388            predictor_x_for_corrections: None,
4389            training_support: None,
4390            eta_skewness_for_corrections: None,
4391            joint_query_count: None,
4392            boundary_alpha: 0.25,
4393            boundary_band_fraction: 0.05,
4394            ood_gamma: 1.0,
4395        }
4396    }
4397}
4398
4399/// Asymmetric (lower, upper) z-multiplier produced by the Edgeworth
4400/// one-sided correction. With κ₃ = 0 both entries equal the standard
4401/// symmetric `z_{(1+level)/2}` quantile.
4402#[derive(Clone, Copy, Debug)]
4403pub(crate) struct EdgeworthZ {
4404    pub z_lower: f64,
4405    pub z_upper: f64,
4406}
4407
4408/// One-sided Edgeworth expansion (Cornish–Fisher to first non-Gaussian
4409/// order) for a coverage level on each tail. Given a per-row skewness
4410/// estimate κ₃, returns (z_lower, z_upper) such that
4411///
4412///   eta_lower = eta − z_lower · se,   eta_upper = eta + z_upper · se,
4413///
4414/// with the lower-tail probability Φ(−z_lower) ≈ α/2 and the upper-tail
4415/// probability 1 − Φ(z_upper) ≈ α/2 to O(κ₃). The expansion is
4416///   z_p ≈ z + (z² − 1) · κ₃ / 6
4417/// applied with sign-symmetric z at the two tails. With κ₃ = 0 this
4418/// reduces to the symmetric interval z_lower = z_upper = z.
4419pub(crate) fn edgeworth_one_sided_quantile(z: f64, skew_kappa3: f64) -> EdgeworthZ {
4420    // Cornish–Fisher: q_α = z_α + (z_α² − 1) κ₃ / 6.
4421    // For the upper tail use +z, for the lower tail use −z (in the
4422    // standardized scale), then negate. Net effect:
4423    //   z_upper_eta = z + (z² − 1) κ₃ / 6
4424    //   z_lower_eta = z − (z² − 1) κ₃ / 6
4425    let bump = (z * z - 1.0) * skew_kappa3 / 6.0;
4426    EdgeworthZ {
4427        z_lower: (z - bump).max(0.0),
4428        z_upper: (z + bump).max(0.0),
4429    }
4430}
4431
4432/// Per-row variance-inflation factor for the boundary correction. Returns
4433/// 1 if no axis is inside the boundary band, otherwise
4434/// `1 + α · Σ_k max(0, 1 − d_k / (β · range_k))²` summed over axes.
4435/// When `range_k = 0` (degenerate axis) the contribution is skipped.
4436pub(crate) fn boundary_variance_inflation_factor(
4437    x_row: ArrayView1<'_, f64>,
4438    axis_min: ArrayView1<'_, f64>,
4439    axis_max: ArrayView1<'_, f64>,
4440    alpha: f64,
4441    band_fraction: f64,
4442) -> f64 {
4443    let d = x_row.len();
4444    if d == 0 || axis_min.len() != d || axis_max.len() != d || band_fraction <= 0.0 {
4445        return 1.0;
4446    }
4447    let mut excess = 0.0_f64;
4448    for k in 0..d {
4449        let lo = axis_min[k];
4450        let hi = axis_max[k];
4451        let range = hi - lo;
4452        if !(range > 0.0) {
4453            continue;
4454        }
4455        let x = x_row[k];
4456        // Closest-edge distance, clamped to interior.
4457        let d_edge = (x - lo).min(hi - x);
4458        if !d_edge.is_finite() || d_edge >= band_fraction * range {
4459            continue;
4460        }
4461        // Inside the band (or beyond on the wrong side; we only inflate
4462        // for interior-near-edge here, OOD case is the other helper).
4463        if d_edge <= 0.0 {
4464            // Exactly on or just past the boundary: full band shortfall.
4465            excess += 1.0;
4466        } else {
4467            let shortfall = 1.0 - d_edge / (band_fraction * range);
4468            excess += shortfall * shortfall;
4469        }
4470    }
4471    (1.0 + alpha * excess).max(1.0)
4472}
4473
4474/// Per-row variance-inflation factor for an out-of-distribution prediction.
4475/// Returns `1 + γ · Σ_k (excess_k / range_k)²` where excess_k = max(0,
4476/// max(lo − x, x − hi)) per axis, range_k = hi − lo. Always ≥ 1; equal to
4477/// 1 when x is inside the bounding box on every axis.
4478pub(crate) fn ood_variance_inflation_factor(
4479    x_row: ArrayView1<'_, f64>,
4480    axis_min: ArrayView1<'_, f64>,
4481    axis_max: ArrayView1<'_, f64>,
4482    gamma: f64,
4483) -> f64 {
4484    let d = x_row.len();
4485    if d == 0 || axis_min.len() != d || axis_max.len() != d {
4486        return 1.0;
4487    }
4488    let mut sq_excess = 0.0_f64;
4489    for k in 0..d {
4490        let lo = axis_min[k];
4491        let hi = axis_max[k];
4492        let range = hi - lo;
4493        if !(range > 0.0) {
4494            continue;
4495        }
4496        let x = x_row[k];
4497        let excess = if x < lo {
4498            lo - x
4499        } else if x > hi {
4500            x - hi
4501        } else {
4502            0.0
4503        };
4504        let frac = excess / range;
4505        sq_excess += frac * frac;
4506    }
4507    (1.0 + gamma * sq_excess).max(1.0)
4508}
4509
4510/// Bonferroni-adjusted z multiplier for joint coverage of `m` query
4511/// rows at central level `level`. The per-row tail probability is
4512/// `(1 − level) / m` (split equally across both tails), giving a
4513/// per-row central level of `1 − (1 − level) / m`. Returns the
4514/// corresponding standard-normal quantile, or the un-adjusted z if
4515/// m ≤ 1 or inputs are degenerate.
4516pub(crate) fn multi_point_joint_z(level: f64, m: usize) -> Result<f64, String> {
4517    if m <= 1 || !(level.is_finite() && level > 0.0 && level < 1.0) {
4518        return standard_normal_quantile(0.5 + 0.5 * level);
4519    }
4520    let alpha = 1.0 - level;
4521    let per_row_alpha = alpha / (m as f64);
4522    let per_row_level = 1.0 - per_row_alpha;
4523    standard_normal_quantile(0.5 + 0.5 * per_row_level)
4524}
4525
4526#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4527pub enum MeanIntervalMethod {
4528    /// Interval on mean scale from delta-method SEs.
4529    Delta,
4530    /// Transform eta interval endpoints through inverse link.
4531    /// This is usually better behaved for nonlinear links.
4532    TransformEta,
4533}
4534
4535pub struct PredictUncertaintyResult {
4536    pub eta: Array1<f64>,
4537    pub mean: Array1<f64>,
4538    pub eta_standard_error: Array1<f64>,
4539    pub mean_standard_error: Array1<f64>,
4540    pub eta_lower: Array1<f64>,
4541    pub eta_upper: Array1<f64>,
4542    pub mean_lower: Array1<f64>,
4543    pub mean_upper: Array1<f64>,
4544    /// Optional Gaussian observation interval bounds.
4545    pub observation_lower: Option<Array1<f64>>,
4546    pub observation_upper: Option<Array1<f64>>,
4547    /// Covariance mode requested by caller.
4548    pub covariance_mode_requested: InferenceCovarianceMode,
4549    /// True if smoothing-corrected covariance was used.
4550    pub covariance_corrected_used: bool,
4551}
4552
4553fn predict_gam_posterior_mean_from_backend(
4554    x: DesignMatrix,
4555    beta: ArrayView1<'_, f64>,
4556    offset: ArrayView1<'_, f64>,
4557    backend: &PredictionCovarianceBackend<'_>,
4558    strategy: &(dyn FamilyStrategy + Sync),
4559    label: &str,
4560) -> Result<PredictPosteriorMeanResult, EstimationError> {
4561    predict_gam_posterior_mean_from_backendwith_bc(x, beta, offset, backend, strategy, label, None)
4562}
4563
4564fn predict_gam_posterior_mean_from_backendwith_bc(
4565    x: DesignMatrix,
4566    beta: ArrayView1<'_, f64>,
4567    offset: ArrayView1<'_, f64>,
4568    backend: &PredictionCovarianceBackend<'_>,
4569    strategy: &(dyn FamilyStrategy + Sync),
4570    label: &str,
4571    bias_correction_beta: Option<ArrayView1<'_, f64>>,
4572) -> Result<PredictPosteriorMeanResult, EstimationError> {
4573    if x.ncols() != beta.len() {
4574        return Err(EstimationError::InvalidInput(format!(
4575            "{label} dimension mismatch: X has {} columns but beta has length {}",
4576            x.ncols(),
4577            beta.len()
4578        )));
4579    }
4580    if x.nrows() != offset.len() {
4581        return Err(EstimationError::InvalidInput(format!(
4582            "{label} dimension mismatch: X has {} rows but offset has length {}",
4583            x.nrows(),
4584            offset.len()
4585        )));
4586    }
4587    if backend.nrows() != beta.len() {
4588        return Err(EstimationError::InvalidInput(format!(
4589            "{label} covariance/backend dimension mismatch: expected parameter dimension {}, got {}",
4590            beta.len(),
4591            backend.nrows()
4592        )));
4593    }
4594
4595    let mut eta = x.matrixvectormultiply(&beta.to_owned());
4596    eta += &offset;
4597    if let Some(bc) = bias_correction_beta {
4598        if bc.len() != beta.len() {
4599            return Err(EstimationError::InvalidInput(format!(
4600                "{label} bias-correction dimension mismatch: beta has length {} but bias_correction_beta has length {}",
4601                beta.len(),
4602                bc.len()
4603            )));
4604        }
4605        let bc_owned = bc.to_owned();
4606        let delta = x.matrixvectormultiply(&bc_owned);
4607        eta += &delta;
4608    }
4609    let etavar = linear_predictorvariance_from_backend(&x, backend)?;
4610    let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
4611    let quadctx = crate::quadrature::QuadratureContext::new();
4612    let means: Result<Vec<f64>, EstimationError> = (0..eta.len())
4613        .into_par_iter()
4614        .map(|i| strategy.posterior_mean(&quadctx, eta[i], eta_standard_error[i]))
4615        .collect();
4616
4617    Ok(PredictPosteriorMeanResult {
4618        eta,
4619        eta_standard_error,
4620        mean: Array1::from_vec(means?),
4621        mean_lower: None,
4622        mean_upper: None,
4623    })
4624}
4625
4626pub struct CoefficientUncertaintyResult {
4627    pub estimate: Array1<f64>,
4628    pub standard_error: Array1<f64>,
4629    pub lower: Array1<f64>,
4630    pub upper: Array1<f64>,
4631    pub corrected: bool,
4632    pub covariance_mode_requested: InferenceCovarianceMode,
4633}
4634
4635/// Generic engine prediction for external designs.
4636/// This API is domain-agnostic: callers provide only design matrix, coefficients, offset, and family.
4637///
4638/// For `RoystonParmar`, callers must supply the exit-side cumulative-hazard
4639/// design and offset so that `eta = log(H(t))`; the response-scale prediction is
4640/// the survival probability `exp(-exp(eta))`.
4641pub fn predict_gam<X>(
4642    x: X,
4643    beta: ArrayView1<'_, f64>,
4644    offset: ArrayView1<'_, f64>,
4645    family: crate::types::LikelihoodFamily,
4646) -> Result<PredictResult, EstimationError>
4647where
4648    X: Into<DesignMatrix>,
4649{
4650    let x = x.into();
4651    if let Some(message) =
4652        predict_gam_dimension_mismatch_message(x.nrows(), x.ncols(), beta.len(), offset.len())
4653    {
4654        return Err(EstimationError::InvalidInput(message));
4655    }
4656
4657    let mut eta = x.matrixvectormultiply(&beta.to_owned());
4658    eta += &offset;
4659
4660    let mean = apply_family_inverse_link(&eta, family, None)?;
4661
4662    Ok(PredictResult { eta, mean })
4663}
4664
4665/// Nonlinear posterior-mean prediction with coefficient uncertainty propagation.
4666///
4667/// For nonlinear links, returns E[g^{-1}(eta_tilde)] where eta_tilde ~ N(eta_hat, se_eta^2).
4668/// For Gaussian identity, this equals the standard plug-in mean.
4669pub fn predict_gam_posterior_mean<X>(
4670    x: X,
4671    beta: ArrayView1<'_, f64>,
4672    offset: ArrayView1<'_, f64>,
4673    family: crate::types::LikelihoodFamily,
4674    covariance: ArrayView2<'_, f64>,
4675) -> Result<PredictPosteriorMeanResult, EstimationError>
4676where
4677    X: Into<DesignMatrix>,
4678{
4679    let x = x.into();
4680    let backend = PredictionCovarianceBackend::from_dense(covariance.view());
4681    let strategy = strategy_for_family(family, None);
4682    predict_gam_posterior_mean_from_backend(
4683        x,
4684        beta,
4685        offset,
4686        &backend,
4687        &strategy,
4688        "predict_gam_posterior_mean",
4689    )
4690}
4691
4692pub fn predict_gam_posterior_meanwith_backend<X>(
4693    x: X,
4694    beta: ArrayView1<'_, f64>,
4695    offset: ArrayView1<'_, f64>,
4696    family: crate::types::LikelihoodFamily,
4697    backend: &PredictionCovarianceBackend<'_>,
4698) -> Result<PredictPosteriorMeanResult, EstimationError>
4699where
4700    X: Into<DesignMatrix>,
4701{
4702    let x = x.into();
4703    let strategy = strategy_for_family(family, None);
4704    predict_gam_posterior_mean_from_backend(
4705        x,
4706        beta,
4707        offset,
4708        backend,
4709        &strategy,
4710        "predict_gam_posterior_meanwith_backend",
4711    )
4712}
4713
4714/// Nonlinear posterior-mean prediction with link-state support for SAS/mixture families.
4715///
4716/// This mirrors `predict_gam_posterior_mean`, but also uses `fit` metadata for
4717/// link families that require extra state (`BinomialSas`, `BinomialMixture`).
4718pub fn predict_gam_posterior_meanwith_fit<X>(
4719    x: X,
4720    beta: ArrayView1<'_, f64>,
4721    offset: ArrayView1<'_, f64>,
4722    family: crate::types::LikelihoodFamily,
4723    covariance: ArrayView2<'_, f64>,
4724    fit: &UnifiedFitResult,
4725) -> Result<PredictPosteriorMeanResult, EstimationError>
4726where
4727    X: Into<DesignMatrix>,
4728{
4729    let x = x.into();
4730    let backend = PredictionCovarianceBackend::from_dense(covariance.view());
4731    let strategy = strategy_from_fit(family, fit)?;
4732    predict_gam_posterior_mean_from_backend(
4733        x,
4734        beta,
4735        offset,
4736        &backend,
4737        &strategy,
4738        "predict_gam_posterior_meanwith_fit",
4739    )
4740}
4741
4742/// Prediction with coefficient uncertainty propagation.
4743///
4744/// The linear predictor variance uses:
4745/// Var(η_i) = x_i^T Var(β) x_i
4746///
4747/// Mean-scale SEs are delta-method approximations:
4748/// Var(μ_i) ≈ (dμ/dη)^2 Var(η_i)
4749///
4750/// Math note (logit family, Gaussian η posterior):
4751///
4752/// If η_i | D ≈ N(m_i, v_i), then the exact posterior predictive mean on the
4753/// probability scale is the logistic-normal integral
4754///
4755///   E[sigmoid(η_i)] = ∫ sigmoid(x) N(x; m_i, v_i) dx.
4756///
4757/// This does not reduce to an elementary closed form. Two exact representations
4758/// often used in the literature are:
4759///
4760/// 1) Theta/Appell-Lerch style representations (via Poisson summation / Mordell integrals).
4761/// 2) Absolutely convergent complex-error-function (Faddeeva) series obtained from
4762///    partial-fraction expansions of tanh/logistic.
4763///
4764/// A practical exact series form is:
4765///
4766///   E[sigmoid(η)] = 1/2
4767///                   - (sqrt(2π)/σ) * Σ_{n>=1} Im[ w((i a_n - μ)/(sqrt(2)σ)) ],
4768///   where a_n = (2n-1)π, σ = sqrt(v), and w is the Faddeeva function
4769///   w(z) = exp(-z^2) erfc(-i z).
4770///
4771/// The formulas above define the exact logistic-normal target moments under
4772/// Gaussian η uncertainty.
4773///
4774/// CLogLog note (exact target):
4775/// If p = 1 - exp(-exp(η)) and η ~ N(μ,σ²), then
4776///   E[p] = 1 - I(1),  E[p²] = 1 - 2I(1) + I(2),  Var(p) = I(2) - I(1)²
4777/// where I(λ) = E[exp(-λ exp(η))] is the lognormal Laplace transform.
4778/// This identity is exact, and highlights that the moments are determined by
4779/// the lognormal Laplace transform values at λ=1 and λ=2.
4780///
4781/// Exact analytic representation (Mellin-Barnes) for I(λ):
4782///   I(λ) = (1/(2πi)) ∫_{c-i∞}^{c+i∞} Γ(z) λ^{-z} exp(-μ z + 0.5 σ² z²) dz, c>0.
4783/// This Mellin-Barnes integral is mathematically exact.
4784pub fn predict_gamwith_uncertainty<X>(
4785    x: X,
4786    beta: ArrayView1<'_, f64>,
4787    offset: ArrayView1<'_, f64>,
4788    family: crate::types::LikelihoodFamily,
4789    fit: &UnifiedFitResult,
4790    options: &PredictUncertaintyOptions,
4791) -> Result<PredictUncertaintyResult, EstimationError>
4792where
4793    X: Into<DesignMatrix>,
4794{
4795    let x = x.into();
4796    if x.ncols() != beta.len() {
4797        return Err(EstimationError::InvalidInput(format!(
4798            "predict_gamwith_uncertainty dimension mismatch: X has {} columns but beta has length {}",
4799            x.ncols(),
4800            beta.len()
4801        )));
4802    }
4803    if x.nrows() != offset.len() {
4804        return Err(EstimationError::InvalidInput(format!(
4805            "predict_gamwith_uncertainty dimension mismatch: X has {} rows but offset has length {}",
4806            x.nrows(),
4807            offset.len()
4808        )));
4809    }
4810    if !(options.confidence_level.is_finite()
4811        && options.confidence_level > 0.0
4812        && options.confidence_level < 1.0)
4813    {
4814        return Err(EstimationError::InvalidInput(format!(
4815            "confidence_level must be in (0,1), got {}",
4816            options.confidence_level
4817        )));
4818    }
4819
4820    let requested_mode = options.covariance_mode;
4821    let (backend, covariance_corrected_used) = selected_uncertainty_backend(
4822        fit,
4823        beta.len(),
4824        requested_mode,
4825        "predict_gamwith_uncertainty",
4826    )?;
4827
4828    let mut eta = x.matrixvectormultiply(&beta.to_owned());
4829    eta += &offset;
4830    if options.apply_bias_correction
4831        && let Some(bc) = fit.bias_correction_beta()
4832    {
4833        if bc.len() == beta.len() {
4834            let delta = x.matrixvectormultiply(&bc.clone());
4835            eta += &delta;
4836        } else {
4837            log::warn!(
4838                "predict_gamwith_uncertainty: bias-correction dimension mismatch \
4839                (beta {}, bc {}); skipping bias correction",
4840                beta.len(),
4841                bc.len()
4842            );
4843        }
4844    }
4845    let fitted_link_state = fit.fitted_link_state(family).ok();
4846    let mixture_state = match fitted_link_state.as_ref() {
4847        Some(FittedLinkState::Mixture { state, .. }) => Some(state.clone()),
4848        _ => None,
4849    };
4850    let sas_state = match fitted_link_state.as_ref() {
4851        Some(FittedLinkState::Sas { state, .. })
4852        | Some(FittedLinkState::BetaLogistic { state, .. }) => Some(*state),
4853        _ => None,
4854    };
4855    let link_kind = match fitted_link_state.as_ref() {
4856        Some(FittedLinkState::Standard(Some(link))) => Some(InverseLink::Standard(*link)),
4857        Some(FittedLinkState::LatentCLogLog { state }) => Some(InverseLink::LatentCLogLog(*state)),
4858        Some(FittedLinkState::Sas { state, .. }) => Some(InverseLink::Sas(*state)),
4859        Some(FittedLinkState::BetaLogistic { state, .. }) => {
4860            Some(InverseLink::BetaLogistic(*state))
4861        }
4862        Some(FittedLinkState::Mixture { state, .. }) => Some(InverseLink::Mixture(state.clone())),
4863        Some(FittedLinkState::Standard(None)) | None => None,
4864    };
4865    let strategy = strategy_for_family(family, link_kind.as_ref());
4866    let mean = apply_family_inverse_link(&eta, family, link_kind.as_ref())?;
4867
4868    let etavar_raw = linear_predictorvariance_from_backend(&x, &backend)?;
4869    let n_rows = etavar_raw.len();
4870
4871    // ── Coverage corrections ────────────────────────────────────────────
4872    // Variance inflation (boundary + OOD). Both are per-row multipliers
4873    // ≥ 1 applied to Var(η_i); they propagate through to eta_se and
4874    // observation intervals consistently.
4875    let mut variance_inflation = Array1::<f64>::ones(n_rows);
4876    if (options.boundary_correction || options.ood_inflation)
4877        && let (Some(predictor_x), Some(support)) = (
4878            options.predictor_x_for_corrections.as_ref(),
4879            options.training_support.as_ref(),
4880        )
4881        && predictor_x.nrows() == n_rows
4882        && predictor_x.ncols() == support.axis_min.len()
4883        && support.axis_min.len() == support.axis_max.len()
4884    {
4885        for i in 0..n_rows {
4886            let row = predictor_x.row(i);
4887            let mut factor = 1.0_f64;
4888            if options.boundary_correction {
4889                factor *= boundary_variance_inflation_factor(
4890                    row,
4891                    support.axis_min.view(),
4892                    support.axis_max.view(),
4893                    options.boundary_alpha,
4894                    options.boundary_band_fraction,
4895                );
4896            }
4897            if options.ood_inflation {
4898                factor *= ood_variance_inflation_factor(
4899                    row,
4900                    support.axis_min.view(),
4901                    support.axis_max.view(),
4902                    options.ood_gamma,
4903                );
4904            }
4905            variance_inflation[i] = factor;
4906        }
4907    }
4908    let etavar = if variance_inflation.iter().all(|&f| f == 1.0) {
4909        etavar_raw.clone()
4910    } else {
4911        Array1::from_iter(
4912            etavar_raw
4913                .iter()
4914                .zip(variance_inflation.iter())
4915                .map(|(&v, &f)| v * f),
4916        )
4917    };
4918    let eta_standard_error = etavar.mapv(|v| v.max(0.0).sqrt());
4919
4920    // Per-row z multipliers. Joint adjustment widens the central level
4921    // first; Edgeworth then optionally splits the lower/upper tails.
4922    let level = options.confidence_level;
4923    let z_central = if options.multi_point_joint {
4924        let m = options.joint_query_count.unwrap_or(n_rows).max(1);
4925        multi_point_joint_z(level, m).map_err(EstimationError::InvalidInput)?
4926    } else {
4927        standard_normal_quantile(0.5 + 0.5 * level).map_err(EstimationError::InvalidInput)?
4928    };
4929    let mut z_lower_per_row = Array1::<f64>::from_elem(n_rows, z_central);
4930    let mut z_upper_per_row = Array1::<f64>::from_elem(n_rows, z_central);
4931    if options.edgeworth_one_sided
4932        && let Some(skew) = options.eta_skewness_for_corrections.as_ref()
4933        && skew.len() == n_rows
4934    {
4935        for i in 0..n_rows {
4936            let adj = edgeworth_one_sided_quantile(z_central, skew[i]);
4937            z_lower_per_row[i] = adj.z_lower;
4938            z_upper_per_row[i] = adj.z_upper;
4939        }
4940    }
4941    let eta_lower = Array1::from_iter(
4942        eta.iter()
4943            .zip(eta_standard_error.iter())
4944            .zip(z_lower_per_row.iter())
4945            .map(|((&e, &s), &zl)| e - zl * s),
4946    );
4947    let eta_upper = Array1::from_iter(
4948        eta.iter()
4949            .zip(eta_standard_error.iter())
4950            .zip(z_upper_per_row.iter())
4951            .map(|((&e, &s), &zu)| e + zu * s),
4952    );
4953    let quadctx = crate::quadrature::QuadratureContext::new();
4954
4955    // Derivative of inverse link g^{-1}(η) used for delta-method:
4956    //   Var(μ_i) ≈ [d g^{-1}(η_i)/dη]^2 Var(η_i).
4957    //
4958    // For logit:
4959    //   g^{-1}(η)=sigmoid(η), dμ/dη=μ(1-μ).
4960    // If η itself is uncertain (η ~ N(m,v)), the exact predictive mean is
4961    // E[sigmoid(η)] (logistic-normal integral) as documented above.
4962    //
4963    // For cloglog:
4964    //   g^{-1}(η)=1-exp(-exp(η)), dμ/dη=exp(η)exp(-exp(η)).
4965    // With uncertain η the exact moments can be written via I(λ)=E[exp(-λexp(η))],
4966    // and:
4967    //   E[μ]   = 1 - I(1),
4968    //   E[μ²]  = 1 - 2I(1) + I(2),
4969    //   Var(μ) = I(2) - I(1)^2.
4970    // These identities characterize the exact cloglog moments under Gaussian η uncertainty.
4971    let mean_standard_error = Array1::from_vec(
4972        (0..eta.len())
4973            .into_par_iter()
4974            .map(|i| -> Result<f64, EstimationError> {
4975                let se_i = etavar[i].max(0.0).sqrt();
4976                let (_, mut meanvar) = strategy.posterior_meanvariance(&quadctx, eta[i], se_i)?;
4977                if matches!(family, crate::types::LikelihoodFamily::BinomialSas)
4978                    && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
4979                        FittedLinkState::Sas { covariance, .. } => covariance.as_ref(),
4980                        _ => None,
4981                    })
4982                {
4983                    let sas = sas_state.ok_or_else(|| {
4984                        EstimationError::InvalidInput(
4985                            "BinomialSas uncertainty requires fitted sas_epsilon/sas_log_delta"
4986                                .to_string(),
4987                        )
4988                    })?;
4989                    let jets =
4990                        sas_inverse_link_jetwith_param_partials(eta[i], sas.epsilon, sas.log_delta);
4991                    let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
4992                    meanvar += quadratic_form(cov_theta, &g)?;
4993                }
4994                if matches!(family, crate::types::LikelihoodFamily::BinomialBetaLogistic)
4995                    && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
4996                        FittedLinkState::BetaLogistic { covariance, .. } => covariance.as_ref(),
4997                        _ => None,
4998                    })
4999                {
5000                    let sas = sas_state.ok_or_else(|| {
5001                        EstimationError::InvalidInput(
5002                            "BinomialBetaLogistic uncertainty requires fitted parameters"
5003                                .to_string(),
5004                        )
5005                    })?;
5006                    let jets = beta_logistic_inverse_link_jetwith_param_partials(
5007                        eta[i],
5008                        sas.log_delta,
5009                        sas.epsilon,
5010                    );
5011                    let g = [jets.djet_depsilon.mu, jets.djet_dlog_delta.mu];
5012                    meanvar += quadratic_form(cov_theta, &g)?;
5013                }
5014                if matches!(family, crate::types::LikelihoodFamily::BinomialMixture)
5015                    && let Some(cov_theta) = fitted_link_state.as_ref().and_then(|s| match s {
5016                        FittedLinkState::Mixture { covariance, .. } => covariance.as_ref(),
5017                        _ => None,
5018                    })
5019                    && let Some(state) = mixture_state.as_ref()
5020                {
5021                    let mut mix_partials = vec![
5022                        InverseLinkJet {
5023                            mu: 0.0,
5024                            d1: 0.0,
5025                            d2: 0.0,
5026                            d3: 0.0,
5027                        };
5028                        state.rho.len()
5029                    ];
5030                    mixture_inverse_link_jetwith_rho_partials_into(
5031                        state,
5032                        eta[i],
5033                        &mut mix_partials,
5034                    );
5035                    meanvar += quadratic_form_from_jetmu(cov_theta, &mix_partials)?;
5036                }
5037                Ok(meanvar.max(0.0).sqrt())
5038            })
5039            .collect::<Result<Vec<_>, _>>()?,
5040    );
5041
5042    let (mut mean_lower, mut mean_upper) = match options.mean_interval_method {
5043        MeanIntervalMethod::Delta => (
5044            Array1::from_iter(
5045                mean.iter()
5046                    .zip(mean_standard_error.iter())
5047                    .zip(z_lower_per_row.iter())
5048                    .map(|((&m, &s), &zl)| m - zl * s),
5049            ),
5050            Array1::from_iter(
5051                mean.iter()
5052                    .zip(mean_standard_error.iter())
5053                    .zip(z_upper_per_row.iter())
5054                    .map(|((&m, &s), &zu)| m + zu * s),
5055            ),
5056        ),
5057        MeanIntervalMethod::TransformEta => {
5058            let transformed_lower =
5059                apply_family_inverse_link(&eta_lower, family, link_kind.as_ref())?;
5060            let transformed_upper =
5061                apply_family_inverse_link(&eta_upper, family, link_kind.as_ref())?;
5062            (
5063                Array1::from_iter(
5064                    transformed_lower
5065                        .iter()
5066                        .zip(transformed_upper.iter())
5067                        .map(|(&lo, &hi)| lo.min(hi)),
5068                ),
5069                Array1::from_iter(
5070                    transformed_lower
5071                        .iter()
5072                        .zip(transformed_upper.iter())
5073                        .map(|(&lo, &hi)| lo.max(hi)),
5074                ),
5075            )
5076        }
5077    };
5078
5079    if matches!(
5080        family,
5081        crate::types::LikelihoodFamily::BinomialLogit
5082            | crate::types::LikelihoodFamily::BinomialProbit
5083            | crate::types::LikelihoodFamily::BinomialCLogLog
5084            | crate::types::LikelihoodFamily::BinomialSas
5085            | crate::types::LikelihoodFamily::BinomialBetaLogistic
5086            | crate::types::LikelihoodFamily::BinomialMixture
5087            | crate::types::LikelihoodFamily::RoystonParmar
5088    ) {
5089        mean_lower.mapv_inplace(|v| v.clamp(0.0, 1.0));
5090        mean_upper.mapv_inplace(|v| v.clamp(0.0, 1.0));
5091    }
5092
5093    let (observation_lower, observation_upper) = if options.includeobservation_interval
5094        && matches!(family, crate::types::LikelihoodFamily::GaussianIdentity)
5095    {
5096        let obsvar = fit.standard_deviation.max(0.0).powi(2);
5097        let obs_se = etavar.mapv(|v| (v + obsvar).max(0.0).sqrt());
5098        let lower = Array1::from_iter(
5099            eta.iter()
5100                .zip(obs_se.iter())
5101                .zip(z_lower_per_row.iter())
5102                .map(|((&e, &s), &zl)| e - zl * s),
5103        );
5104        let upper = Array1::from_iter(
5105            eta.iter()
5106                .zip(obs_se.iter())
5107                .zip(z_upper_per_row.iter())
5108                .map(|((&e, &s), &zu)| e + zu * s),
5109        );
5110        (Some(lower), Some(upper))
5111    } else {
5112        (None, None)
5113    };
5114
5115    Ok(PredictUncertaintyResult {
5116        eta,
5117        mean,
5118        eta_standard_error,
5119        mean_standard_error,
5120        eta_lower,
5121        eta_upper,
5122        mean_lower,
5123        mean_upper,
5124        observation_lower,
5125        observation_upper,
5126        covariance_mode_requested: requested_mode,
5127        covariance_corrected_used,
5128    })
5129}
5130
5131/// Coefficient-level uncertainty and confidence intervals.
5132pub fn coefficient_uncertainty(
5133    fit: &UnifiedFitResult,
5134    confidence_level: f64,
5135    covariance_mode: InferenceCovarianceMode,
5136) -> Result<CoefficientUncertaintyResult, EstimationError> {
5137    coefficient_uncertaintywith_mode(fit, confidence_level, covariance_mode)
5138}
5139
5140/// Coefficient-level uncertainty and confidence intervals with explicit covariance mode.
5141pub fn coefficient_uncertaintywith_mode(
5142    fit: &UnifiedFitResult,
5143    confidence_level: f64,
5144    covariance_mode: InferenceCovarianceMode,
5145) -> Result<CoefficientUncertaintyResult, EstimationError> {
5146    if !(confidence_level.is_finite() && confidence_level > 0.0 && confidence_level < 1.0) {
5147        return Err(EstimationError::InvalidInput(format!(
5148            "confidence_level must be in (0,1), got {}",
5149            confidence_level
5150        )));
5151    }
5152    // Coefficient SEs are extracted from either:
5153    // - conditional covariance H^{-1}, or
5154    // - first-order corrected covariance H^{-1} + J V_rho J^T.
5155    let (se, corrected) = match covariance_mode {
5156        InferenceCovarianceMode::Conditional => (
5157            fit.beta_standard_errors().cloned().ok_or_else(|| {
5158                EstimationError::InvalidInput(
5159                    "fit result does not contain conditional coefficient standard errors"
5160                        .to_string(),
5161                )
5162            })?,
5163            false,
5164        ),
5165        InferenceCovarianceMode::ConditionalPlusSmoothingPreferred => {
5166            if let Some(se_corr) = fit.beta_standard_errors_corrected() {
5167                (se_corr.clone(), true)
5168            } else if let Some(se_base) = fit.beta_standard_errors() {
5169                (se_base.clone(), false)
5170            } else {
5171                return Err(EstimationError::InvalidInput(
5172                    "fit result does not contain coefficient standard errors".to_string(),
5173                ));
5174            }
5175        }
5176        InferenceCovarianceMode::ConditionalPlusSmoothingRequired => (
5177            fit.beta_standard_errors_corrected()
5178                .cloned()
5179                .ok_or_else(|| {
5180                    EstimationError::InvalidInput(
5181                        "fit result does not contain smoothing-corrected coefficient standard errors"
5182                            .to_string(),
5183                    )
5184                })?,
5185            true,
5186        ),
5187    };
5188
5189    if se.len() != fit.beta.len() {
5190        return Err(EstimationError::InvalidInput(format!(
5191            "standard error length mismatch: beta has {}, se has {}",
5192            fit.beta.len(),
5193            se.len()
5194        )));
5195    }
5196
5197    let z = standard_normal_quantile(0.5 + 0.5 * confidence_level)
5198        .map_err(EstimationError::InvalidInput)?;
5199    let lower = &fit.beta - &se.mapv(|s| z * s);
5200    let upper = &fit.beta + &se.mapv(|s| z * s);
5201    Ok(CoefficientUncertaintyResult {
5202        estimate: fit.beta.clone(),
5203        standard_error: se,
5204        lower,
5205        upper,
5206        corrected,
5207        covariance_mode_requested: covariance_mode,
5208    })
5209}
5210
5211#[cfg(test)]
5212mod tests {
5213    use super::*;
5214    use crate::estimate::{
5215        BlockRole, FitArtifacts, FittedBlock, FittedLinkState, UnifiedFitResult,
5216        UnifiedFitResultParts,
5217    };
5218    use crate::inference::model::SavedAnchoredDeviationRuntime;
5219    use crate::pirls::PirlsStatus;
5220    use crate::types::LinkFunction;
5221    use ndarray::{Array1, Array2, array};
5222
5223    fn saved_runtime_from_deviation_runtime(
5224        runtime: &crate::families::bernoulli_marginal_slope::DeviationRuntime,
5225    ) -> SavedAnchoredDeviationRuntime {
5226        SavedAnchoredDeviationRuntime {
5227            kernel:
5228                crate::families::bernoulli_marginal_slope::exact_kernel::ANCHORED_DEVIATION_KERNEL
5229                    .to_string(),
5230            breakpoints: runtime.breakpoints().to_vec(),
5231            basis_dim: runtime.basis_dim(),
5232            span_c0: runtime
5233                .span_c0()
5234                .outer_iter()
5235                .map(|row| row.to_vec())
5236                .collect(),
5237            span_c1: runtime
5238                .span_c1()
5239                .outer_iter()
5240                .map(|row| row.to_vec())
5241                .collect(),
5242            span_c2: runtime
5243                .span_c2()
5244                .outer_iter()
5245                .map(|row| row.to_vec())
5246                .collect(),
5247            span_c3: runtime
5248                .span_c3()
5249                .outer_iter()
5250                .map(|row| row.to_vec())
5251                .collect(),
5252            anchor_residual_coefficients: None,
5253            anchor_residual_components: Vec::new(),
5254            anchor_residual_rotation: None,
5255        }
5256    }
5257
5258    fn test_fit_with_covariance(beta: Array1<f64>, covariance: Array2<f64>) -> UnifiedFitResult {
5259        UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5260            blocks: vec![FittedBlock {
5261                beta: beta.clone(),
5262                role: BlockRole::Mean,
5263                edf: 0.0,
5264                lambdas: Array1::zeros(0),
5265            }],
5266            log_lambdas: Array1::zeros(0),
5267            lambdas: Array1::zeros(0),
5268            likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
5269            likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
5270            log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5271            log_likelihood: 0.0,
5272            deviance: 0.0,
5273            reml_score: 0.0,
5274            stable_penalty_term: 0.0,
5275            penalized_objective: 0.0,
5276            outer_iterations: 0,
5277            outer_converged: true,
5278            outer_gradient_norm: 0.0,
5279            standard_deviation: 1.0,
5280            covariance_conditional: Some(covariance),
5281            covariance_corrected: None,
5282            inference: None,
5283            fitted_link: FittedLinkState::Standard(None),
5284            geometry: None,
5285            block_states: Vec::new(),
5286            pirls_status: PirlsStatus::Converged,
5287            max_abs_eta: 0.0,
5288            constraint_kkt: None,
5289            artifacts: FitArtifacts {
5290                pirls: None,
5291                ..Default::default()
5292            },
5293            inner_cycles: 0,
5294        })
5295        .expect("test fit")
5296    }
5297
5298    fn gaussian_location_scale_fit_with_covariance(
5299        beta_mu: Array1<f64>,
5300        beta_noise: Array1<f64>,
5301        covariance: Array2<f64>,
5302    ) -> UnifiedFitResult {
5303        UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5304            blocks: vec![
5305                FittedBlock {
5306                    beta: beta_mu,
5307                    role: BlockRole::Location,
5308                    edf: 0.0,
5309                    lambdas: Array1::zeros(0),
5310                },
5311                FittedBlock {
5312                    beta: beta_noise,
5313                    role: BlockRole::Scale,
5314                    edf: 0.0,
5315                    lambdas: Array1::zeros(0),
5316                },
5317            ],
5318            log_lambdas: Array1::zeros(0),
5319            lambdas: Array1::zeros(0),
5320            likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
5321            likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
5322            log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5323            log_likelihood: 0.0,
5324            deviance: 0.0,
5325            reml_score: 0.0,
5326            stable_penalty_term: 0.0,
5327            penalized_objective: 0.0,
5328            outer_iterations: 0,
5329            outer_converged: true,
5330            outer_gradient_norm: 0.0,
5331            standard_deviation: 1.0,
5332            covariance_conditional: Some(covariance),
5333            covariance_corrected: None,
5334            inference: None,
5335            fitted_link: FittedLinkState::Standard(None),
5336            geometry: None,
5337            block_states: Vec::new(),
5338            pirls_status: PirlsStatus::Converged,
5339            max_abs_eta: 0.0,
5340            constraint_kkt: None,
5341            artifacts: FitArtifacts {
5342                pirls: None,
5343                ..Default::default()
5344            },
5345            inner_cycles: 0,
5346        })
5347        .expect("gaussian location-scale fit")
5348    }
5349
5350    fn survival_fit_with_covariance(
5351        beta_threshold: Array1<f64>,
5352        beta_log_sigma: Array1<f64>,
5353        covariance: Array2<f64>,
5354    ) -> UnifiedFitResult {
5355        UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
5356            blocks: vec![
5357                FittedBlock {
5358                    beta: beta_threshold,
5359                    role: BlockRole::Threshold,
5360                    edf: 0.0,
5361                    lambdas: Array1::zeros(0),
5362                },
5363                FittedBlock {
5364                    beta: beta_log_sigma,
5365                    role: BlockRole::Scale,
5366                    edf: 0.0,
5367                    lambdas: Array1::zeros(0),
5368                },
5369            ],
5370            log_lambdas: Array1::zeros(0),
5371            lambdas: Array1::zeros(0),
5372            likelihood_family: Some(crate::types::LikelihoodFamily::RoystonParmar),
5373            likelihood_scale: crate::types::LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
5374            log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
5375            log_likelihood: 0.0,
5376            deviance: 0.0,
5377            reml_score: 0.0,
5378            stable_penalty_term: 0.0,
5379            penalized_objective: 0.0,
5380            outer_iterations: 0,
5381            outer_converged: true,
5382            outer_gradient_norm: 0.0,
5383            standard_deviation: 1.0,
5384            covariance_conditional: Some(covariance),
5385            covariance_corrected: None,
5386            inference: None,
5387            fitted_link: FittedLinkState::Standard(None),
5388            geometry: None,
5389            block_states: Vec::new(),
5390            pirls_status: PirlsStatus::Converged,
5391            max_abs_eta: 0.0,
5392            constraint_kkt: None,
5393            artifacts: FitArtifacts {
5394                pirls: None,
5395                ..Default::default()
5396            },
5397            inner_cycles: 0,
5398        })
5399        .expect("survival fit")
5400    }
5401
5402    #[test]
5403    fn predict_posterior_mean_probit_matches_closed_form_reference() {
5404        let x = array![[1.0], [1.0]];
5405        let beta = array![0.7];
5406        let offset = array![0.0, 0.0];
5407        let covariance = Array2::from_diag(&array![0.25]);
5408        let out = predict_gam_posterior_mean(
5409            x,
5410            beta.view(),
5411            offset.view(),
5412            crate::types::LikelihoodFamily::BinomialProbit,
5413            covariance.view(),
5414        )
5415        .expect("predict posterior mean");
5416        let expected = crate::quadrature::probit_posterior_meanwith_deriv_exact(0.7, 0.5).mean;
5417        assert!((out.mean[0] - expected).abs() <= 1e-12);
5418        assert!((out.mean[1] - expected).abs() <= 1e-12);
5419    }
5420
5421    #[test]
5422    fn predict_posterior_mean_logit_uses_integrated_dispatch() {
5423        let x = array![[1.0], [1.0]];
5424        let beta = array![0.4];
5425        let offset = array![0.0, 0.0];
5426        let covariance = Array2::from_diag(&array![0.16]);
5427        let out = predict_gam_posterior_mean(
5428            x,
5429            beta.view(),
5430            offset.view(),
5431            crate::types::LikelihoodFamily::BinomialLogit,
5432            covariance.view(),
5433        )
5434        .expect("predict posterior mean");
5435        let quadctx = crate::quadrature::QuadratureContext::new();
5436        let expected = crate::quadrature::integrated_inverse_link_mean_and_derivative(
5437            &quadctx,
5438            LinkFunction::Logit,
5439            0.4,
5440            0.4,
5441        )
5442        .expect("logit integrated inverse-link moments should evaluate")
5443        .mean;
5444        assert!((out.mean[0] - expected).abs() <= 1e-12);
5445        assert!((out.mean[1] - expected).abs() <= 1e-12);
5446    }
5447
5448    #[test]
5449    fn bernoulli_marginal_slope_predictor_rejects_structurally_invalid_or_unknown_runtime_kernel() {
5450        let seed = array![-1.5, -0.2, 0.6, 1.4];
5451        let prepared =
5452            crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5453                &seed,
5454                &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5455                    degree: 3,
5456                    num_internal_knots: 3,
5457                    ..Default::default()
5458                },
5459            )
5460            .expect("production score-warp runtime");
5461        let production_runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5462        let score_only = BernoulliMarginalSlopePredictor {
5463            beta_marginal: array![0.8],
5464            beta_logslope: array![1.6],
5465            beta_score_warp: Some(array![0.7, -0.4]),
5466            beta_link_dev: None,
5467            base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5468            z_column: "z".to_string(),
5469            latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5470            latent_measure: LatentMeasureKind::StandardNormal,
5471            baseline_marginal: 0.0,
5472            baseline_logslope: 0.0,
5473            covariance: None,
5474            score_warp_runtime: Some(SavedAnchoredDeviationRuntime {
5475                kernel: "OldQuadrature".to_string(),
5476                ..production_runtime.clone()
5477            }),
5478            // existing field-init order (link_deviation_runtime is the next).
5479            link_deviation_runtime: None,
5480            gaussian_frailty_sd: None,
5481            latent_z_calibration: None,
5482        };
5483        let err = score_only
5484            .score_warp_runtime
5485            .as_ref()
5486            .unwrap()
5487            .design(&array![0.0])
5488            .unwrap_err();
5489        assert!(err.contains("DenestedCubicTransport"));
5490
5491        let err =
5492            crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5493                &seed,
5494                &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5495                    degree: 2,
5496                    num_internal_knots: 3,
5497                    ..Default::default()
5498                },
5499            )
5500            .expect_err("non-cubic deviation runtimes should be rejected");
5501        assert!(err.contains("degree must be 3"));
5502
5503        let mut structurally_invalid = production_runtime.clone();
5504        structurally_invalid.span_c0[0].pop();
5505        let err = structurally_invalid.design(&array![0.0]).unwrap_err();
5506        assert!(err.contains("c0 row 0 has width"));
5507
5508        let cubic = production_runtime;
5509        assert!(cubic.design(&array![0.0]).is_ok());
5510    }
5511
5512    #[test]
5513    fn saved_anchored_deviation_runtime_local_cubic_reconstructs_values() {
5514        let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5515        let prepared =
5516            crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5517                &seed,
5518                &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5519                    num_internal_knots: 4,
5520                    ..Default::default()
5521                },
5522            )
5523            .expect("build saved anchored deviation runtime");
5524        let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5525        let beta = Array1::from_iter(
5526            (0..runtime.basis_dim)
5527                .map(|idx| 0.02 * (idx as f64 + 1.0) * (-1.0_f64).powi(idx as i32)),
5528        );
5529        let n_spans = runtime.span_count().expect("span count");
5530        assert!(n_spans >= 2);
5531        for span_idx in 0..n_spans {
5532            let cubic = runtime
5533                .local_cubic_on_span(&beta, span_idx)
5534                .expect("local cubic");
5535            let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
5536            let expected = runtime.design(&x_eval).expect("design").dot(&beta);
5537            let expected_d1 = runtime
5538                .first_derivative_design(&x_eval)
5539                .expect("d1 design")
5540                .dot(&beta);
5541            for i in 0..x_eval.len() {
5542                let x = x_eval[i];
5543                assert!((cubic.evaluate(x) - expected[i]).abs() < 1e-10);
5544                assert!((cubic.first_derivative(x) - expected_d1[i]).abs() < 1e-10);
5545                let selected = runtime.local_cubic_at(&beta, x).expect("local cubic at x");
5546                let expected_span_idx = if i == 0 && span_idx > 0 {
5547                    span_idx - 1
5548                } else {
5549                    span_idx
5550                };
5551                let expected_cubic = runtime
5552                    .local_cubic_on_span(&beta, expected_span_idx)
5553                    .expect("expected local cubic on span");
5554                assert_eq!(selected.left, expected_cubic.left);
5555                assert_eq!(selected.right, expected_cubic.right);
5556            }
5557        }
5558    }
5559
5560    #[test]
5561    fn saved_anchored_deviation_runtime_design_with_anchor_rows_applies_residual() {
5562        use crate::families::bernoulli_marginal_slope::deviation_runtime::ParametricAnchorBlock;
5563        use crate::inference::model::{SavedAnchorComponent, SavedAnchorKind};
5564
5565        let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5566        let prepared =
5567            crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5568                &seed,
5569                &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5570                    num_internal_knots: 4,
5571                    ..Default::default()
5572                },
5573            )
5574            .expect("build saved anchored deviation runtime");
5575        let mut runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5576
5577        // Inject a non-trivial anchor residual: d = 3 anchor cols,
5578        // M = arbitrary 3 × basis_dim matrix, identity rotation.
5579        let d = 3usize;
5580        let m: Vec<Vec<f64>> = (0..d)
5581            .map(|i| {
5582                (0..runtime.basis_dim)
5583                    .map(|j| 0.1 * (i as f64 + 1.0) - 0.05 * (j as f64 + 1.0))
5584                    .collect()
5585            })
5586            .collect();
5587        runtime.anchor_residual_coefficients = Some(m.clone());
5588        runtime.anchor_residual_components = vec![SavedAnchorComponent {
5589            kind: SavedAnchorKind::Parametric {
5590                block: ParametricAnchorBlock::Marginal,
5591                ncols: d,
5592            },
5593        }];
5594        runtime.anchor_residual_rotation = None;
5595
5596        let values = array![-1.0, 0.0, 0.5, 2.0];
5597        let n = values.len();
5598        let anchor_rows = Array2::from_shape_fn((n, d), |(i, j)| {
5599            0.3 * (i as f64 + 1.0) - 0.1 * (j as f64 + 1.0)
5600        });
5601
5602        let raw = runtime
5603            .design_uncorrected(&values)
5604            .expect("uncorrected design");
5605        let corrected = runtime
5606            .design_with_anchor_rows(&values, anchor_rows.view())
5607            .expect("design with anchor rows");
5608
5609        // Manually compute expected: raw - anchor_rows · M
5610        let mut m_dense = Array2::<f64>::zeros((d, runtime.basis_dim));
5611        for (i, row) in m.iter().enumerate() {
5612            for (j, &v) in row.iter().enumerate() {
5613                m_dense[[i, j]] = v;
5614            }
5615        }
5616        let expected = &raw - &anchor_rows.dot(&m_dense);
5617
5618        for i in 0..n {
5619            for j in 0..runtime.basis_dim {
5620                assert!(
5621                    (corrected[[i, j]] - expected[[i, j]]).abs() < 1e-12,
5622                    "residual-corrected design mismatch at ({i}, {j}): \
5623                     got {got}, expected {exp}",
5624                    got = corrected[[i, j]],
5625                    exp = expected[[i, j]],
5626                );
5627            }
5628        }
5629
5630        // anchor_correction_matrix should produce N · M (n × basis_dim) so
5631        // that raw - correction == corrected, row by row.
5632        let correction = runtime
5633            .anchor_correction_matrix(anchor_rows.view())
5634            .expect("anchor correction matrix")
5635            .expect("Some correction when residual is present");
5636        for i in 0..n {
5637            for j in 0..runtime.basis_dim {
5638                assert!((raw[[i, j]] - correction[[i, j]] - corrected[[i, j]]).abs() < 1e-12,);
5639            }
5640        }
5641    }
5642
5643    #[test]
5644    fn bernoulli_marginal_slope_rigid_gaussian_frailty_uses_scaled_closed_form() {
5645        let predictor = BernoulliMarginalSlopePredictor {
5646            beta_marginal: array![0.7],
5647            beta_logslope: array![-0.4],
5648            beta_score_warp: None,
5649            beta_link_dev: None,
5650            base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5651            z_column: "z".to_string(),
5652            latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5653            latent_measure: LatentMeasureKind::StandardNormal,
5654            baseline_marginal: 0.1,
5655            baseline_logslope: -0.2,
5656            covariance: None,
5657            score_warp_runtime: None,
5658            link_deviation_runtime: None,
5659            gaussian_frailty_sd: Some(0.8),
5660            latent_z_calibration: None,
5661        };
5662        let theta = predictor.theta();
5663        let input = PredictInput {
5664            design: DesignMatrix::from(array![[1.0], [1.0]]),
5665            offset: array![0.0, 0.05],
5666            design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5667            offset_noise: Some(array![0.0, -0.1]),
5668            auxiliary_scalar: Some(array![-0.3, 1.2]),
5669            auxiliary_matrix: None,
5670        };
5671
5672        let (eta, grad) = predictor
5673            .final_eta_and_gradient_from_theta(&input, &theta, true)
5674            .expect("rigid frailty path should evaluate");
5675
5676        let scale = predictor.probit_frailty_scale();
5677        let marginal_eta = array![0.8, 0.85];
5678        let logslope_eta = array![-0.6, -0.7];
5679        let z = array![-0.3, 1.2];
5680        for i in 0..eta.len() {
5681            let sb = scale * logslope_eta[i];
5682            let c = (1.0 + sb * sb).sqrt();
5683            let expected_eta = marginal_eta[i] * c + sb * z[i];
5684            assert!((eta[i] - expected_eta).abs() <= 1e-12);
5685            let expected_d_marginal = c;
5686            let expected_d_logslope =
5687                marginal_eta[i] * scale * scale * logslope_eta[i] / c + scale * z[i];
5688            let grad = grad.as_ref().expect("gradient should be returned");
5689            assert!((grad[[i, 0]] - expected_d_marginal).abs() <= 1e-12);
5690            assert!((grad[[i, 1]] - expected_d_logslope).abs() <= 1e-12);
5691        }
5692    }
5693
5694    #[test]
5695    fn bernoulli_marginal_slope_predictor_uses_local_empirical_latent_law() {
5696        let grids = vec![
5697            EmpiricalZGrid {
5698                nodes: vec![-1.2, -0.2, 0.7],
5699                weights: vec![0.45, 0.35, 0.20],
5700            },
5701            EmpiricalZGrid {
5702                nodes: vec![-0.4, 0.6, 2.4],
5703                weights: vec![0.20, 0.35, 0.45],
5704            },
5705        ];
5706        let predictor = BernoulliMarginalSlopePredictor {
5707            beta_marginal: array![0.2],
5708            beta_logslope: array![0.9],
5709            beta_score_warp: None,
5710            beta_link_dev: None,
5711            base_link: InverseLink::Standard(crate::types::LinkFunction::Probit),
5712            z_column: "z".to_string(),
5713            latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5714            latent_measure: LatentMeasureKind::LocalEmpirical {
5715                feature_cols: vec![0],
5716                input_scales: None,
5717                centers: vec![vec![-1.0], vec![1.0]],
5718                grids: grids.clone(),
5719                top_k: 1,
5720                bandwidth: 0.25,
5721                train_row_mixtures: std::sync::Arc::new(Vec::new()),
5722            },
5723            baseline_marginal: 0.0,
5724            baseline_logslope: 0.0,
5725            covariance: None,
5726            score_warp_runtime: None,
5727            link_deviation_runtime: None,
5728            gaussian_frailty_sd: None,
5729            latent_z_calibration: None,
5730        };
5731        let input = PredictInput {
5732            design: DesignMatrix::from(array![[1.0], [1.0]]),
5733            offset: array![0.0, 0.0],
5734            design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5735            offset_noise: Some(array![0.0, 0.0]),
5736            auxiliary_scalar: Some(array![0.0, 0.0]),
5737            auxiliary_matrix: Some(array![[-1.0], [1.0]]),
5738        };
5739
5740        let (eta, _) = predictor
5741            .final_eta_and_gradient_from_theta(&input, &predictor.theta(), true)
5742            .expect("local empirical prediction");
5743        let (chain_eta, deta_dq) = predictor
5744            .predict_eta_and_q_chain(&input)
5745            .expect("local empirical q chain");
5746
5747        for (row, grid) in grids.iter().enumerate() {
5748            let expected_intercept = empirical_intercept_from_marginal(
5749                normal_cdf(0.2),
5750                0.2,
5751                0.9,
5752                1.0,
5753                &grid.nodes,
5754                &grid.weights,
5755                None,
5756            )
5757            .expect("expected empirical intercept");
5758            assert!((eta[row] - expected_intercept).abs() <= 1e-10);
5759            assert!((chain_eta[row] - eta[row]).abs() <= 1e-12);
5760            assert!(deta_dq[row].is_finite() && deta_dq[row] > 0.0);
5761        }
5762    }
5763
5764    #[test]
5765    fn bernoulli_marginal_slope_predictor_rejects_nonprobit_base_link_scale() {
5766        let predictor = BernoulliMarginalSlopePredictor {
5767            beta_marginal: array![0.7],
5768            beta_logslope: array![-0.4],
5769            beta_score_warp: None,
5770            beta_link_dev: None,
5771            base_link: InverseLink::Standard(crate::types::LinkFunction::Logit),
5772            z_column: "z".to_string(),
5773            latent_z_normalization: SavedLatentZNormalization { mean: 0.0, sd: 1.0 },
5774            latent_measure: LatentMeasureKind::StandardNormal,
5775            baseline_marginal: 0.1,
5776            baseline_logslope: -0.2,
5777            covariance: None,
5778            score_warp_runtime: None,
5779            link_deviation_runtime: None,
5780            gaussian_frailty_sd: Some(0.8),
5781            latent_z_calibration: None,
5782        };
5783        let theta = predictor.theta();
5784        let input = PredictInput {
5785            design: DesignMatrix::from(array![[1.0], [1.0]]),
5786            offset: array![0.0, 0.05],
5787            design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5788            offset_noise: Some(array![0.0, -0.1]),
5789            auxiliary_scalar: Some(array![-0.3, 1.2]),
5790            auxiliary_matrix: None,
5791        };
5792
5793        let err = predictor
5794            .final_eta_and_gradient_from_theta(&input, &theta, true)
5795            .expect_err("non-probit marginal-slope prediction should be rejected");
5796        assert!(err.to_string().contains("requires link(type=probit)"));
5797    }
5798
5799    #[test]
5800    fn saved_anchored_deviation_runtime_basis_cubic_matches_basis_column() {
5801        let seed = array![-2.0, -0.75, 0.0, 1.0, 3.0];
5802        let prepared =
5803            crate::families::bernoulli_marginal_slope::build_score_warp_deviation_block_from_seed(
5804                &seed,
5805                &crate::families::bernoulli_marginal_slope::DeviationBlockConfig {
5806                    num_internal_knots: 4,
5807                    ..Default::default()
5808                },
5809            )
5810            .expect("build saved anchored deviation runtime");
5811        let runtime = saved_runtime_from_deviation_runtime(&prepared.runtime);
5812        let cubic = runtime.basis_span_cubic(0, 1).expect("basis span cubic");
5813        let x_eval = array![cubic.left, 0.5 * (cubic.left + cubic.right), cubic.right];
5814        let design = runtime.design(&x_eval).expect("basis design");
5815        let d1 = runtime
5816            .first_derivative_design(&x_eval)
5817            .expect("basis d1 design");
5818        for i in 0..x_eval.len() {
5819            let x = x_eval[i];
5820            assert!((cubic.evaluate(x) - design[[i, 1]]).abs() < 1e-10);
5821            assert!((cubic.first_derivative(x) - d1[[i, 1]]).abs() < 1e-10);
5822            let selected = runtime.basis_cubic_at(1, x).expect("basis cubic at x");
5823            let expected_span_idx = 0;
5824            let expected_cubic = runtime
5825                .basis_span_cubic(expected_span_idx, 1)
5826                .expect("expected basis span cubic");
5827            assert_eq!(selected.left, expected_cubic.left);
5828            assert_eq!(selected.right, expected_cubic.right);
5829        }
5830    }
5831
5832    #[test]
5833    fn predict_royston_parmar_point_prediction_returns_survival_probability() {
5834        let x = array![[1.0], [1.0]];
5835        let beta = array![0.4];
5836        let offset = array![0.0, 0.8];
5837        let out = predict_gam(
5838            x,
5839            beta.view(),
5840            offset.view(),
5841            crate::types::LikelihoodFamily::RoystonParmar,
5842        )
5843        .expect("royston-parmar point prediction");
5844        let expected_eta = array![0.4, 1.2];
5845        let expected_mean = expected_eta.mapv(|eta: f64| (-(eta.exp())).exp().clamp(0.0, 1.0));
5846        // Approximate comparison: delta-regularization bias can introduce ~1e-15 drift
5847        for i in 0..out.eta.len() {
5848            assert!(
5849                (out.eta[i] - expected_eta[i]).abs() <= 1e-14,
5850                "eta[{i}] mismatch"
5851            );
5852        }
5853        for i in 0..out.mean.len() {
5854            assert!((out.mean[i] - expected_mean[i]).abs() <= 1e-12);
5855        }
5856    }
5857
5858    #[test]
5859    fn predict_royston_parmar_posterior_mean_matches_quadrature_and_fit_path() {
5860        let x = array![[1.0], [1.0]];
5861        let beta = array![0.35];
5862        let offset = array![0.0, 0.0];
5863        let covariance = Array2::from_diag(&array![0.09]);
5864        let fit = test_fit_with_covariance(beta.clone(), covariance.clone());
5865
5866        let out = predict_gam_posterior_mean(
5867            x.clone(),
5868            beta.view(),
5869            offset.view(),
5870            crate::types::LikelihoodFamily::RoystonParmar,
5871            covariance.view(),
5872        )
5873        .expect("royston-parmar posterior mean");
5874        let out_with_fit = predict_gam_posterior_meanwith_fit(
5875            x,
5876            beta.view(),
5877            offset.view(),
5878            crate::types::LikelihoodFamily::RoystonParmar,
5879            covariance.view(),
5880            &fit,
5881        )
5882        .expect("royston-parmar posterior mean with fit");
5883
5884        let quadctx = crate::quadrature::QuadratureContext::new();
5885        let expected = crate::quadrature::survival_posterior_mean(&quadctx, 0.35, 0.3);
5886        for i in 0..out.mean.len() {
5887            assert!((out.mean[i] - expected).abs() <= 1e-12);
5888            assert!((out_with_fit.mean[i] - expected).abs() <= 1e-12);
5889            assert!((out_with_fit.mean[i] - out.mean[i]).abs() <= 1e-12);
5890            assert!(
5891                (out_with_fit.eta_standard_error[i] - out.eta_standard_error[i]).abs() <= 1e-12
5892            );
5893        }
5894    }
5895
5896    #[test]
5897    fn predict_royston_parmar_uncertainty_clamps_and_orders_intervals() {
5898        let x = array![[1.0]];
5899        let beta = array![0.6];
5900        let offset = array![0.0];
5901        let covariance = Array2::from_diag(&array![0.25]);
5902        let fit = test_fit_with_covariance(beta.clone(), covariance);
5903        let options = PredictUncertaintyOptions {
5904            confidence_level: 0.95,
5905            covariance_mode: InferenceCovarianceMode::Conditional,
5906            mean_interval_method: MeanIntervalMethod::TransformEta,
5907            includeobservation_interval: false,
5908            apply_bias_correction: false,
5909            // Coverage corrections off so the test asserts the legacy
5910            // unadjusted interval semantics.
5911            edgeworth_one_sided: false,
5912            boundary_correction: false,
5913            ood_inflation: false,
5914            multi_point_joint: false,
5915            ..PredictUncertaintyOptions::default()
5916        };
5917
5918        let out = predict_gamwith_uncertainty(
5919            x,
5920            beta.view(),
5921            offset.view(),
5922            crate::types::LikelihoodFamily::RoystonParmar,
5923            &fit,
5924            &options,
5925        )
5926        .expect("royston-parmar uncertainty");
5927
5928        let quadctx = crate::quadrature::QuadratureContext::new();
5929        let (_, variance) = crate::quadrature::survival_posterior_meanvariance(&quadctx, 0.6, 0.5);
5930        assert!((out.mean[0] - (-(0.6_f64.exp())).exp()).abs() <= 1e-12);
5931        assert!((out.eta_standard_error[0] - 0.5).abs() <= 1e-12);
5932        assert!((out.mean_standard_error[0] - variance.sqrt()).abs() <= 1e-12);
5933        assert!(out.mean_lower[0] <= out.mean_upper[0]);
5934        assert!((0.0..=1.0).contains(&out.mean_lower[0]));
5935        assert!((0.0..=1.0).contains(&out.mean_upper[0]));
5936    }
5937
5938    #[test]
5939    fn gaussian_location_scale_sigma_includes_noise_offset() {
5940        let predictor = GaussianLocationScalePredictor {
5941            beta_mu: array![0.0],
5942            beta_noise: array![0.0],
5943            response_scale: 2.0,
5944            covariance: None,
5945            link_wiggle: None,
5946        };
5947        let input = PredictInput {
5948            design: DesignMatrix::from(array![[1.0], [1.0]]),
5949            offset: array![0.0, 0.0],
5950            design_noise: Some(DesignMatrix::from(array![[1.0], [1.0]])),
5951            offset_noise: Some(array![(3.0f64).ln(), (5.0f64).ln()]),
5952            auxiliary_scalar: None,
5953            auxiliary_matrix: None,
5954        };
5955
5956        let sigma = predictor
5957            .predict_noise_scale(&input)
5958            .expect("gaussian location-scale sigma")
5959            .expect("sigma should be returned");
5960        // σ = (LOGB_SIGMA_FLOOR + exp(η + offset)) * scale; (0.01 + 3) * 2 = 6.02.
5961        assert!((sigma[0] - 6.02).abs() <= 1e-12);
5962        assert!((sigma[1] - 10.02).abs() <= 1e-12);
5963        let out = predictor
5964            .predict_with_uncertainty(&input)
5965            .expect("gaussian location-scale uncertainty");
5966        assert!(out.eta_se.is_none());
5967        assert!(out.mean_se.is_none());
5968    }
5969
5970    #[test]
5971    fn gaussian_location_scale_eta_se_pads_scale_block_without_wiggle() {
5972        let predictor = GaussianLocationScalePredictor {
5973            beta_mu: array![0.5],
5974            beta_noise: array![0.1],
5975            response_scale: 1.0,
5976            covariance: Some(array![[4.0, 0.0], [0.0, 9.0]]),
5977            link_wiggle: None,
5978        };
5979        let fit = gaussian_location_scale_fit_with_covariance(
5980            array![0.5],
5981            array![0.1],
5982            array![[4.0, 0.0], [0.0, 9.0]],
5983        );
5984        let input = PredictInput {
5985            design: DesignMatrix::from(array![[1.0]]),
5986            offset: array![0.0],
5987            design_noise: Some(DesignMatrix::from(array![[1.0]])),
5988            offset_noise: None,
5989            auxiliary_scalar: None,
5990            auxiliary_matrix: None,
5991        };
5992
5993        let out = predictor
5994            .predict_posterior_mean(&input, &fit, None)
5995            .expect("gaussian location-scale posterior mean");
5996        assert!((out.eta_standard_error[0] - 2.0).abs() <= 1e-12);
5997    }
5998
5999    #[test]
6000    fn survival_eta_se_pads_log_sigma_block() {
6001        let predictor = SurvivalPredictor {
6002            beta_threshold: array![0.5],
6003            beta_log_sigma: array![0.0],
6004            inverse_link: InverseLink::Standard(LinkFunction::Probit),
6005            covariance: Some(array![[9.0, 0.0], [0.0, 16.0]]),
6006        };
6007        let input = PredictInput {
6008            design: DesignMatrix::from(array![[1.0]]),
6009            offset: array![0.0],
6010            design_noise: Some(DesignMatrix::from(array![[1.0]])),
6011            offset_noise: Some(array![0.0]),
6012            auxiliary_scalar: None,
6013            auxiliary_matrix: None,
6014        };
6015
6016        let out = predictor
6017            .predict_with_uncertainty(&input)
6018            .expect("survival uncertainty");
6019        let eta_se = out.eta_se.expect("eta_se should be present");
6020        assert!((eta_se[0] - 3.0).abs() <= 1e-12);
6021    }
6022
6023    #[test]
6024    fn survival_predictor_cloglog_point_and_se_use_upper_tail_at_q0() {
6025        let predictor = SurvivalPredictor {
6026            beta_threshold: array![-1.0],
6027            beta_log_sigma: array![0.0],
6028            inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6029            covariance: Some(array![[4.0, 0.0], [0.0, 0.0]]),
6030        };
6031        let input = PredictInput {
6032            design: DesignMatrix::from(array![[1.0]]),
6033            offset: array![0.0],
6034            design_noise: Some(DesignMatrix::from(array![[1.0]])),
6035            offset_noise: Some(array![0.0]),
6036            auxiliary_scalar: None,
6037            auxiliary_matrix: None,
6038        };
6039
6040        let out = predictor
6041            .predict_with_uncertainty(&input)
6042            .expect("cloglog survival prediction");
6043        let q0 = 1.0_f64;
6044        let expected_survival = (-(q0.exp())).exp();
6045        let expected_mean_se = 2.0 * (q0 - q0.exp()).exp();
6046
6047        assert!((out.mean[0] - expected_survival).abs() <= 1e-12);
6048        assert!(
6049            (out.mean_se.expect("mean_se should be present")[0] - expected_mean_se).abs() <= 1e-12
6050        );
6051    }
6052
6053    #[test]
6054    fn survival_predictor_cloglog_posterior_mean_zero_covariance_matches_point_prediction() {
6055        let predictor = SurvivalPredictor {
6056            beta_threshold: array![-1.0],
6057            beta_log_sigma: array![0.0],
6058            inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6059            covariance: Some(Array2::zeros((2, 2))),
6060        };
6061        let fit = survival_fit_with_covariance(array![-1.0], array![0.0], Array2::zeros((2, 2)));
6062        let input = PredictInput {
6063            design: DesignMatrix::from(array![[1.0]]),
6064            offset: array![0.0],
6065            design_noise: Some(DesignMatrix::from(array![[1.0]])),
6066            offset_noise: Some(array![0.0]),
6067            auxiliary_scalar: None,
6068            auxiliary_matrix: None,
6069        };
6070
6071        let point = predictor
6072            .predict_plugin_response(&input)
6073            .expect("cloglog survival point prediction");
6074        let posterior = predictor
6075            .predict_posterior_mean(&input, &fit, None)
6076            .expect("cloglog survival posterior mean");
6077
6078        assert!((posterior.mean[0] - point.mean[0]).abs() <= 1e-12);
6079    }
6080
6081    #[test]
6082    fn survival_predictor_zero_threshold_with_tiny_sigma_stays_finite() {
6083        let predictor = SurvivalPredictor {
6084            beta_threshold: array![0.0],
6085            beta_log_sigma: array![0.0],
6086            inverse_link: InverseLink::Standard(LinkFunction::CLogLog),
6087            covariance: None,
6088        };
6089        let input = PredictInput {
6090            design: DesignMatrix::from(array![[1.0]]),
6091            offset: array![0.0],
6092            design_noise: Some(DesignMatrix::from(array![[1.0]])),
6093            offset_noise: Some(array![-1000.0]),
6094            auxiliary_scalar: None,
6095            auxiliary_matrix: None,
6096        };
6097
6098        let point = predictor
6099            .predict_plugin_response(&input)
6100            .expect("cloglog survival point prediction");
6101        let expected = (-1.0_f64).exp();
6102
6103        assert!(point.mean[0].is_finite());
6104        assert!((point.mean[0] - expected).abs() <= 1e-12);
6105    }
6106
6107    // ─── O(n⁻¹) frequentist bias correction tests ─────────────────────────
6108
6109    fn test_fit_with_bias_correction(
6110        beta: Array1<f64>,
6111        covariance: Array2<f64>,
6112        bias_correction_beta: Option<Array1<f64>>,
6113    ) -> UnifiedFitResult {
6114        use crate::estimate::FitInference;
6115        let p = beta.len();
6116        let inf = FitInference {
6117            // No penalty in this fixture (lambdas empty), so leave edf_by_block
6118            // empty to satisfy the EDF/lambdas count invariant.
6119            edf_by_block: vec![],
6120            edf_total: p as f64,
6121            smoothing_correction: None,
6122            penalized_hessian: Array2::<f64>::eye(p),
6123            working_weights: Array1::zeros(0),
6124            working_response: Array1::zeros(0),
6125            reparam_qs: None,
6126            beta_covariance: Some(covariance.clone()),
6127            beta_standard_errors: None,
6128            beta_covariance_corrected: None,
6129            beta_standard_errors_corrected: None,
6130            bias_correction_beta,
6131        };
6132        UnifiedFitResult::new_for_test_unchecked(UnifiedFitResultParts {
6133            blocks: vec![FittedBlock {
6134                beta: beta.clone(),
6135                role: BlockRole::Mean,
6136                edf: p as f64,
6137                lambdas: Array1::zeros(0),
6138            }],
6139            log_lambdas: Array1::zeros(0),
6140            lambdas: Array1::zeros(0),
6141            likelihood_family: Some(crate::types::LikelihoodFamily::GaussianIdentity),
6142            likelihood_scale: crate::types::LikelihoodScaleMetadata::ProfiledGaussian,
6143            log_likelihood_normalization: crate::types::LogLikelihoodNormalization::Full,
6144            log_likelihood: 0.0,
6145            deviance: 0.0,
6146            reml_score: 0.0,
6147            stable_penalty_term: 0.0,
6148            penalized_objective: 0.0,
6149            outer_iterations: 0,
6150            outer_converged: true,
6151            outer_gradient_norm: 0.0,
6152            standard_deviation: 1.0,
6153            covariance_conditional: Some(covariance),
6154            covariance_corrected: None,
6155            inference: Some(inf),
6156            fitted_link: FittedLinkState::Standard(Some(LinkFunction::Identity)),
6157            geometry: None,
6158            block_states: Vec::new(),
6159            pirls_status: PirlsStatus::Converged,
6160            max_abs_eta: 0.0,
6161            constraint_kkt: None,
6162            artifacts: FitArtifacts {
6163                pirls: None,
6164                ..Default::default()
6165            },
6166            inner_cycles: 0,
6167        })
6168    }
6169
6170    fn bc_options(apply: bool) -> PredictUncertaintyOptions {
6171        PredictUncertaintyOptions {
6172            confidence_level: 0.95,
6173            covariance_mode: InferenceCovarianceMode::Conditional,
6174            mean_interval_method: MeanIntervalMethod::TransformEta,
6175            includeobservation_interval: false,
6176            apply_bias_correction: apply,
6177            edgeworth_one_sided: false,
6178            boundary_correction: false,
6179            ood_inflation: false,
6180            multi_point_joint: false,
6181            ..PredictUncertaintyOptions::default()
6182        }
6183    }
6184
6185    #[test]
6186    fn test_bias_correction_idempotent_with_flag() {
6187        // With bc=[0.1, -0.05] and x=[[1, 2]], delta_eta = [1*0.1 + 2*(-0.05)] = [0].
6188        // Use a non-degenerate row to see a real shift.
6189        let x = array![[1.0, 0.5]];
6190        let beta = array![1.0, 2.0];
6191        let bc = array![0.1, -0.05];
6192        let cov = Array2::<f64>::eye(2);
6193        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
6194        let offset = array![0.0];
6195
6196        // Raw eta = [1.0 + 1.0] = 2.0; corrected eta = 2.0 + (0.1 + 0.5*(-0.05)) = 2.075.
6197        let pred_off = predict_gamwith_uncertainty(
6198            x.clone(),
6199            beta.view(),
6200            offset.view(),
6201            crate::types::LikelihoodFamily::GaussianIdentity,
6202            &fit,
6203            &bc_options(false),
6204        )
6205        .expect("predict no-bc");
6206        let pred_on = predict_gamwith_uncertainty(
6207            x.clone(),
6208            beta.view(),
6209            offset.view(),
6210            crate::types::LikelihoodFamily::GaussianIdentity,
6211            &fit,
6212            &bc_options(true),
6213        )
6214        .expect("predict bc");
6215        assert!((pred_off.eta[0] - 2.0).abs() < 1e-12);
6216        let expected_delta = 1.0 * 0.1 + 0.5 * (-0.05);
6217        assert!((pred_on.eta[0] - (2.0 + expected_delta)).abs() < 1e-12);
6218        // SE unchanged at first order: identical covariance and design.
6219        assert!(
6220            (pred_off.eta_standard_error[0] - pred_on.eta_standard_error[0]).abs() < 1e-14,
6221            "bias correction must not affect eta standard error"
6222        );
6223    }
6224
6225    #[test]
6226    fn test_bias_correction_zero_when_unset() {
6227        // Without bias_correction_beta, prediction must equal raw plug-in regardless
6228        // of the apply_bias_correction flag.
6229        let x = array![[1.0, 0.5]];
6230        let beta = array![1.0, 2.0];
6231        let cov = Array2::<f64>::eye(2);
6232        let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
6233        let offset = array![0.0];
6234
6235        let pred = predict_gamwith_uncertainty(
6236            x,
6237            beta.view(),
6238            offset.view(),
6239            crate::types::LikelihoodFamily::GaussianIdentity,
6240            &fit,
6241            &bc_options(true),
6242        )
6243        .expect("predict");
6244        assert!((pred.eta[0] - 2.0).abs() < 1e-12);
6245    }
6246
6247    #[test]
6248    fn test_bias_correction_does_not_affect_posterior_se() {
6249        // SE depends only on cov and design rows, not on β or the BC vector.
6250        let x = array![[1.0, 0.5], [0.7, -0.3]];
6251        let beta = array![0.4, 0.9];
6252        let bc = array![0.2, -0.1];
6253        let cov = array![[1.0, 0.1], [0.1, 0.5]];
6254        let fit_with = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(bc));
6255        let fit_without = test_fit_with_bias_correction(beta.clone(), cov, None);
6256        let offset = array![0.0, 0.0];
6257
6258        let pred_with = predict_gamwith_uncertainty(
6259            x.clone(),
6260            beta.view(),
6261            offset.view(),
6262            crate::types::LikelihoodFamily::GaussianIdentity,
6263            &fit_with,
6264            &bc_options(true),
6265        )
6266        .expect("predict with bc");
6267        let pred_without = predict_gamwith_uncertainty(
6268            x,
6269            beta.view(),
6270            offset.view(),
6271            crate::types::LikelihoodFamily::GaussianIdentity,
6272            &fit_without,
6273            &bc_options(true),
6274        )
6275        .expect("predict without bc");
6276        for i in 0..2 {
6277            assert!(
6278                (pred_with.eta_standard_error[i] - pred_without.eta_standard_error[i]).abs()
6279                    < 1e-14,
6280                "BC must not perturb eta SE at index {i}"
6281            );
6282        }
6283    }
6284
6285    #[test]
6286    fn test_bias_correction_accessor_propagates() {
6287        // bias_correction_beta() accessor returns the value stored on FitInference.
6288        let beta = array![1.0, 2.0];
6289        let bc = array![0.3, -0.2];
6290        let cov = Array2::<f64>::eye(2);
6291        let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
6292        let recovered = fit
6293            .bias_correction_beta()
6294            .expect("bias correction should be present");
6295        assert_eq!(recovered.len(), bc.len());
6296        for i in 0..bc.len() {
6297            assert!((recovered[i] - bc[i]).abs() < 1e-15);
6298        }
6299    }
6300
6301    // ─── Stronger, adversarial bias-correction tests ──────────────────────
6302
6303    /// Solve a small symmetric 3x3 SPD system H y = r by closed-form 3x3
6304    /// inverse via the cofactor / adjugate formula. Used to compute the
6305    /// expected bias_correction_beta = H^{-1} S β̂ by hand.
6306    fn solve_3x3_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
6307        assert_eq!(h.nrows(), 3);
6308        assert_eq!(h.ncols(), 3);
6309        let m = |i: usize, j: usize| h[[i, j]];
6310        let det = m(0, 0) * (m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1))
6311            - m(0, 1) * (m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0))
6312            + m(0, 2) * (m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0));
6313        assert!(det.abs() > 1e-12, "singular matrix in solve_3x3_spd");
6314        // Cofactor matrix; inverse = adj/det = transpose(cof)/det.
6315        let cof = array![
6316            [
6317                m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1),
6318                -(m(1, 0) * m(2, 2) - m(1, 2) * m(2, 0)),
6319                m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0)
6320            ],
6321            [
6322                -(m(0, 1) * m(2, 2) - m(0, 2) * m(2, 1)),
6323                m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0),
6324                -(m(0, 0) * m(2, 1) - m(0, 1) * m(2, 0))
6325            ],
6326            [
6327                m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1),
6328                -(m(0, 0) * m(1, 2) - m(0, 2) * m(1, 0)),
6329                m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0)
6330            ]
6331        ];
6332        // adj = cof^T
6333        let mut y = Array1::<f64>::zeros(3);
6334        for i in 0..3 {
6335            let mut acc = 0.0;
6336            for j in 0..3 {
6337                acc += cof[[j, i]] * r[j];
6338            }
6339            y[i] = acc / det;
6340        }
6341        y
6342    }
6343
6344    /// Tiny deterministic LCG for reproducibility without an external crate.
6345    struct Lcg(u64);
6346    impl Lcg {
6347        fn new(seed: u64) -> Self {
6348            Self(
6349                seed.wrapping_mul(6364136223846793005)
6350                    .wrapping_add(1442695040888963407),
6351            )
6352        }
6353        fn next_u64(&mut self) -> u64 {
6354            self.0 = self
6355                .0
6356                .wrapping_mul(6364136223846793005)
6357                .wrapping_add(1442695040888963407);
6358            self.0
6359        }
6360        fn unif(&mut self) -> f64 {
6361            // Take top 53 bits → [0, 1).
6362            ((self.next_u64() >> 11) as f64) / ((1u64 << 53) as f64)
6363        }
6364        /// Box–Muller standard normal.
6365        fn normal(&mut self) -> f64 {
6366            let u1 = self.unif().max(1e-300);
6367            let u2 = self.unif();
6368            (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
6369        }
6370    }
6371
6372    /// Test 1: η̂_BC at x = I_p columns equals β̂ + b̂ component-wise,
6373    /// where b̂ = H⁻¹ S β̂ is computed by hand.
6374    #[test]
6375    fn test_bias_correction_matches_explicit_formula() {
6376        // p = 3. Pick H SPD (= XᵀWX + S in spirit), S, β̂, then solve H b = S β̂.
6377        let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.0]];
6378        let s_pen = array![[1.0_f64, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 2.0]];
6379        let beta = array![0.7_f64, -1.3, 0.4];
6380        let s_beta = s_pen.dot(&beta);
6381        let b_hat = solve_3x3_spd(&h, &s_beta);
6382
6383        // Cov is just a placeholder for the SE machinery; not used in this assertion.
6384        let cov = Array2::<f64>::eye(3);
6385        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat.clone()));
6386
6387        // Predict at the standard-basis rows: η_raw = β, η_BC = β + b_hat.
6388        let x = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
6389        let offset = array![0.0, 0.0, 0.0];
6390
6391        let pred_raw = predict_gamwith_uncertainty(
6392            x.clone(),
6393            beta.view(),
6394            offset.view(),
6395            crate::types::LikelihoodFamily::GaussianIdentity,
6396            &fit,
6397            &bc_options(false),
6398        )
6399        .expect("raw predict");
6400        let pred_bc = predict_gamwith_uncertainty(
6401            x,
6402            beta.view(),
6403            offset.view(),
6404            crate::types::LikelihoodFamily::GaussianIdentity,
6405            &fit,
6406            &bc_options(true),
6407        )
6408        .expect("bc predict");
6409
6410        for i in 0..3 {
6411            assert!(
6412                (pred_raw.eta[i] - beta[i]).abs() < 1e-12,
6413                "raw eta[{i}] = {} expected {}",
6414                pred_raw.eta[i],
6415                beta[i]
6416            );
6417            let expected = beta[i] + b_hat[i];
6418            assert!(
6419                (pred_bc.eta[i] - expected).abs() < 1e-12,
6420                "BC eta[{i}] = {} expected β+b̂ = {} (b̂[{i}] = {})",
6421                pred_bc.eta[i],
6422                expected,
6423                b_hat[i]
6424            );
6425        }
6426    }
6427
6428    /// Test 2: S = 0 ⇒ b̂ = H⁻¹ · 0 · β̂ = 0; corrected prediction equals raw.
6429    #[test]
6430    fn test_bias_correction_zero_for_zero_penalty() {
6431        // With S = 0, the canonical fit-time computation produces b̂ = 0.
6432        // Inject a zero bias_correction_beta and verify η_BC == η_raw exactly.
6433        let beta = array![0.5_f64, -0.4, 1.7];
6434        let bc_zero = Array1::<f64>::zeros(3);
6435        let cov = Array2::<f64>::eye(3);
6436        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc_zero));
6437
6438        let x = array![[1.0, 2.0, -0.5], [0.3, -0.7, 1.2], [2.0, 0.1, 0.0]];
6439        let offset = array![0.0, 0.0, 0.0];
6440
6441        let pred_raw = predict_gamwith_uncertainty(
6442            x.clone(),
6443            beta.view(),
6444            offset.view(),
6445            crate::types::LikelihoodFamily::GaussianIdentity,
6446            &fit,
6447            &bc_options(false),
6448        )
6449        .expect("raw predict");
6450        let pred_bc = predict_gamwith_uncertainty(
6451            x,
6452            beta.view(),
6453            offset.view(),
6454            crate::types::LikelihoodFamily::GaussianIdentity,
6455            &fit,
6456            &bc_options(true),
6457        )
6458        .expect("bc predict");
6459
6460        for i in 0..3 {
6461            assert!(
6462                (pred_bc.eta[i] - pred_raw.eta[i]).abs() < 1e-15,
6463                "S=0 ⇒ BC must be a no-op; got Δ={} at i={i}",
6464                pred_bc.eta[i] - pred_raw.eta[i]
6465            );
6466        }
6467    }
6468
6469    /// Test 3: ‖η̂_BC − η̂_raw‖ is monotone-increasing in the scalar λ
6470    /// multiplier of S. Specifically, for fixed H_base = XᵀWX, set
6471    /// H(λ) = H_base + λI and S(λ) = λI, so b̂(λ) = H(λ)⁻¹ (λI) β̂.
6472    #[test]
6473    fn test_bias_correction_increases_with_penalty_strength() {
6474        // Use p = 3 and the same H_base / β̂ across runs.
6475        let h_base = array![[3.0_f64, 0.4, 0.1], [0.4, 2.5, 0.2], [0.1, 0.2, 4.0]];
6476        let beta = array![1.2_f64, -0.8, 0.5];
6477        let x = array![[1.0, 0.5, -0.2], [0.3, -0.4, 0.9], [0.7, 0.7, 0.7]];
6478        let offset = array![0.0, 0.0, 0.0];
6479
6480        let lambdas = [0.1_f64, 1.0, 10.0];
6481        let mut deltas = Vec::with_capacity(lambdas.len());
6482        for &lam in &lambdas {
6483            // H(λ) = H_base + λ I; S(λ) = λ I.
6484            let mut h = h_base.clone();
6485            for k in 0..3 {
6486                h[[k, k]] += lam;
6487            }
6488            let s_beta = beta.mapv(|v| lam * v);
6489            let b_hat = solve_3x3_spd(&h, &s_beta);
6490
6491            let cov = Array2::<f64>::eye(3);
6492            let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(b_hat));
6493
6494            let pred_raw = predict_gamwith_uncertainty(
6495                x.clone(),
6496                beta.view(),
6497                offset.view(),
6498                crate::types::LikelihoodFamily::GaussianIdentity,
6499                &fit,
6500                &bc_options(false),
6501            )
6502            .expect("raw predict");
6503            let pred_bc = predict_gamwith_uncertainty(
6504                x.clone(),
6505                beta.view(),
6506                offset.view(),
6507                crate::types::LikelihoodFamily::GaussianIdentity,
6508                &fit,
6509                &bc_options(true),
6510            )
6511            .expect("bc predict");
6512
6513            let mut sumsq = 0.0;
6514            for i in 0..3 {
6515                let d = pred_bc.eta[i] - pred_raw.eta[i];
6516                sumsq += d * d;
6517            }
6518            deltas.push(sumsq.sqrt());
6519        }
6520
6521        assert!(
6522            deltas[0] < deltas[1],
6523            "‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
6524            lambdas[0],
6525            deltas[0],
6526            lambdas[1],
6527            deltas[1]
6528        );
6529        assert!(
6530            deltas[1] < deltas[2],
6531            "‖η_BC − η_raw‖ must grow with λ: λ={} gave {}, λ={} gave {}",
6532            lambdas[1],
6533            deltas[1],
6534            lambdas[2],
6535            deltas[2]
6536        );
6537        // And there should be a meaningful gap, not numerical noise.
6538        assert!(
6539            deltas[2] > 10.0 * deltas[0],
6540            "expected order-of-magnitude growth in BC magnitude across λ ∈ {{0.1,1,10}}; got {:?}",
6541            deltas
6542        );
6543    }
6544
6545    /// Test 4: under strong shrinkage, the bias-corrected predictor moves
6546    /// closer to the unpenalized OLS predictor than the raw penalized
6547    /// predictor. We hand-construct a fixture where:
6548    ///   β̂   = small-shrunk version of β_OLS,
6549    ///   H   = XᵀX + S,  with S = λI,
6550    ///   b̂   = H⁻¹ S β̂.
6551    /// At ≥90% of test points, |η_OLS − η_BC| < |η_OLS − η_raw|.
6552    #[test]
6553    fn test_bias_correction_recovers_unpenalized_in_simulation() {
6554        let n = 200usize;
6555        let p = 5usize;
6556        let mut rng = Lcg::new(0xC0FFEE_u64);
6557
6558        // Design matrix X (n × p) with column 0 = 1 (intercept-like).
6559        let mut x_data = vec![0.0_f64; n * p];
6560        for i in 0..n {
6561            x_data[i * p] = 1.0;
6562            for j in 1..p {
6563                x_data[i * p + j] = rng.normal();
6564            }
6565        }
6566        let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
6567
6568        // True beta and (unpenalized) OLS beta from y = Xβ_true + ε.
6569        let beta_true = array![0.5_f64, 1.0, -0.7, 0.3, 0.8];
6570        let mut y = Array1::<f64>::zeros(n);
6571        for i in 0..n {
6572            let mut eta = 0.0;
6573            for j in 0..p {
6574                eta += x[[i, j]] * beta_true[j];
6575            }
6576            y[i] = eta + 0.3 * rng.normal();
6577        }
6578        // β_OLS = (XᵀX)⁻¹ Xᵀy. Use ndarray-via-explicit approach: solve via LU
6579        // by leveraging the existing 3x3 helper is impossible at p=5; instead
6580        // form the Cholesky-like solve via faer-free Gauss elimination.
6581        let xtx = x.t().dot(&x);
6582        let xty = x.t().dot(&y);
6583        let beta_ols = solve_dense_spd(&xtx, &xty);
6584
6585        // Pretend the penalized fit shrunk OLS by factor 0.6: β̂ = 0.6·β_OLS.
6586        let shrink = 0.6_f64;
6587        let beta_hat = beta_ols.mapv(|v| shrink * v);
6588
6589        // S = λ I with λ chosen so shrinkage matches the target. Exact match
6590        // is not required; we just need a consistent (H, S, β̂) triple.
6591        let lambda = 100.0_f64;
6592        let mut h = xtx.clone();
6593        for k in 0..p {
6594            h[[k, k]] += lambda;
6595        }
6596        let s_beta = beta_hat.mapv(|v| lambda * v);
6597        let b_hat = solve_dense_spd(&h, &s_beta);
6598
6599        let cov = Array2::<f64>::eye(p);
6600        let fit = test_fit_with_bias_correction(beta_hat.clone(), cov, Some(b_hat.clone()));
6601
6602        // Test points: a held-out random batch of 50 rows.
6603        let m = 50usize;
6604        let mut xt_data = vec![0.0_f64; m * p];
6605        for i in 0..m {
6606            xt_data[i * p] = 1.0;
6607            for j in 1..p {
6608                xt_data[i * p + j] = rng.normal();
6609            }
6610        }
6611        let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
6612        let offset = Array1::<f64>::zeros(m);
6613
6614        let pred_raw = predict_gamwith_uncertainty(
6615            xt.clone(),
6616            beta_hat.view(),
6617            offset.view(),
6618            crate::types::LikelihoodFamily::GaussianIdentity,
6619            &fit,
6620            &bc_options(false),
6621        )
6622        .expect("raw predict");
6623        let pred_bc = predict_gamwith_uncertainty(
6624            xt.clone(),
6625            beta_hat.view(),
6626            offset.view(),
6627            crate::types::LikelihoodFamily::GaussianIdentity,
6628            &fit,
6629            &bc_options(true),
6630        )
6631        .expect("bc predict");
6632        let eta_ols = xt.dot(&beta_ols);
6633
6634        let mut closer = 0usize;
6635        for i in 0..m {
6636            let raw_gap = (eta_ols[i] - pred_raw.eta[i]).abs();
6637            let bc_gap = (eta_ols[i] - pred_bc.eta[i]).abs();
6638            if bc_gap < raw_gap {
6639                closer += 1;
6640            }
6641        }
6642        let frac = closer as f64 / m as f64;
6643        assert!(
6644            frac >= 0.9,
6645            "BC must close the OLS gap at ≥90% of test points; got {}/{} = {:.2}",
6646            closer,
6647            m,
6648            frac
6649        );
6650    }
6651
6652    /// Test 5: bias is O(n⁻¹) — it should shrink as n grows when λ is held
6653    /// at a fixed (n-independent) value. The previous formulation drew a
6654    /// fresh (X, y) at each seed and averaged across 12 seeds; with σ²=0.25
6655    /// and p=4, the per-seed coefficient SE Var(β̂)≈σ²/n is comparable to
6656    /// or larger than the true bias H⁻¹λβ ≈ (λ/n)·β at n=5000, so the
6657    /// MC-averaged "bias" estimator is dominated by sampling noise of η̂
6658    /// rather than by the bias signal — the headline ratio cannot be
6659    /// resolved at this scale with 12 seeds.
6660    ///
6661    /// The principled comparison is deterministic. For Gaussian-identity
6662    /// ridge with penalty S = λ I and design X (fixed), the conditional
6663    /// mean of the penalized estimator is
6664    ///     E[β̂ | X] = (XᵀX + λI)⁻¹ XᵀX β = β - H⁻¹ S β.
6665    /// The bias-correction vector is b̂(β̂) = H⁻¹ S β̂, so the conditional
6666    /// mean of the corrected estimator is
6667    ///     E[β̂_BC | X] = E[β̂|X] + H⁻¹ S E[β̂|X] = β - (H⁻¹ S)² β.
6668    /// Thus the conditional bias of η̂_raw is -xᵀH⁻¹Sβ (order λ/n), and
6669    /// the conditional bias of η̂_BC is -xᵀ(H⁻¹S)²β (order (λ/n)²). The
6670    /// ratio scales like λ/(n+λ), which at n=5000 and λ=5 is ≈ 10⁻³.
6671    ///
6672    /// We run the production prediction pipeline with `β̂ := E[β̂|X]` and
6673    /// `b̂ := H⁻¹ S β̂` (both deterministic). The eta we read back is
6674    /// exactly E[η̂_*|X], so |Δη| against η_true measures conditional bias
6675    /// without any Monte-Carlo overlay. This both (a) eliminates the
6676    /// signal-vs-noise floor and (b) still exercises the BC wiring inside
6677    /// `predict_gamwith_uncertainty`.
6678    #[test]
6679    fn test_bias_correction_bias_drops_with_n_simulation() {
6680        let p = 4usize;
6681        let beta_true = array![0.4_f64, 0.9, -0.5, 0.6];
6682        let lambda = 5.0_f64;
6683        let ns = [200usize, 1000, 5000];
6684
6685        // Held-out test points are reused across n (they are just probes).
6686        let m = 32usize;
6687        let mut probe_rng = Lcg::new(424242);
6688        let mut xt_data = vec![0.0_f64; m * p];
6689        for i in 0..m {
6690            xt_data[i * p] = 1.0;
6691            for j in 1..p {
6692                xt_data[i * p + j] = probe_rng.normal();
6693            }
6694        }
6695        let xt = Array2::from_shape_vec((m, p), xt_data).expect("Xtest shape");
6696        let eta_true = xt.dot(&beta_true);
6697        let offset = Array1::<f64>::zeros(m);
6698
6699        let mut mean_abs_raw_bias = [0.0_f64; 3];
6700        let mut mean_abs_bc_bias = [0.0_f64; 3];
6701
6702        // Use independent outer cases as the parallel work unit. Each case
6703        // builds its own design and performs two small dense SPD solves; keep
6704        // those solves serial to avoid fine-grained Rayon overhead inside the
6705        // dense elimination kernel itself.
6706        //
6707        // Each n still starts from the same deterministic LCG seed. Different
6708        // n therefore share the same seed prefix for their first min(n_a, n_b)
6709        // rows, isolating the ratio drop to scale alone rather than to a
6710        // confounding draw.
6711        let bias_by_n: Vec<(usize, f64, f64)> = (0..ns.len())
6712            .into_par_iter()
6713            .map(|kn| {
6714                let n = ns[kn];
6715                let mut rng = Lcg::new(0xBEEFu64);
6716                let mut x_data = vec![0.0_f64; n * p];
6717                for i in 0..n {
6718                    x_data[i * p] = 1.0;
6719                    for j in 1..p {
6720                        x_data[i * p + j] = rng.normal();
6721                    }
6722                }
6723                let x = Array2::from_shape_vec((n, p), x_data).expect("X shape");
6724                let xtx = x.t().dot(&x);
6725                let mut h = xtx.clone();
6726                for k in 0..p {
6727                    h[[k, k]] += lambda;
6728                }
6729
6730                // E[β̂ | X] = β - H⁻¹ S β = (XᵀX + λI)⁻¹ XᵀX β.
6731                let xtx_beta = xtx.dot(&beta_true);
6732                let beta_mean = solve_dense_spd(&h, &xtx_beta);
6733                // b̂(β̂) at β̂ = E[β̂|X]: b̂ = H⁻¹ λ β̂.
6734                let s_beta_mean = beta_mean.mapv(|v| lambda * v);
6735                let b_hat = solve_dense_spd(&h, &s_beta_mean);
6736
6737                let cov = Array2::<f64>::eye(p);
6738                let fit = test_fit_with_bias_correction(beta_mean.clone(), cov, Some(b_hat));
6739
6740                let pred_raw = predict_gamwith_uncertainty(
6741                    xt.clone(),
6742                    beta_mean.view(),
6743                    offset.view(),
6744                    crate::types::LikelihoodFamily::GaussianIdentity,
6745                    &fit,
6746                    &bc_options(false),
6747                )
6748                .expect("raw predict");
6749                let pred_bc = predict_gamwith_uncertainty(
6750                    xt.clone(),
6751                    beta_mean.view(),
6752                    offset.view(),
6753                    crate::types::LikelihoodFamily::GaussianIdentity,
6754                    &fit,
6755                    &bc_options(true),
6756                )
6757                .expect("bc predict");
6758
6759                let mut acc_raw = 0.0;
6760                let mut acc_bc = 0.0;
6761                for i in 0..m {
6762                    acc_raw += (pred_raw.eta[i] - eta_true[i]).abs();
6763                    acc_bc += (pred_bc.eta[i] - eta_true[i]).abs();
6764                }
6765                (kn, acc_raw / m as f64, acc_bc / m as f64)
6766            })
6767            .collect();
6768        for (kn, raw, bc) in bias_by_n {
6769            mean_abs_raw_bias[kn] = raw;
6770            mean_abs_bc_bias[kn] = bc;
6771        }
6772
6773        // Raw bias should itself be decreasing in n (sanity check; otherwise
6774        // the test conditions are wrong, not the BC).
6775        assert!(
6776            mean_abs_raw_bias[2] < mean_abs_raw_bias[0],
6777            "raw penalized conditional bias should shrink with n: got {:?}",
6778            mean_abs_raw_bias
6779        );
6780        // The headline claim: BC is much smaller than raw at large n. The
6781        // analytic ratio is λ/(n+λ); at n=5000, λ=5 this is ≈10⁻³, so the
6782        // 0.5 threshold is conservative and the test fails decisively if
6783        // the BC sign or scale is wrong (e.g. dropping the H⁻¹, swapping
6784        // sign, or using cov instead of H).
6785        let ratio_large = mean_abs_bc_bias[2] / mean_abs_raw_bias[2].max(1e-300);
6786        assert!(
6787            ratio_large < 0.5,
6788            "BC must reduce conditional bias by >2× at n={}; raw={}, bc={}, ratio={}",
6789            ns[2],
6790            mean_abs_raw_bias[2],
6791            mean_abs_bc_bias[2],
6792            ratio_large
6793        );
6794        // And the BC/raw ratio should decrease (or at least not grow) with n.
6795        let ratio_small = mean_abs_bc_bias[0] / mean_abs_raw_bias[0].max(1e-300);
6796        assert!(
6797            ratio_large <= ratio_small + 1e-6,
6798            "BC/raw ratio should not grow with n: small-n ratio={}, large-n ratio={}",
6799            ratio_small,
6800            ratio_large
6801        );
6802    }
6803
6804    /// Test 6: invariance under invertible reparameterization. If β = Q θ,
6805    /// the design becomes X̃ = X Q⁻¹ in coefficient-θ space and the penalty
6806    /// becomes S̃ = Q⁻ᵀ S Q⁻¹. Then η̂_BC must equal η̂_BC(original) for any
6807    /// row x. We verify that swapping (β, b_hat, X) ↔ (θ, b̃, X̃) gives the
6808    /// same prediction.
6809    #[test]
6810    fn test_bias_correction_identity_in_basis_change() {
6811        // Original parameterization (p = 3).
6812        let h = array![[4.0_f64, 0.5, 0.2], [0.5, 3.0, 0.1], [0.2, 0.1, 2.5]];
6813        let s_pen = array![[0.7_f64, 0.1, 0.0], [0.1, 0.5, 0.05], [0.0, 0.05, 1.2]];
6814        let beta = array![0.6_f64, -0.4, 1.1];
6815        let s_beta = s_pen.dot(&beta);
6816        let b_hat = solve_3x3_spd(&h, &s_beta);
6817
6818        // Pick an invertible Q (upper-triangular with unit diagonal).
6819        let q = array![[1.0_f64, 0.3, -0.2], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]];
6820        // θ = Q⁻¹ β; with this triangular Q we can solve directly.
6821        let qinv = invert_upper_triangular_3(&q);
6822        let theta = qinv.dot(&beta);
6823        // b̃ = Q⁻¹ b̂.
6824        let b_tilde = qinv.dot(&b_hat);
6825
6826        // Test row x; in θ-space the row becomes x̃ = Q⁻ᵀ x  → but predicted
6827        // η is xᵀβ = xᵀ Q θ ⇒ x̃ = Qᵀ x. Use that form.
6828        let x_row = array![[0.4_f64, -0.7, 0.9]];
6829        let mut x_tilde = Array2::<f64>::zeros((1, 3));
6830        for j in 0..3 {
6831            let mut acc = 0.0;
6832            for i in 0..3 {
6833                acc += q[[i, j]] * x_row[[0, i]];
6834            }
6835            x_tilde[[0, j]] = acc;
6836        }
6837        let offset = array![0.0_f64];
6838
6839        let cov = Array2::<f64>::eye(3);
6840        let fit_orig = test_fit_with_bias_correction(beta.clone(), cov.clone(), Some(b_hat));
6841        let fit_repar = test_fit_with_bias_correction(theta.clone(), cov, Some(b_tilde));
6842
6843        let pred_orig = predict_gamwith_uncertainty(
6844            x_row,
6845            beta.view(),
6846            offset.view(),
6847            crate::types::LikelihoodFamily::GaussianIdentity,
6848            &fit_orig,
6849            &bc_options(true),
6850        )
6851        .expect("orig predict");
6852        let pred_repar = predict_gamwith_uncertainty(
6853            x_tilde,
6854            theta.view(),
6855            offset.view(),
6856            crate::types::LikelihoodFamily::GaussianIdentity,
6857            &fit_repar,
6858            &bc_options(true),
6859        )
6860        .expect("repar predict");
6861
6862        assert!(
6863            (pred_orig.eta[0] - pred_repar.eta[0]).abs() < 1e-12,
6864            "BC must be invariant under reparameterization: orig η={} repar η={} Δ={}",
6865            pred_orig.eta[0],
6866            pred_repar.eta[0],
6867            (pred_orig.eta[0] - pred_repar.eta[0]).abs()
6868        );
6869    }
6870
6871    /// Test 7: stronger no-SE-leakage check. Across 100 random test rows,
6872    /// the SE with BC enabled and SE with BC disabled differ by < 1e-14
6873    /// (relative magnitude). Catches accidental contamination of the
6874    /// variance pipeline by bias_correction_beta.
6875    #[test]
6876    fn test_bias_correction_does_not_inflate_se() {
6877        let p = 4usize;
6878        let beta = array![0.5_f64, -0.7, 1.1, 0.3];
6879        // Non-trivial covariance.
6880        let cov = array![
6881            [2.0_f64, 0.3, 0.1, 0.0],
6882            [0.3, 1.5, 0.2, 0.05],
6883            [0.1, 0.2, 1.8, 0.1],
6884            [0.0, 0.05, 0.1, 2.2]
6885        ];
6886        let bc = array![0.2_f64, -0.15, 0.05, 0.1];
6887        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
6888
6889        let m = 100usize;
6890        let mut rng = Lcg::new(0xBEEFCAFE_u64);
6891        let mut x_data = vec![0.0_f64; m * p];
6892        for i in 0..m {
6893            for j in 0..p {
6894                x_data[i * p + j] = rng.normal();
6895            }
6896        }
6897        let x = Array2::from_shape_vec((m, p), x_data).expect("X shape");
6898        let offset = Array1::<f64>::zeros(m);
6899
6900        let pred_off = predict_gamwith_uncertainty(
6901            x.clone(),
6902            beta.view(),
6903            offset.view(),
6904            crate::types::LikelihoodFamily::GaussianIdentity,
6905            &fit,
6906            &bc_options(false),
6907        )
6908        .expect("predict no-bc");
6909        let pred_on = predict_gamwith_uncertainty(
6910            x,
6911            beta.view(),
6912            offset.view(),
6913            crate::types::LikelihoodFamily::GaussianIdentity,
6914            &fit,
6915            &bc_options(true),
6916        )
6917        .expect("predict bc");
6918
6919        for i in 0..m {
6920            let a = pred_off.eta_standard_error[i];
6921            let b = pred_on.eta_standard_error[i];
6922            let rel = (a - b).abs() / a.abs().max(b.abs()).max(1e-300);
6923            assert!(
6924                rel < 1e-14,
6925                "SE leakage detected at i={}: off={}, on={}, relΔ={}",
6926                i,
6927                a,
6928                b,
6929                rel
6930            );
6931        }
6932    }
6933
6934    /// Test 8: pathological β̂ (NaN/Inf entries) must not panic. NaNs
6935    /// propagate into η rather than triggering an unwrap.
6936    #[test]
6937    fn test_bias_correction_finite_for_pathological_inputs() {
6938        let beta = array![1.0_f64, f64::NAN, 0.5];
6939        let bc = array![0.1_f64, 0.2, f64::INFINITY];
6940        let cov = Array2::<f64>::eye(3);
6941        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc));
6942
6943        let x = array![[1.0_f64, 1.0, 1.0]];
6944        let offset = array![0.0_f64];
6945        let pred = predict_gamwith_uncertainty(
6946            x,
6947            beta.view(),
6948            offset.view(),
6949            crate::types::LikelihoodFamily::GaussianIdentity,
6950            &fit,
6951            &bc_options(true),
6952        )
6953        .expect("pathological predict should not error, only propagate NaN/Inf");
6954        assert!(
6955            !pred.eta[0].is_finite(),
6956            "expected non-finite η to propagate; got η = {}",
6957            pred.eta[0]
6958        );
6959    }
6960
6961    /// Test 9: with apply_bias_correction = false, η̂ == β̂·x_* up to
6962    /// 1e-15 even when bias_correction_beta is loaded onto the fit.
6963    #[test]
6964    fn test_bias_correction_disabled_via_options_returns_raw() {
6965        let beta = array![1.5_f64, -0.7];
6966        let bc = array![0.4_f64, -0.3];
6967        let cov = Array2::<f64>::eye(2);
6968        let fit = test_fit_with_bias_correction(beta.clone(), cov, Some(bc.clone()));
6969
6970        let x = array![[1.0_f64, 0.5], [0.7, -0.3]];
6971        let offset = array![0.0_f64, 0.0];
6972        let pred = predict_gamwith_uncertainty(
6973            x.clone(),
6974            beta.view(),
6975            offset.view(),
6976            crate::types::LikelihoodFamily::GaussianIdentity,
6977            &fit,
6978            &bc_options(false),
6979        )
6980        .expect("predict no-bc");
6981
6982        // Raw η = X β.
6983        let expected = x.dot(&beta);
6984        for i in 0..2 {
6985            let d = (pred.eta[i] - expected[i]).abs();
6986            assert!(
6987                d < 1e-15,
6988                "apply_bias_correction=false must return raw plug-in: η[{i}]={} expected={} Δ={}",
6989                pred.eta[i],
6990                expected[i],
6991                d
6992            );
6993        }
6994    }
6995
6996    /// Test 10: bias correction must use the *penalized* Hessian H = XᵀWX + S,
6997    /// not the inverse of the supplied covariance. We construct a fixture
6998    /// where the supplied covariance ≠ H⁻¹ (we deliberately pass a different
6999    /// covariance into FitInference) and verify that prediction still uses
7000    /// the externally-supplied bias_correction_beta verbatim — i.e. the
7001    /// prediction code does NOT recompute b̂ from cov⁻¹ S β.
7002    #[test]
7003    fn test_bias_correction_with_nonidentity_covariance_uses_correct_h() {
7004        // True (XᵀWX + S) implied by the fit:
7005        let h_true = array![[5.0_f64, 0.7, 0.2], [0.7, 4.0, 0.3], [0.2, 0.3, 3.5]];
7006        let s_pen = array![[0.8_f64, 0.0, 0.0], [0.0, 1.2, 0.0], [0.0, 0.0, 0.6]];
7007        let beta = array![0.9_f64, -1.1, 0.4];
7008        let s_beta = s_pen.dot(&beta);
7009        let b_hat_correct = solve_3x3_spd(&h_true, &s_beta);
7010
7011        // Also compute the WRONG b̂ that one would get if the code used
7012        // covariance⁻¹ instead of H. We pick a covariance that is clearly
7013        // not H⁻¹: a tridiagonal SPD matrix.
7014        let cov_wrong = array![[2.0_f64, 0.4, 0.0], [0.4, 1.5, 0.3], [0.0, 0.3, 1.8]];
7015        // cov_wrong is not equal to H_true^{-1}.
7016        let h_inv = invert_3x3_spd(&h_true);
7017        let mut diff = 0.0;
7018        for i in 0..3 {
7019            for j in 0..3 {
7020                diff += (h_inv[[i, j]] - cov_wrong[[i, j]]).abs();
7021            }
7022        }
7023        assert!(
7024            diff > 0.5,
7025            "test setup error: cov_wrong should be far from H_true⁻¹ (diff={})",
7026            diff
7027        );
7028
7029        // Build the fit with the WRONG covariance but the CORRECT bias vector.
7030        // Predictions must reflect b_hat_correct (not whatever the code might
7031        // compute from cov_wrong).
7032        let fit =
7033            test_fit_with_bias_correction(beta.clone(), cov_wrong, Some(b_hat_correct.clone()));
7034
7035        let x = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
7036        let offset = array![0.0_f64, 0.0, 0.0];
7037        let pred = predict_gamwith_uncertainty(
7038            x,
7039            beta.view(),
7040            offset.view(),
7041            crate::types::LikelihoodFamily::GaussianIdentity,
7042            &fit,
7043            &bc_options(true),
7044        )
7045        .expect("predict bc");
7046
7047        for i in 0..3 {
7048            let expected = beta[i] + b_hat_correct[i];
7049            assert!(
7050                (pred.eta[i] - expected).abs() < 1e-12,
7051                "prediction must use the supplied bias_correction_beta verbatim: \
7052                 η[{i}]={} expected={} (β+b̂_correct[{i}]={})",
7053                pred.eta[i],
7054                expected,
7055                b_hat_correct[i]
7056            );
7057        }
7058    }
7059
7060    /// Test 11: bias_correction_beta survives serde JSON round-trip.
7061    /// Catches missing serde fields or skip_serializing attributes.
7062    #[test]
7063    fn test_bias_correction_propagates_through_unified_fit_result() {
7064        let beta = array![0.7_f64, -0.4, 1.2];
7065        let bc = array![0.123456789_f64, -0.987654321, 0.5];
7066        let cov = Array2::<f64>::eye(3);
7067        let fit = test_fit_with_bias_correction(beta, cov, Some(bc.clone()));
7068
7069        let json = serde_json::to_string(&fit).expect("serialize unified fit");
7070        let decoded: UnifiedFitResult =
7071            serde_json::from_str(&json).expect("deserialize unified fit");
7072        let recovered = decoded
7073            .bias_correction_beta()
7074            .expect("bias_correction_beta must survive JSON round-trip");
7075        assert_eq!(
7076            recovered.len(),
7077            bc.len(),
7078            "bc length changed across round-trip"
7079        );
7080        for i in 0..bc.len() {
7081            assert!(
7082                (recovered[i] - bc[i]).abs() < 1e-15,
7083                "bc[{i}] drifted across JSON round-trip: in={}, out={}",
7084                bc[i],
7085                recovered[i]
7086            );
7087        }
7088    }
7089
7090    // ─── Local linear-algebra helpers for the bias-correction tests ──────
7091
7092    /// Solve H y = r for general dense SPD H (small p) via Gauss elimination
7093    /// with partial pivoting. Used in the simulation tests where p > 3 makes
7094    /// the closed-form 3×3 helper insufficient.
7095    fn solve_dense_spd(h: &Array2<f64>, r: &Array1<f64>) -> Array1<f64> {
7096        let n = h.nrows();
7097        assert_eq!(h.ncols(), n);
7098        assert_eq!(r.len(), n);
7099        let mut a = Array2::<f64>::zeros((n, n + 1));
7100        for i in 0..n {
7101            for j in 0..n {
7102                a[[i, j]] = h[[i, j]];
7103            }
7104            a[[i, n]] = r[i];
7105        }
7106        for k in 0..n {
7107            // Partial pivot.
7108            let mut piv = k;
7109            let mut best = a[[k, k]].abs();
7110            for i in (k + 1)..n {
7111                if a[[i, k]].abs() > best {
7112                    best = a[[i, k]].abs();
7113                    piv = i;
7114                }
7115            }
7116            assert!(best > 1e-14, "near-singular system in solve_dense_spd");
7117            if piv != k {
7118                for j in 0..=n {
7119                    let tmp = a[[k, j]];
7120                    a[[k, j]] = a[[piv, j]];
7121                    a[[piv, j]] = tmp;
7122                }
7123            }
7124            for i in (k + 1)..n {
7125                let factor = a[[i, k]] / a[[k, k]];
7126                for j in k..=n {
7127                    a[[i, j]] -= factor * a[[k, j]];
7128                }
7129            }
7130        }
7131        let mut y = Array1::<f64>::zeros(n);
7132        for i in (0..n).rev() {
7133            let mut acc = a[[i, n]];
7134            for j in (i + 1)..n {
7135                acc -= a[[i, j]] * y[j];
7136            }
7137            y[i] = acc / a[[i, i]];
7138        }
7139        y
7140    }
7141
7142    /// Invert a 3x3 SPD matrix using the same cofactor formula as solve_3x3_spd.
7143    fn invert_3x3_spd(h: &Array2<f64>) -> Array2<f64> {
7144        let mut out = Array2::<f64>::zeros((3, 3));
7145        for col in 0..3 {
7146            let mut e = Array1::<f64>::zeros(3);
7147            e[col] = 1.0;
7148            let v = solve_3x3_spd(h, &e);
7149            for row in 0..3 {
7150                out[[row, col]] = v[row];
7151            }
7152        }
7153        out
7154    }
7155
7156    /// Invert a 3x3 unit-diagonal upper-triangular matrix exactly.
7157    fn invert_upper_triangular_3(q: &Array2<f64>) -> Array2<f64> {
7158        // Q is upper triangular with unit diagonal:
7159        //   [1  a  b]
7160        //   [0  1  c]
7161        //   [0  0  1]
7162        // Q⁻¹ = [[1, -a, ac-b], [0, 1, -c], [0, 0, 1]].
7163        let a = q[[0, 1]];
7164        let b = q[[0, 2]];
7165        let c = q[[1, 2]];
7166        array![[1.0, -a, a * c - b], [0.0, 1.0, -c], [0.0, 0.0, 1.0]]
7167    }
7168
7169    // ─── Coverage correction unit tests (Task #9) ─────────────────────────
7170
7171    /// Build a minimal Gaussian-identity fit (intercept-only design) with a
7172    /// non-zero variance on β so prediction returns a non-degenerate
7173    /// interval. Used to feed corrections without coupling to a fitter.
7174    fn coverage_correction_fixture() -> (UnifiedFitResult, Array2<f64>, Array1<f64>, Array1<f64>) {
7175        let beta = array![1.0];
7176        let cov = array![[0.25_f64]];
7177        let fit = test_fit_with_bias_correction(beta.clone(), cov.clone(), None);
7178        // Single batch row with x=1 (intercept).
7179        let x = array![[1.0_f64]];
7180        let offset = array![0.0_f64];
7181        (fit, x, beta, offset)
7182    }
7183
7184    fn corrections_baseline_options() -> PredictUncertaintyOptions {
7185        PredictUncertaintyOptions {
7186            confidence_level: 0.95,
7187            covariance_mode: InferenceCovarianceMode::Conditional,
7188            mean_interval_method: MeanIntervalMethod::TransformEta,
7189            includeobservation_interval: false,
7190            apply_bias_correction: false,
7191            // All four corrections OFF for the regression baseline.
7192            edgeworth_one_sided: false,
7193            boundary_correction: false,
7194            ood_inflation: false,
7195            multi_point_joint: false,
7196            ..PredictUncertaintyOptions::default()
7197        }
7198    }
7199
7200    #[test]
7201    fn coverage_corrections_all_off_matches_legacy() {
7202        // Regression baseline: with every correction OFF the output must
7203        // match the un-corrected interval exactly. Locks the legacy
7204        // semantics so we can detect accidental drift in the hot path.
7205        let (fit, x, beta, offset) = coverage_correction_fixture();
7206        let opts = corrections_baseline_options();
7207        let pred = predict_gamwith_uncertainty(
7208            x.view(),
7209            beta.view(),
7210            offset.view(),
7211            crate::types::LikelihoodFamily::GaussianIdentity,
7212            &fit,
7213            &opts,
7214        )
7215        .expect("prediction baseline");
7216
7217        let z = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7218        let expected_se = (0.25_f64).sqrt();
7219        assert!((pred.eta_standard_error[0] - expected_se).abs() <= 1e-12);
7220        let expected_lower = 1.0 - z * expected_se;
7221        let expected_upper = 1.0 + z * expected_se;
7222        assert!(
7223            (pred.eta_lower[0] - expected_lower).abs() <= 1e-12,
7224            "baseline lower drifted: got {}, expected {}",
7225            pred.eta_lower[0],
7226            expected_lower
7227        );
7228        assert!(
7229            (pred.eta_upper[0] - expected_upper).abs() <= 1e-12,
7230            "baseline upper drifted: got {}, expected {}",
7231            pred.eta_upper[0],
7232            expected_upper
7233        );
7234    }
7235
7236    #[test]
7237    fn edgeworth_one_sided_makes_interval_asymmetric_with_positive_skew() {
7238        let (fit, x, beta, offset) = coverage_correction_fixture();
7239        let mut opts = corrections_baseline_options();
7240        opts.edgeworth_one_sided = true;
7241        opts.eta_skewness_for_corrections = Some(array![0.6_f64]);
7242
7243        let pred = predict_gamwith_uncertainty(
7244            x.view(),
7245            beta.view(),
7246            offset.view(),
7247            crate::types::LikelihoodFamily::GaussianIdentity,
7248            &fit,
7249            &opts,
7250        )
7251        .expect("edgeworth prediction");
7252
7253        // Cornish–Fisher with κ₃ = 0.6, z ≈ 1.96: bump = (z²−1)·0.6/6 > 0
7254        // ⇒ z_upper > z_central > z_lower ⇒ upper tail moves further right
7255        // and the lower tail moves *closer* to η̂. Equivalently, the
7256        // (η_upper − η̂) > (η̂ − η_lower).
7257        let dist_upper = pred.eta_upper[0] - 1.0;
7258        let dist_lower = 1.0 - pred.eta_lower[0];
7259        assert!(
7260            dist_upper > dist_lower + 1e-9,
7261            "positive skew should push upper tail further than lower: \
7262             upper-dist={dist_upper}, lower-dist={dist_lower}"
7263        );
7264        // Skew = 0 must reduce to the symmetric interval (parity check).
7265        opts.eta_skewness_for_corrections = Some(array![0.0_f64]);
7266        let pred_sym = predict_gamwith_uncertainty(
7267            x.view(),
7268            beta.view(),
7269            offset.view(),
7270            crate::types::LikelihoodFamily::GaussianIdentity,
7271            &fit,
7272            &opts,
7273        )
7274        .expect("edgeworth zero-skew prediction");
7275        let sym_upper = pred_sym.eta_upper[0] - 1.0;
7276        let sym_lower = 1.0 - pred_sym.eta_lower[0];
7277        assert!((sym_upper - sym_lower).abs() <= 1e-12);
7278    }
7279
7280    #[test]
7281    fn boundary_correction_widens_interval_near_edge() {
7282        // Two query rows on a single axis with training support [0, 10].
7283        // Row 0 lies in the interior (x=5 ⇒ d_edge=5, well outside the
7284        // boundary band β·range=0.05·10=0.5). Row 1 is near the edge
7285        // (x=9.9 ⇒ d_edge=0.1, inside the band) and must receive a
7286        // strictly wider interval than the baseline.
7287        let beta = array![1.0_f64];
7288        let cov = array![[0.25_f64]];
7289        let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7290        let x = array![[1.0_f64], [1.0_f64]];
7291        let offset = array![0.0_f64, 0.0_f64];
7292
7293        let mut opts = corrections_baseline_options();
7294        opts.boundary_correction = true;
7295        opts.predictor_x_for_corrections = Some(array![[5.0_f64], [9.9_f64]]);
7296        opts.training_support = Some(TrainingSupport {
7297            axis_min: array![0.0_f64],
7298            axis_max: array![10.0_f64],
7299        });
7300
7301        let pred = predict_gamwith_uncertainty(
7302            x.view(),
7303            beta.view(),
7304            offset.view(),
7305            crate::types::LikelihoodFamily::GaussianIdentity,
7306            &fit,
7307            &opts,
7308        )
7309        .expect("boundary-corrected prediction");
7310
7311        let baseline_se = (0.25_f64).sqrt();
7312        // Interior row (x=5) is outside the boundary band ⇒ no inflation.
7313        assert!(
7314            (pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12,
7315            "interior row must not be inflated: {} vs {}",
7316            pred.eta_standard_error[0],
7317            baseline_se
7318        );
7319        // Near-edge row must have strictly higher SE.
7320        assert!(
7321            pred.eta_standard_error[1] > baseline_se + 1e-9,
7322            "near-edge row must be inflated: got {}, baseline {}",
7323            pred.eta_standard_error[1],
7324            baseline_se
7325        );
7326        // Direction: interval must be wider, not narrower.
7327        let width0 = pred.eta_upper[0] - pred.eta_lower[0];
7328        let width1 = pred.eta_upper[1] - pred.eta_lower[1];
7329        assert!(
7330            width1 > width0 + 1e-9,
7331            "near-edge interval not wider: width0={width0}, width1={width1}"
7332        );
7333    }
7334
7335    #[test]
7336    fn ood_inflation_widens_interval_outside_support() {
7337        let beta = array![1.0_f64];
7338        let cov = array![[0.25_f64]];
7339        let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7340        let x = array![[1.0_f64], [1.0_f64]];
7341        let offset = array![0.0_f64, 0.0_f64];
7342
7343        // Row 0: in-support (x=5). Row 1: well past the upper bound (x=15
7344        // outside [0, 10]).
7345        let mut opts = corrections_baseline_options();
7346        opts.ood_inflation = true;
7347        opts.predictor_x_for_corrections = Some(array![[5.0_f64], [15.0_f64]]);
7348        opts.training_support = Some(TrainingSupport {
7349            axis_min: array![0.0_f64],
7350            axis_max: array![10.0_f64],
7351        });
7352
7353        let pred = predict_gamwith_uncertainty(
7354            x.view(),
7355            beta.view(),
7356            offset.view(),
7357            crate::types::LikelihoodFamily::GaussianIdentity,
7358            &fit,
7359            &opts,
7360        )
7361        .expect("ood-inflated prediction");
7362
7363        let baseline_se = (0.25_f64).sqrt();
7364        assert!((pred.eta_standard_error[0] - baseline_se).abs() <= 1e-12);
7365        // Excess fraction = (15-10)/10 = 0.5 ⇒ factor = 1 + γ·0.25 with
7366        // default γ = 1 ⇒ 1.25 ⇒ se = sqrt(0.25·1.25) = sqrt(0.3125).
7367        let expected = (0.25_f64 * 1.25).sqrt();
7368        assert!(
7369            (pred.eta_standard_error[1] - expected).abs() <= 1e-12,
7370            "ood inflation factor wrong: got {}, expected {}",
7371            pred.eta_standard_error[1],
7372            expected
7373        );
7374        assert!(pred.eta_standard_error[1] > baseline_se);
7375    }
7376
7377    #[test]
7378    fn multi_point_joint_widens_interval_relative_to_per_row() {
7379        let beta = array![1.0_f64];
7380        let cov = array![[0.25_f64]];
7381        let fit = test_fit_with_bias_correction(beta.clone(), cov, None);
7382        // Five identical query rows; joint over m=5 must widen each
7383        // interval relative to the per-row baseline, by the Bonferroni z.
7384        let x = Array2::<f64>::from_elem((5, 1), 1.0_f64);
7385        let offset = Array1::zeros(5);
7386        let mut opts = corrections_baseline_options();
7387        opts.multi_point_joint = true;
7388        // Don't set joint_query_count so the helper uses batch size = 5.
7389
7390        let pred = predict_gamwith_uncertainty(
7391            x.view(),
7392            beta.view(),
7393            offset.view(),
7394            crate::types::LikelihoodFamily::GaussianIdentity,
7395            &fit,
7396            &opts,
7397        )
7398        .expect("joint-adjusted prediction");
7399
7400        let z_per_row = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7401        let z_joint = standard_normal_quantile(0.5 + 0.5 * (1.0 - 0.05_f64 / 5.0)).unwrap();
7402        assert!(
7403            z_joint > z_per_row + 1e-6,
7404            "Bonferroni z must exceed per-row z: joint={z_joint}, per-row={z_per_row}"
7405        );
7406        let baseline_se = (0.25_f64).sqrt();
7407        // Width per row should be 2·z_joint·se.
7408        for i in 0..5 {
7409            let width = pred.eta_upper[i] - pred.eta_lower[i];
7410            let expected = 2.0 * z_joint * baseline_se;
7411            assert!(
7412                (width - expected).abs() <= 1e-12,
7413                "joint row {i} width mismatch: got {width}, expected {expected}"
7414            );
7415        }
7416    }
7417
7418    #[test]
7419    fn edgeworth_helper_zero_skew_returns_central_z() {
7420        let z = 1.96_f64;
7421        let adj = edgeworth_one_sided_quantile(z, 0.0);
7422        assert!((adj.z_lower - z).abs() <= 1e-12);
7423        assert!((adj.z_upper - z).abs() <= 1e-12);
7424    }
7425
7426    #[test]
7427    fn boundary_helper_returns_one_in_interior() {
7428        let f = boundary_variance_inflation_factor(
7429            array![5.0_f64].view(),
7430            array![0.0_f64].view(),
7431            array![10.0_f64].view(),
7432            0.25,
7433            0.05,
7434        );
7435        assert!((f - 1.0).abs() <= 1e-12);
7436    }
7437
7438    #[test]
7439    fn ood_helper_returns_one_inside_box() {
7440        let f = ood_variance_inflation_factor(
7441            array![5.0_f64].view(),
7442            array![0.0_f64].view(),
7443            array![10.0_f64].view(),
7444            1.0,
7445        );
7446        assert!((f - 1.0).abs() <= 1e-12);
7447    }
7448
7449    #[test]
7450    fn multi_point_joint_z_passthrough_at_m_one() {
7451        let z1 = multi_point_joint_z(0.95, 1).unwrap();
7452        let z_baseline = standard_normal_quantile(0.5 + 0.5 * 0.95).unwrap();
7453        assert!((z1 - z_baseline).abs() <= 1e-12);
7454    }
7455}