Skip to main content

gam_models/inference/
predict_io.rs

1use crate::bms::{
2    EmpiricalZGrid, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
3    bernoulli_marginal_link_map, empirical_intercept_from_marginal,
4};
5use crate::marginal_slope_shared::{
6    ObservedDenestedCellPartials, eval_coeff4_at,
7    probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
8};
9use crate::survival::lognormal_kernel::FrailtySpec;
10use crate::inference::model::{SavedCompiledFlexBlock, SavedLatentZNormalization};
11use gam_linalg::matrix::DesignMatrix;
12use gam_math::probability::{normal_cdf, normal_pdf};
13use gam_solve::estimate::{EstimationError, UnifiedFitResult};
14use gam_problem::types::{InverseLink, LikelihoodSpec};
15use ndarray::{Array1, Array2, ArrayView1};
16use rayon::iter::{IntoParallelIterator, ParallelIterator};
17
18pub struct PredictResult {
19    pub eta: Array1<f64>,
20    pub mean: Array1<f64>,
21}
22
23/// Input to prediction routines. Contains the design matrix and metadata
24/// needed for point prediction plus uncertainty quantification.
25pub struct PredictInput {
26    /// Design matrix for the primary (mean/location) block.
27    pub design: DesignMatrix,
28    /// Offset vector for the primary block.
29    pub offset: Array1<f64>,
30    /// Optional design matrix for the noise/scale block (GAMLSS/survival).
31    pub design_noise: Option<DesignMatrix>,
32    /// Optional offset vector for the noise/scale block.
33    pub offset_noise: Option<Array1<f64>>,
34    /// Optional auxiliary scalar covariate used by specialized predictors.
35    pub auxiliary_scalar: Option<Array1<f64>>,
36    /// Optional auxiliary matrix used by specialized predictors.
37    pub auxiliary_matrix: Option<Array2<f64>>,
38}
39
40pub struct BernoulliMarginalSlopePredictor {
41    pub beta_marginal: Array1<f64>,
42    pub beta_logslope: Array1<f64>,
43    pub beta_score_warp: Option<Array1<f64>>,
44    pub beta_link_dev: Option<Array1<f64>>,
45    pub base_link: InverseLink,
46    pub z_column: String,
47    pub latent_z_normalization: SavedLatentZNormalization,
48    pub latent_measure: LatentMeasureKind,
49    pub baseline_marginal: f64,
50    pub baseline_logslope: f64,
51    pub covariance: Option<Array2<f64>>,
52    pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
53    pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
54    pub gaussian_frailty_sd: Option<f64>,
55    pub latent_z_calibration: Option<LatentZRankIntCalibration>,
56    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
57}
58
59fn prediction_chunk_rows(parameter_dim: usize, local_dim: usize, n_rows: usize) -> usize {
60    const PREDICTION_TARGET_WORK_BYTES: usize = 2 * 1024 * 1024;
61    const PREDICTION_MIN_CHUNK_ROWS: usize = 16;
62    const PREDICTION_MAX_CHUNK_ROWS: usize = 4096;
63    if n_rows == 0 {
64        return 1;
65    }
66    let bytes_per_row = parameter_dim
67        .max(1)
68        .saturating_mul(local_dim.max(1))
69        .saturating_mul(std::mem::size_of::<f64>())
70        .saturating_mul(4);
71    let target_rows = if bytes_per_row == 0 {
72        n_rows
73    } else {
74        PREDICTION_TARGET_WORK_BYTES / bytes_per_row
75    };
76    target_rows
77        .max(PREDICTION_MIN_CHUNK_ROWS)
78        .min(PREDICTION_MAX_CHUNK_ROWS)
79        .min(n_rows.max(1))
80}
81
82/// Per-runtime predict-time anchor correction matrices.
83///
84/// Built once per top-level predict call from the marginal + logslope
85/// designs at the prediction rows. Each `Array2<f64>` is shaped
86/// `n_predict × runtime.basis_dim` and holds `n_row(i) · M` for every
87/// prediction row, where `n_row(i)` is the concatenation of the marginal
88/// and logslope design rows in the runtime's anchor component order.
89///
90/// At any `local_cubic_at` / `basis_cubic_at` / `design` call site we
91/// subtract the appropriate slice of these matrices from the raw cubic
92/// output to apply the cross-block residual `n_row · M` correction.
93///
94/// `n_anchor_rows` is the underlying `n × d` parametric anchor stack
95/// (per-runtime layouts: score_warp gets `[marginal | logslope]`;
96/// link_dev gets `[marginal | logslope | score_warp_design(z)]` when the
97/// fit-time identifiability stage threaded the score-warp basis in as a
98/// flex-evaluation anchor). These layouts must match the column order
99/// `install_compiled_flex_block_into_runtime` used at fit time.
100#[derive(Default)]
101struct BmsAnchorCorrections {
102    /// `[marginal | logslope]` at predict rows. `Some` whenever any
103    /// runtime carries an anchor residual.
104    score_warp_anchor_rows: Option<Array2<f64>>,
105    /// `[marginal | logslope | score_warp_design(z)]` at predict rows.
106    /// `Some` whenever the link-deviation runtime carries an anchor
107    /// residual; the score-warp tail is included iff the saved
108    /// link-deviation runtime's residual components include a
109    /// `FlexEvaluation` entry.
110    link_dev_anchor_rows: Option<Array2<f64>>,
111    score_warp: Option<Array2<f64>>,
112    link_dev: Option<Array2<f64>>,
113}
114
115impl BmsAnchorCorrections {
116    fn score_warp_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
117        self.score_warp.as_ref().map(|m| m.row(row))
118    }
119
120    fn link_dev_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
121        self.link_dev.as_ref().map(|m| m.row(row))
122    }
123
124    fn score_warp_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
125        self.score_warp_anchor_rows.as_ref().map(|m| m.view())
126    }
127
128    fn link_dev_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
129        self.link_dev_anchor_rows.as_ref().map(|m| m.view())
130    }
131}
132
133impl BernoulliMarginalSlopePredictor {
134    /// Build the anchor correction matrices for a given predict-input batch.
135    ///
136    /// Returns an empty bundle (all `None`) when neither runtime carries
137    /// an anchor residual — this is the fast path for fits without
138    /// cross-block residualisation. When at least one runtime has a
139    /// residual, materialises the marginal + logslope designs at the
140    /// predict rows once and computes the per-runtime correction matrices
141    /// against each runtime's stored `M`.
142    fn build_anchor_correction_matrices(
143        &self,
144        input: &PredictInput,
145        design_logslope: &DesignMatrix,
146        z: &Array1<f64>,
147    ) -> Result<BmsAnchorCorrections, EstimationError> {
148        use crate::inference::model::SavedAnchorKind;
149        let needs_score = self
150            .score_warp_runtime
151            .as_ref()
152            .is_some_and(|r| r.anchor_correction.is_some());
153        let needs_link = self
154            .link_deviation_runtime
155            .as_ref()
156            .is_some_and(|r| r.anchor_correction.is_some());
157        if !needs_score && !needs_link {
158            return Ok(BmsAnchorCorrections::default());
159        }
160        // Materialise the marginal + logslope designs at predict rows.
161        // For large-scale predict batches the caller already chunks via
162        // `prediction_chunk_rows`, so this densification is bounded per
163        // chunk by `chunk_size × (p_marginal + p_logslope)`.
164        let marginal_dense = input
165            .design
166            .try_to_dense_arc(
167                "bernoulli marginal-slope predict-time marginal anchor materialisation",
168            )
169            .map_err(EstimationError::InvalidInput)?;
170        let logslope_dense = design_logslope
171            .try_to_dense_arc(
172                "bernoulli marginal-slope predict-time logslope anchor materialisation",
173            )
174            .map_err(EstimationError::InvalidInput)?;
175        let n_rows = marginal_dense.nrows();
176        if logslope_dense.nrows() != n_rows {
177            return Err(EstimationError::InvalidInput(format!(
178                "bernoulli marginal-slope predict anchor materialisation row mismatch: marginal {} vs logslope {}",
179                n_rows,
180                logslope_dense.nrows()
181            )));
182        }
183        if z.len() != n_rows {
184            return Err(EstimationError::InvalidInput(format!(
185                "bernoulli marginal-slope predict anchor materialisation: z has {} entries, expected {}",
186                z.len(),
187                n_rows
188            )));
189        }
190        let p_marginal = marginal_dense.ncols();
191        let p_logslope = logslope_dense.ncols();
192        let d_parametric = p_marginal + p_logslope;
193        let mut parametric_rows = Array2::<f64>::zeros((n_rows, d_parametric));
194        parametric_rows
195            .slice_mut(ndarray::s![.., 0..p_marginal])
196            .assign(&marginal_dense.view());
197        parametric_rows
198            .slice_mut(ndarray::s![.., p_marginal..d_parametric])
199            .assign(&logslope_dense.view());
200
201        // Score-warp anchor layout is `[marginal | logslope]` (parametric
202        // only; flex-flex anchoring goes the other direction).
203        let score_warp = if needs_score {
204            let runtime = self.score_warp_runtime.as_ref().unwrap();
205            self.validate_runtime_anchor_layout_parametric_only(runtime, "score_warp")?;
206            runtime
207                .anchor_correction_matrix(parametric_rows.view())
208                .map_err(EstimationError::from)?
209        } else {
210            None
211        };
212
213        // Link-deviation anchor layout matches the fit-time stacking in
214        // `install_compiled_flex_block_into_runtime`: parametric
215        // columns first, then (if a FlexEvaluation component is present)
216        // the score-warp runtime's reparameterised basis at predict rows.
217        let (link_dev_anchor_rows, link_dev) = if needs_link {
218            let runtime = self.link_deviation_runtime.as_ref().unwrap();
219            // Determine whether the saved link-dev residual carries a
220            // FlexEvaluation tail and validate ordering matches the
221            // fit-time invariant (all parametric components first, then
222            // at most one FlexEvaluation tail).
223            let mut saw_flex_tail = false;
224            let mut flex_tail_ncols: usize = 0;
225            for (idx, component) in runtime.anchor_components.iter().enumerate() {
226                match &component.kind {
227                    SavedAnchorKind::Parametric { .. } => {
228                        if saw_flex_tail {
229                            return Err(EstimationError::InvalidInput(format!(
230                                "bernoulli marginal-slope link-deviation saved anchor components \
231                                 are out of order: parametric component at index {idx} follows \
232                                 a FlexEvaluation tail",
233                            )));
234                        }
235                    }
236                    SavedAnchorKind::FlexEvaluation { ncols } => {
237                        if saw_flex_tail {
238                            return Err(EstimationError::InvalidInput(
239                                "bernoulli marginal-slope link-deviation saved anchor components \
240                                 carry more than one FlexEvaluation tail; fit-time stacking emits \
241                                 at most one (score-warp)"
242                                    .to_string(),
243                            ));
244                        }
245                        saw_flex_tail = true;
246                        flex_tail_ncols = *ncols;
247                    }
248                }
249            }
250            let rows = if saw_flex_tail {
251                let score_runtime = self.score_warp_runtime.as_ref().ok_or_else(|| {
252                    EstimationError::InvalidInput(
253                        "bernoulli marginal-slope link-deviation saved anchor includes a \
254                         FlexEvaluation tail but the saved score-warp runtime is missing"
255                            .to_string(),
256                    )
257                })?;
258                // Evaluate the score-warp runtime at predict-row z. When
259                // the score-warp itself carries an anchor residual, route
260                // through `design_with_anchor_rows` so the per-row
261                // subtraction is applied; otherwise the raw `design(z)`
262                // is the reparameterised basis.
263                let score_basis = if score_runtime.anchor_correction.is_some() {
264                    score_runtime
265                        .design_with_anchor_rows(z, parametric_rows.view())
266                        .map_err(EstimationError::from)?
267                } else {
268                    score_runtime.design(z).map_err(EstimationError::from)?
269                };
270                if score_basis.ncols() != flex_tail_ncols {
271                    return Err(EstimationError::InvalidInput(format!(
272                        "bernoulli marginal-slope link-deviation FlexEvaluation tail expects \
273                         {} score-warp basis columns at predict rows, got {}",
274                        flex_tail_ncols,
275                        score_basis.ncols()
276                    )));
277                }
278                let mut combined = Array2::<f64>::zeros((n_rows, d_parametric + flex_tail_ncols));
279                combined
280                    .slice_mut(ndarray::s![.., 0..d_parametric])
281                    .assign(&parametric_rows.view());
282                combined
283                    .slice_mut(ndarray::s![.., d_parametric..])
284                    .assign(&score_basis.view());
285                combined
286            } else {
287                parametric_rows.clone()
288            };
289            let corr = runtime
290                .anchor_correction_matrix(rows.view())
291                .map_err(EstimationError::from)?;
292            (Some(rows), corr)
293        } else {
294            (None, None)
295        };
296
297        Ok(BmsAnchorCorrections {
298            score_warp_anchor_rows: Some(parametric_rows),
299            link_dev_anchor_rows,
300            score_warp,
301            link_dev,
302        })
303    }
304
305    /// Validate that a saved deviation runtime's anchor residual contains
306    /// only `Parametric` components (no `FlexEvaluation` tail). Used for
307    /// the score-warp runtime, whose fit-time stacking is parametric-only.
308    fn validate_runtime_anchor_layout_parametric_only(
309        &self,
310        runtime: &SavedCompiledFlexBlock,
311        runtime_label: &str,
312    ) -> Result<(), EstimationError> {
313        use crate::inference::model::SavedAnchorKind;
314        for (idx, component) in runtime.anchor_components.iter().enumerate() {
315            match &component.kind {
316                SavedAnchorKind::Parametric { .. } => {}
317                SavedAnchorKind::FlexEvaluation { .. } => {
318                    return Err(EstimationError::InvalidInput(format!(
319                        "bernoulli marginal-slope {runtime_label} saved anchor component at \
320                         index {idx} is FlexEvaluation; only Parametric components are \
321                         expected for this runtime",
322                    )));
323                }
324            }
325        }
326        Ok(())
327    }
328
329    pub fn likelihood_family(&self) -> LikelihoodSpec {
330        LikelihoodSpec::binomial_probit()
331    }
332
333    pub fn mean_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
334        Ok(eta.mapv(normal_cdf))
335    }
336
337    pub fn mean_derivative_from_eta(
338        &self,
339        eta: &Array1<f64>,
340    ) -> Result<Array1<f64>, EstimationError> {
341        Ok(eta.mapv(normal_pdf))
342    }
343
344    pub(crate) fn probit_frailty_scale(&self) -> f64 {
345        marginal_slope_probit_frailty_scale(self.gaussian_frailty_sd)
346    }
347
348    /// Apply the (optional) rank-INT latent-z calibration to a batch of
349    /// normalized predict-time z values.
350    ///
351    /// The calibration was fit on the training z + weights as a Blom-
352    /// rankit weighted rank inverse-normal transform; the calibrated
353    /// sample is N(0, 1) by construction (exact, not approximate), which
354    /// is why the BMS standard-normal closed-form kernel is correct on
355    /// the calibrated scale. At predict time, every z that flows into a
356    /// kernel evaluation site (`final_eta_and_gradient_from_theta`,
357    /// `predict_eta_and_q_chain`, and indirectly the per-row `solve_intercept_scalar`
358    /// / `evaluate_prediction_calibration` / `observed_denested_cell_partials_at_z`
359    /// helpers that consume per-row scalar z values from the closure-
360    /// captured `z` array) must be routed through the same monotone
361    /// transform. When `latent_z_calibration` is `None`, this returns
362    /// the input unchanged — that case corresponds to training-time z
363    /// having passed the strict normality check, so no transform was
364    /// applied at fit time either.
365    fn apply_latent_z_calibration(&self, z: &Array1<f64>) -> Array1<f64> {
366        match &self.latent_z_calibration {
367            Some(cal) => Array1::from_iter(z.iter().map(|&zi| cal.apply_at_predict(zi))),
368            None => z.clone(),
369        }
370    }
371
372    /// Apply the (optional) conditional location-scale latent-z calibration
373    /// (#905) to a batch of normalized predict-time z values.
374    ///
375    /// When `Some`, training detected a conditional `E[z|C]`/`Var(z|C)` shift
376    /// and replaced its latent score by `ζ = (z − m(C))/√v(C)`. The predictor
377    /// MUST apply the identical map, rebuilding the conditioning span `a(C)`
378    /// from the marginal prediction design (`input.design`) — the same span the
379    /// fit regressed z on. `None` ⇒ no conditional calibration was applied at
380    /// fit time, so z passes through unchanged (mutually exclusive with the
381    /// rank-INT calibration above).
382    fn apply_latent_z_conditional_calibration(
383        &self,
384        z: &Array1<f64>,
385        input: &PredictInput,
386    ) -> Result<Array1<f64>, EstimationError> {
387        let Some(cal) = self.latent_z_conditional_calibration.as_ref() else {
388            return Ok(z.clone());
389        };
390        let a_block = input.design.to_dense();
391        cal.apply(z.view(), a_block.view())
392            .map_err(EstimationError::InvalidInput)
393    }
394
395    fn rigid_intercept_from_marginal(&self, marginal_eta: f64, slope: f64) -> f64 {
396        let probit_scale = self.probit_frailty_scale();
397        marginal_eta * (1.0 + (probit_scale * slope).powi(2)).sqrt() / probit_scale
398    }
399
400    fn empirical_rigid_intercept_and_gradient(
401        &self,
402        marginal_eta: f64,
403        slope: f64,
404        nodes: &[f64],
405        weights: &[f64],
406    ) -> Result<(f64, f64, f64), EstimationError> {
407        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
408            .map_err(EstimationError::InvalidInput)?;
409        let scale = self.probit_frailty_scale();
410        let intercept = empirical_intercept_from_marginal(
411            marginal.mu,
412            marginal.q,
413            slope,
414            scale,
415            nodes,
416            weights,
417            None,
418        )
419        .map_err(EstimationError::InvalidInput)?;
420        let observed_slope = scale * slope;
421        let mut f_a = 0.0;
422        let mut f_b = 0.0;
423        for (&node, &weight) in nodes.iter().zip(weights.iter()) {
424            let eta = intercept + observed_slope * node;
425            let pdf = normal_pdf(eta);
426            f_a += weight * pdf;
427            f_b += weight * pdf * scale * node;
428        }
429        if !(f_a.is_finite() && f_a > 0.0 && f_b.is_finite()) {
430            return Err(EstimationError::InvalidInput(format!(
431                "empirical latent prediction calibration derivative is invalid: F_a={f_a}, F_b={f_b}"
432            )));
433        }
434        let a_marginal_eta = marginal.mu1 / f_a;
435        let a_slope = -f_b / f_a;
436        Ok((intercept, a_marginal_eta, a_slope))
437    }
438
439    fn local_empirical_mixture_for_point(
440        point: &[f64],
441        centers: &[Vec<f64>],
442        top_k: usize,
443        bandwidth: f64,
444    ) -> Result<Vec<(usize, f64)>, EstimationError> {
445        if centers.is_empty() {
446            return Err(EstimationError::InvalidInput(
447                "local empirical latent prediction has no centers".to_string(),
448            ));
449        }
450        if top_k == 0 {
451            return Err(EstimationError::InvalidInput(
452                "local empirical latent prediction top_k must be positive".to_string(),
453            ));
454        }
455        if !(bandwidth.is_finite() && bandwidth > 0.0) {
456            return Err(EstimationError::InvalidInput(format!(
457                "local empirical latent prediction bandwidth must be finite and positive, got {bandwidth}"
458            )));
459        }
460        let bw2 = bandwidth * bandwidth;
461        let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
462        for (idx, center) in centers.iter().enumerate() {
463            if center.len() != point.len() {
464                return Err(EstimationError::InvalidInput(format!(
465                    "local empirical latent prediction center {idx} dimension mismatch: center={}, point={}",
466                    center.len(),
467                    point.len()
468                )));
469            }
470            let d2 = center
471                .iter()
472                .zip(point.iter())
473                .map(|(&c, &x)| {
474                    let delta = x - c;
475                    delta * delta
476                })
477                .sum::<f64>();
478            if !d2.is_finite() {
479                return Err(EstimationError::InvalidInput(
480                    "local empirical latent prediction distance is non-finite".to_string(),
481                ));
482            }
483            distances.push((idx, d2));
484        }
485        distances.sort_by(|left, right| {
486            left.1
487                .partial_cmp(&right.1)
488                .expect("validated local empirical distances are finite")
489        });
490        let k = top_k.min(distances.len());
491        let mut mixture = Vec::with_capacity(k);
492        let mut total = 0.0;
493        for &(idx, d2) in distances.iter().take(k) {
494            let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
495            mixture.push((idx, weight));
496            total += weight;
497        }
498        if !(total.is_finite() && total > 0.0) {
499            return Err(EstimationError::InvalidInput(
500                "local empirical latent prediction mixture has non-positive total weight"
501                    .to_string(),
502            ));
503        }
504        for (_, weight) in &mut mixture {
505            *weight /= total;
506        }
507        Ok(mixture)
508    }
509
510    fn combine_empirical_grids(
511        grids: &[EmpiricalZGrid],
512        mixture: &[(usize, f64)],
513    ) -> Result<EmpiricalZGrid, EstimationError> {
514        let total_len = mixture
515            .iter()
516            .map(|&(idx, _)| grids.get(idx).map_or(0, |grid| grid.nodes.len()))
517            .sum::<usize>();
518        let mut nodes = Vec::with_capacity(total_len);
519        let mut weights = Vec::with_capacity(total_len);
520        let mut total_weight = 0.0;
521        for &(grid_idx, grid_weight) in mixture {
522            if !(grid_weight.is_finite() && grid_weight >= 0.0) {
523                return Err(EstimationError::InvalidInput(format!(
524                    "local empirical latent prediction mixture weight must be finite and non-negative, got {grid_weight}"
525                )));
526            }
527            let grid = grids.get(grid_idx).ok_or_else(|| {
528                EstimationError::InvalidInput(format!(
529                    "local empirical latent prediction grid index {grid_idx} is out of bounds for {} grids",
530                    grids.len()
531                ))
532            })?;
533            if grid.nodes.len() != grid.weights.len() || grid.nodes.is_empty() {
534                return Err(EstimationError::InvalidInput(format!(
535                    "local empirical latent prediction grid {grid_idx} is invalid: nodes={}, weights={}",
536                    grid.nodes.len(),
537                    grid.weights.len()
538                )));
539            }
540            for (node, weight) in grid.pairs() {
541                let combined_weight = grid_weight * weight;
542                if !(node.is_finite() && combined_weight.is_finite() && combined_weight >= 0.0) {
543                    return Err(EstimationError::InvalidInput(
544                        "local empirical latent prediction grid contains invalid node/weight"
545                            .to_string(),
546                    ));
547                }
548                nodes.push(node);
549                weights.push(combined_weight);
550                total_weight += combined_weight;
551            }
552        }
553        if !(total_weight.is_finite() && total_weight > 0.0) {
554            return Err(EstimationError::InvalidInput(
555                "local empirical latent prediction combined grid has non-positive total weight"
556                    .to_string(),
557            ));
558        }
559        for weight in &mut weights {
560            *weight /= total_weight;
561        }
562        Ok(EmpiricalZGrid { nodes, weights })
563    }
564
565    fn empirical_grid_for_prediction_row(
566        &self,
567        input: &PredictInput,
568        row: usize,
569    ) -> Result<Option<EmpiricalZGrid>, EstimationError> {
570        match &self.latent_measure {
571            LatentMeasureKind::StandardNormal => Ok(None),
572            LatentMeasureKind::GlobalEmpirical { grid } => Ok(Some(grid.clone())),
573            LatentMeasureKind::LocalEmpirical {
574                centers,
575                grids,
576                top_k,
577                bandwidth,
578                ..
579            } => {
580                let conditioning = input.auxiliary_matrix.as_ref().ok_or_else(|| {
581                    EstimationError::InvalidInput(
582                        "bernoulli marginal-slope local empirical prediction requires auxiliary conditioning matrix"
583                            .to_string(),
584                    )
585                })?;
586                if row >= conditioning.nrows() {
587                    return Err(EstimationError::InvalidInput(format!(
588                        "local empirical latent prediction row {row} is out of bounds for {} conditioning rows",
589                        conditioning.nrows()
590                    )));
591                }
592                let expected_dim = centers.first().map_or(0, Vec::len);
593                if conditioning.ncols() != expected_dim {
594                    return Err(EstimationError::InvalidInput(format!(
595                        "local empirical latent prediction conditioning dimension mismatch: got {}, expected {expected_dim}",
596                        conditioning.ncols()
597                    )));
598                }
599                let point = conditioning.row(row).to_vec();
600                let mixture =
601                    Self::local_empirical_mixture_for_point(&point, centers, *top_k, *bandwidth)?;
602                Self::combine_empirical_grids(grids, &mixture).map(Some)
603            }
604        }
605    }
606
607    fn transform_internal_eta_to_base_scale(
608        &self,
609        internal_eta: Array1<f64>,
610        internal_grad: Option<Array2<f64>>,
611    ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
612        Ok((internal_eta, internal_grad))
613    }
614
615    fn link_terms_value_d1(
616        &self,
617        eta0: &Array1<f64>,
618        beta_link_dev: Option<&Array1<f64>>,
619        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
620    ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
621        if let (Some(runtime), Some(beta)) = (&self.link_deviation_runtime, beta_link_dev) {
622            // When the runtime carries a cross-block anchor residual, every
623            // raw-design row needs `n_row · M` subtracted. `correction_for_row`
624            // already holds the precomputed `n_row · M` for this predict row
625            // (length basis_dim), so the corrected basis contribution to η
626            // is `basis · beta - correction.dot(beta)` for every eta0 entry.
627            // Derivative paths are unaffected (the anchor argument is a
628            // different scalar than eta0).
629            let basis = runtime
630                .design_uncorrected(eta0)
631                .map_err(EstimationError::from)?;
632            let mut value = &basis.dot(beta) + eta0;
633            if let Some(corr) = link_dev_correction_for_row {
634                let offset = corr.dot(beta);
635                for v in value.iter_mut() {
636                    *v -= offset;
637                }
638            } else if runtime.anchor_correction.is_some() {
639                return Err(EstimationError::InvalidInput(
640                    "bernoulli marginal-slope link-deviation runtime has an anchor residual but \
641                     no per-row correction was supplied to link_terms_value_d1"
642                        .to_string(),
643                ));
644            }
645            let d1 = runtime
646                .first_derivative_design(eta0)
647                .map_err(EstimationError::from)?;
648            Ok((value, d1.dot(beta) + 1.0))
649        } else {
650            Ok((eta0.clone(), Array1::ones(eta0.len())))
651        }
652    }
653
654    fn denested_partition_cells(
655        &self,
656        a: f64,
657        b: f64,
658        beta_score_warp: Option<&Array1<f64>>,
659        beta_link_dev: Option<&Array1<f64>>,
660        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
661        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
662    ) -> Result<Vec<crate::cubic_cell_kernel::DenestedPartitionCell>, EstimationError> {
663        let score_breaks = if let Some(runtime) = self.score_warp_runtime.as_ref() {
664            runtime.breakpoints().map_err(EstimationError::from)?
665        } else {
666            Vec::new()
667        };
668        let link_breaks = if let Some(runtime) = self.link_deviation_runtime.as_ref() {
669            runtime.breakpoints().map_err(EstimationError::from)?
670        } else {
671            Vec::new()
672        };
673        let mut cells =
674            crate::cubic_cell_kernel::build_denested_partition_cells_with_tails(
675                a,
676                b,
677                &score_breaks,
678                &link_breaks,
679                |z| {
680                    if let (Some(runtime), Some(beta)) =
681                        (self.score_warp_runtime.as_ref(), beta_score_warp)
682                    {
683                        let mut span = runtime.local_cubic_at(beta, z)?;
684                        // `local_cubic_at`'s c0 is `Σ_j basis_c0[span][j] · beta[j]`.
685                        // The cross-block residual replaces basis_c0 by
686                        // basis_c0 − n_row · M, contributing a row-constant
687                        // `correction.dot(beta)` to c0. Higher coefficients
688                        // (c1..c3) depend on derivatives of the basis w.r.t.
689                        // its own argument and are untouched.
690                        if let Some(corr) = score_warp_correction_for_row {
691                            span.c0 -= corr.dot(beta);
692                        }
693                        Ok(span)
694                    } else {
695                        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
696                            left: 0.0,
697                            right: 1.0,
698                            c0: 0.0,
699                            c1: 0.0,
700                            c2: 0.0,
701                            c3: 0.0,
702                        })
703                    }
704                },
705                |u| {
706                    if let (Some(runtime), Some(beta)) =
707                        (self.link_deviation_runtime.as_ref(), beta_link_dev)
708                    {
709                        let mut span = runtime.local_cubic_at(beta, u)?;
710                        if let Some(corr) = link_dev_correction_for_row {
711                            span.c0 -= corr.dot(beta);
712                        }
713                        Ok(span)
714                    } else {
715                        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
716                            left: 0.0,
717                            right: 1.0,
718                            c0: 0.0,
719                            c1: 0.0,
720                            c2: 0.0,
721                            c3: 0.0,
722                        })
723                    }
724                },
725            )
726            .map_err(EstimationError::InvalidInput)?;
727        let scale = self.probit_frailty_scale();
728        if scale != 1.0 {
729            for partition_cell in &mut cells {
730                partition_cell.cell.c0 *= scale;
731                partition_cell.cell.c1 *= scale;
732                partition_cell.cell.c2 *= scale;
733                partition_cell.cell.c3 *= scale;
734            }
735        }
736        Ok(cells)
737    }
738
739    fn evaluate_denested_calibration(
740        &self,
741        a: f64,
742        marginal_eta: f64,
743        slope: f64,
744        beta_score_warp: Option<&Array1<f64>>,
745        beta_link_dev: Option<&Array1<f64>>,
746        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
747        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
748    ) -> Result<(f64, f64, f64), EstimationError> {
749        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
750            .map_err(EstimationError::InvalidInput)?;
751        let cells = self.denested_partition_cells(
752            a,
753            slope,
754            beta_score_warp,
755            beta_link_dev,
756            score_warp_correction_for_row,
757            link_dev_correction_for_row,
758        )?;
759        let scale = self.probit_frailty_scale();
760        let mut f = -marginal.mu;
761        let mut f_a = 0.0;
762        let mut f_aa = 0.0;
763        for partition_cell in cells {
764            let cell = partition_cell.cell;
765            let (dc_da_raw, _) =
766                crate::cubic_cell_kernel::denested_cell_coefficient_partials(
767                    partition_cell.score_span,
768                    partition_cell.link_span,
769                    a,
770                    slope,
771                );
772            let (d2c_da2_raw, _, _) =
773                crate::cubic_cell_kernel::denested_cell_second_partials(
774                    partition_cell.score_span,
775                    partition_cell.link_span,
776                    a,
777                    slope,
778                );
779            let dc_da = scale_coeff4(dc_da_raw, scale);
780            let d2c_da2 = scale_coeff4(d2c_da2_raw, scale);
781            // Derive the moment `max_degree` from the contractions consumed
782            // below, instead of hardcoding a magic constant. The second-
783            // derivative contraction dominates the first-derivative one, so
784            // its required degree is the binding bound. Hardcoding 7 here
785            // produced 8 moments while the contraction needs 10 (#321).
786            let max_degree =
787                crate::cubic_cell_kernel::cell_second_derivative_required_max_degree(
788                    &dc_da, &dc_da, &d2c_da2,
789                );
790            let state = crate::cubic_cell_kernel::evaluate_cell_moments(cell, max_degree)
791                .map_err(EstimationError::InvalidInput)?;
792            f += state.value;
793            f_a += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
794                &dc_da,
795                &state.moments,
796            )
797            .map_err(EstimationError::InvalidInput)?;
798            f_aa += crate::cubic_cell_kernel::cell_second_derivative_from_moments(
799                cell,
800                &dc_da,
801                &dc_da,
802                &d2c_da2,
803                &state.moments,
804            )
805            .map_err(EstimationError::InvalidInput)?;
806        }
807        Ok((f, f_a, f_aa))
808    }
809
810    fn observed_denested_cell_partials_at_z(
811        &self,
812        z_value: f64,
813        a: f64,
814        b: f64,
815        beta_score_warp: Option<&Array1<f64>>,
816        beta_link_dev: Option<&Array1<f64>>,
817        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
818        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
819    ) -> Result<ObservedDenestedCellPartials, EstimationError> {
820        use crate::cubic_cell_kernel as exact;
821
822        let zero_span = exact::LocalSpanCubic {
823            left: 0.0,
824            right: 1.0,
825            c0: 0.0,
826            c1: 0.0,
827            c2: 0.0,
828            c3: 0.0,
829        };
830        let u_value = a + b * z_value;
831        let score_span = if let (Some(runtime), Some(beta)) =
832            (self.score_warp_runtime.as_ref(), beta_score_warp)
833        {
834            let mut span = runtime
835                .local_cubic_at(beta, z_value)
836                .map_err(EstimationError::from)?;
837            if let Some(corr) = score_warp_correction_for_row {
838                span.c0 -= corr.dot(beta);
839            }
840            span
841        } else {
842            zero_span
843        };
844        let link_span = if let (Some(runtime), Some(beta)) =
845            (self.link_deviation_runtime.as_ref(), beta_link_dev)
846        {
847            let mut span = runtime
848                .local_cubic_at(beta, u_value)
849                .map_err(EstimationError::from)?;
850            if let Some(corr) = link_dev_correction_for_row {
851                span.c0 -= corr.dot(beta);
852            }
853            span
854        } else {
855            zero_span
856        };
857        let scale = self.probit_frailty_scale();
858        let coeff = scale_coeff4(
859            exact::denested_cell_coefficients(score_span, link_span, a, b),
860            scale,
861        );
862        let (dc_da_raw, dc_db_raw) =
863            exact::denested_cell_coefficient_partials(score_span, link_span, a, b);
864        let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
865            exact::denested_cell_second_partials(score_span, link_span, a, b);
866        let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) = exact::denested_cell_third_partials(link_span);
867        Ok(ObservedDenestedCellPartials {
868            coeff,
869            dc_da: scale_coeff4(dc_da_raw, scale),
870            dc_db: scale_coeff4(dc_db_raw, scale),
871            dc_daa: scale_coeff4(dc_daa_raw, scale),
872            dc_dab: scale_coeff4(dc_dab_raw, scale),
873            dc_dbb: scale_coeff4(dc_dbb_raw, scale),
874            dc_daaa: scale_coeff4(dc_daaa, scale),
875            dc_daab: scale_coeff4(dc_daab, scale),
876            dc_dabb: scale_coeff4(dc_dabb, scale),
877            dc_dbbb: scale_coeff4(dc_dbbb, scale),
878        })
879    }
880
881    fn evaluate_empirical_denested_calibration(
882        &self,
883        a: f64,
884        marginal_eta: f64,
885        slope: f64,
886        beta_score_warp: Option<&Array1<f64>>,
887        beta_link_dev: Option<&Array1<f64>>,
888        grid: &EmpiricalZGrid,
889        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
890        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
891    ) -> Result<(f64, f64, f64), EstimationError> {
892        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
893            .map_err(EstimationError::InvalidInput)?;
894        let mut f = -marginal.mu;
895        let mut f_a = 0.0;
896        let mut f_aa = 0.0;
897        for (node, weight) in grid.pairs() {
898            let obs = self.observed_denested_cell_partials_at_z(
899                node,
900                a,
901                slope,
902                beta_score_warp,
903                beta_link_dev,
904                score_warp_correction_for_row,
905                link_dev_correction_for_row,
906            )?;
907            let eta = eval_coeff4_at(&obs.coeff, node);
908            let eta_a = eval_coeff4_at(&obs.dc_da, node);
909            let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
910            let pdf = normal_pdf(eta);
911            f += weight * normal_cdf(eta);
912            f_a += weight * pdf * eta_a;
913            f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
914        }
915        Ok((f, f_a, f_aa))
916    }
917
918    fn evaluate_prediction_calibration(
919        &self,
920        a: f64,
921        marginal_eta: f64,
922        slope: f64,
923        beta_score_warp: Option<&Array1<f64>>,
924        beta_link_dev: Option<&Array1<f64>>,
925        empirical_grid: Option<&EmpiricalZGrid>,
926        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
927        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
928    ) -> Result<(f64, f64, f64), EstimationError> {
929        if let Some(grid) = empirical_grid {
930            self.evaluate_empirical_denested_calibration(
931                a,
932                marginal_eta,
933                slope,
934                beta_score_warp,
935                beta_link_dev,
936                grid,
937                score_warp_correction_for_row,
938                link_dev_correction_for_row,
939            )
940        } else {
941            self.evaluate_denested_calibration(
942                a,
943                marginal_eta,
944                slope,
945                beta_score_warp,
946                beta_link_dev,
947                score_warp_correction_for_row,
948                link_dev_correction_for_row,
949            )
950        }
951    }
952
953    pub fn from_unified(
954        unified: &UnifiedFitResult,
955        z_column: String,
956        latent_z_normalization: SavedLatentZNormalization,
957        latent_measure: LatentMeasureKind,
958        baseline_marginal: f64,
959        baseline_logslope: f64,
960        base_link: InverseLink,
961        frailty: FrailtySpec,
962        score_warp_runtime: Option<SavedCompiledFlexBlock>,
963        link_deviation_runtime: Option<SavedCompiledFlexBlock>,
964        latent_z_calibration: Option<crate::bms::LatentZRankIntCalibration>,
965        latent_z_conditional_calibration: Option<crate::bms::LatentZConditionalCalibration>,
966    ) -> Result<Self, String> {
967        let gaussian_frailty_sd = match frailty {
968            FrailtySpec::None => None,
969            FrailtySpec::GaussianShift {
970                sigma_fixed: Some(sigma),
971            } => Some(sigma),
972            FrailtySpec::GaussianShift { sigma_fixed: None } => {
973                return Err(
974                    "bernoulli marginal-slope predictor requires a fixed GaussianShift sigma"
975                        .to_string(),
976                );
977            }
978            FrailtySpec::HazardMultiplier { .. } => {
979                return Err(
980                    "bernoulli marginal-slope predictor does not support HazardMultiplier frailty"
981                        .to_string(),
982                );
983            }
984        };
985        if !matches!(
986            base_link,
987            InverseLink::Standard(gam_problem::types::StandardLink::Probit)
988        ) {
989            return Err(
990                "bernoulli marginal-slope predictor requires link(type=probit); saved non-probit marginal-slope models must be refit"
991                    .to_string(),
992            );
993        }
994        if let Some(runtime) = score_warp_runtime.as_ref() {
995            runtime.validate_exact_replay_contract().map_err(|e| {
996                format!("bernoulli marginal-slope score-warp runtime is invalid: {e}")
997            })?;
998        }
999        if let Some(runtime) = link_deviation_runtime.as_ref() {
1000            runtime.validate_exact_replay_contract().map_err(|e| {
1001                format!("bernoulli marginal-slope link-deviation runtime is invalid: {e}")
1002            })?;
1003        }
1004        // Cross-block anchor residuals on either runtime are now applied
1005        // per-row by every predict-time `local_cubic_at` / `basis_cubic_at`
1006        // / `design` call site via `build_anchor_correction_matrices`.
1007        latent_z_normalization
1008            .validate("bernoulli marginal-slope predictor")
1009            .map_err(|e| {
1010                format!("bernoulli marginal-slope predictor latent z normalization is invalid: {e}")
1011            })?;
1012        latent_measure
1013            .validate("bernoulli marginal-slope predictor latent measure")
1014            .map_err(|e| {
1015                format!("bernoulli marginal-slope predictor latent measure is invalid: {e}")
1016            })?;
1017        let blocks = &unified.blocks;
1018        let expected_blocks = 2
1019            + usize::from(score_warp_runtime.is_some())
1020            + usize::from(link_deviation_runtime.is_some());
1021        if blocks.len() != expected_blocks {
1022            return Err(format!(
1023                "bernoulli marginal-slope predictor requires exactly {expected_blocks} coefficient blocks under the current exact de-nested semantics, got {}",
1024                blocks.len()
1025            ));
1026        }
1027        let mut cursor = 2usize;
1028        let beta_score_warp = if score_warp_runtime.is_some() {
1029            let beta = blocks
1030                .get(cursor)
1031                .ok_or_else(|| "missing score-warp coefficient block".to_string())?
1032                .beta
1033                .clone();
1034            cursor += 1;
1035            Some(beta)
1036        } else {
1037            None
1038        };
1039        let beta_link_dev = if link_deviation_runtime.is_some() {
1040            Some(
1041                blocks
1042                    .get(cursor)
1043                    .ok_or_else(|| "missing link-deviation coefficient block".to_string())?
1044                    .beta
1045                    .clone(),
1046            )
1047        } else {
1048            None
1049        };
1050        Ok(Self {
1051            beta_marginal: blocks[0].beta.clone(),
1052            beta_logslope: blocks[1].beta.clone(),
1053            beta_score_warp,
1054            beta_link_dev,
1055            base_link,
1056            z_column,
1057            latent_z_normalization,
1058            latent_measure,
1059            baseline_marginal,
1060            baseline_logslope,
1061            covariance: unified.beta_covariance().cloned(),
1062            score_warp_runtime,
1063            link_deviation_runtime,
1064            gaussian_frailty_sd,
1065            latent_z_calibration,
1066            latent_z_conditional_calibration,
1067        })
1068    }
1069
1070    pub fn theta(&self) -> Array1<f64> {
1071        let total = self.beta_marginal.len()
1072            + self.beta_logslope.len()
1073            + self.beta_score_warp.as_ref().map_or(0, |b| b.len())
1074            + self.beta_link_dev.as_ref().map_or(0, |b| b.len());
1075        let mut theta = Array1::<f64>::zeros(total);
1076        let mut cursor = 0usize;
1077        theta
1078            .slice_mut(ndarray::s![cursor..cursor + self.beta_marginal.len()])
1079            .assign(&self.beta_marginal);
1080        cursor += self.beta_marginal.len();
1081        theta
1082            .slice_mut(ndarray::s![cursor..cursor + self.beta_logslope.len()])
1083            .assign(&self.beta_logslope);
1084        cursor += self.beta_logslope.len();
1085        if let Some(beta) = self.beta_score_warp.as_ref() {
1086            theta
1087                .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1088                .assign(beta);
1089            cursor += beta.len();
1090        }
1091        if let Some(beta) = self.beta_link_dev.as_ref() {
1092            theta
1093                .slice_mut(ndarray::s![cursor..cursor + beta.len()])
1094                .assign(beta);
1095        }
1096        theta
1097    }
1098
1099    fn split_theta<'a>(
1100        &'a self,
1101        theta: &'a Array1<f64>,
1102    ) -> Result<
1103        (
1104            ArrayView1<'a, f64>,
1105            ArrayView1<'a, f64>,
1106            Option<ArrayView1<'a, f64>>,
1107            Option<ArrayView1<'a, f64>>,
1108        ),
1109        EstimationError,
1110    > {
1111        let expected = self.theta().len();
1112        if theta.len() != expected {
1113            return Err(EstimationError::InvalidInput(format!(
1114                "bernoulli marginal-slope theta length mismatch: expected {expected}, got {}",
1115                theta.len()
1116            )));
1117        }
1118        let mut cursor = 0usize;
1119        let marginal = theta.slice(ndarray::s![cursor..cursor + self.beta_marginal.len()]);
1120        cursor += self.beta_marginal.len();
1121        let logslope = theta.slice(ndarray::s![cursor..cursor + self.beta_logslope.len()]);
1122        cursor += self.beta_logslope.len();
1123        let score_warp = self.beta_score_warp.as_ref().map(|beta| {
1124            let view = theta.slice(ndarray::s![cursor..cursor + beta.len()]);
1125            cursor += beta.len();
1126            view
1127        });
1128        let link_dev = self
1129            .beta_link_dev
1130            .as_ref()
1131            .map(|beta| theta.slice(ndarray::s![cursor..cursor + beta.len()]));
1132        Ok((marginal, logslope, score_warp, link_dev))
1133    }
1134
1135    /// Safeguarded monotone root solve for the marginal intercept under the
1136    /// de-nested flexible model
1137    ///   η(z) = a + b z + b Δ_h(z) + Δ_w(a + b z).
1138    fn solve_intercept_scalar(
1139        &self,
1140        marginal_eta: f64,
1141        slope: f64,
1142        link_dev_beta: Option<&Array1<f64>>,
1143        score_warp_beta: Option<&Array1<f64>>,
1144        empirical_grid: Option<&EmpiricalZGrid>,
1145        warm_start_buf: &mut Array1<f64>,
1146        score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1147        link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
1148    ) -> Result<f64, EstimationError> {
1149        let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
1150            .map_err(EstimationError::InvalidInput)?;
1151        let eval = |a: f64| -> Result<(f64, f64, f64), String> {
1152            self.evaluate_prediction_calibration(
1153                a,
1154                marginal_eta,
1155                slope,
1156                score_warp_beta,
1157                link_dev_beta,
1158                empirical_grid,
1159                score_warp_correction_for_row,
1160                link_dev_correction_for_row,
1161            )
1162            .map_err(|err| err.to_string())
1163        };
1164
1165        let probit_scale = self.probit_frailty_scale();
1166        let a_rigid = self.rigid_intercept_from_marginal(marginal.q, slope);
1167        let mut intercept = a_rigid;
1168        if let (Some(_), Some(beta)) = (self.link_deviation_runtime.as_ref(), link_dev_beta) {
1169            warm_start_buf[0] = a_rigid;
1170            let one_pt = warm_start_buf.slice(ndarray::s![0..1]).to_owned();
1171            let (l_val, l_d1) =
1172                self.link_terms_value_d1(&one_pt, Some(beta), link_dev_correction_for_row)?;
1173            let ell1 = l_d1[0];
1174            if ell1 > 1e-8 {
1175                let ell0 = l_val[0] - ell1 * a_rigid;
1176                let observed_logslope = probit_scale * ell1 * slope;
1177                intercept = (marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
1178                    / probit_scale
1179                    - ell0)
1180                    / ell1;
1181            }
1182        }
1183
1184        // Same adaptive tolerance the acceptance check below uses; passing
1185        // a tighter `convergence_tol` would just iterate past what we accept.
1186        let target = marginal.mu;
1187        let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
1188
1189        let (root, _, f_best) = crate::monotone_root::solve_monotone_root(
1190            eval,
1191            intercept,
1192            "saved bernoulli intercept",
1193            abs_tol,
1194            64,
1195            48,
1196        )?;
1197
1198        if f_best.abs() > abs_tol {
1199            return Err(EstimationError::InvalidInput(format!(
1200                "saved bernoulli marginal-slope intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target:.6}"
1201            )));
1202        }
1203        Ok(root)
1204    }
1205
1206    pub fn final_eta_and_gradient_from_theta(
1207        &self,
1208        input: &PredictInput,
1209        theta: &Array1<f64>,
1210        need_gradient: bool,
1211    ) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
1212        let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1213            EstimationError::InvalidInput(format!(
1214                "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1215                self.z_column
1216            ))
1217        })?;
1218        let z_normalized = self
1219            .latent_z_normalization
1220            .apply(z_raw, "bernoulli marginal-slope prediction")
1221            .map_err(EstimationError::from)?;
1222        // P4: when training applied a rank-INT calibration to the latent
1223        // z (so the BMS rigid kernel could use the closed-form
1224        // standard-normal path), the predictor MUST apply the same
1225        // monotone transform to predict-time z before any kernel
1226        // evaluation. The transform is mathematically exact: piecewise-
1227        // linear interpolation on (sorted_z, weighted_cdf) followed by
1228        // Φ⁻¹, both strictly monotone and invertible up to the empirical
1229        // CDF resolution. `None` ⇒ training-time z passed the strict
1230        // normality check, no transform was applied, leave z unchanged.
1231        let z = self.apply_latent_z_calibration(&z_normalized);
1232        // #905: replace z by ζ = (z − m(C))/√v(C) when training engaged the
1233        // conditional Auto gate (no-op otherwise; mutually exclusive with the
1234        // rank-INT calibration above).
1235        let z = self.apply_latent_z_conditional_calibration(&z, input)?;
1236        let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1237            EstimationError::InvalidInput(
1238                "bernoulli marginal-slope prediction requires logslope design".to_string(),
1239            )
1240        })?;
1241        let (beta_marginal, beta_logslope, beta_score_warp, beta_link_dev) =
1242            self.split_theta(theta)?;
1243        if self.score_warp_runtime.is_some() != beta_score_warp.is_some() {
1244            return Err(EstimationError::InvalidInput(
1245                "bernoulli marginal-slope saved score-warp runtime/coefficients are inconsistent"
1246                    .to_string(),
1247            ));
1248        }
1249        if self.link_deviation_runtime.is_some() != beta_link_dev.is_some() {
1250            return Err(EstimationError::InvalidInput(
1251                "bernoulli marginal-slope saved link-deviation runtime/coefficients are inconsistent"
1252                    .to_string(),
1253            ));
1254        }
1255        let n = z.len();
1256        if input.offset.len() != n {
1257            return Err(EstimationError::InvalidInput(format!(
1258                "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1259                input.offset.len()
1260            )));
1261        }
1262        let logslope_offset = input
1263            .offset_noise
1264            .as_ref()
1265            .map_or_else(|| Array1::zeros(n), Clone::clone);
1266        if logslope_offset.len() != n {
1267            return Err(EstimationError::InvalidInput(format!(
1268                "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1269                logslope_offset.len()
1270            )));
1271        }
1272        let marginal_eta = input
1273            .design
1274            .dot(&beta_marginal.to_owned())
1275            .mapv(|v| v + self.baseline_marginal)
1276            + &input.offset;
1277        let logslope_eta = design_logslope
1278            .dot(&beta_logslope.to_owned())
1279            .mapv(|v| v + self.baseline_logslope)
1280            + &logslope_offset;
1281        let flex_active =
1282            self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1283        let marginal_dim = self.beta_marginal.len();
1284        let logslope_dim = self.beta_logslope.len();
1285        let score_warp_dim = self.beta_score_warp.as_ref().map_or(0, Array1::len);
1286        let link_dev_dim = self.beta_link_dev.as_ref().map_or(0, Array1::len);
1287        let logslope_offset = marginal_dim;
1288        let score_warp_offset = logslope_offset + logslope_dim;
1289        let link_dev_offset = score_warp_offset + score_warp_dim;
1290        let chunk_size = prediction_chunk_rows(theta.len(), 1, n);
1291        let num_chunks = n.div_ceil(chunk_size);
1292        let scale = self.probit_frailty_scale();
1293        // Cross-block anchor corrections: when either runtime carries an
1294        // anchor residual, precompute the per-row correction matrices
1295        // (n × runtime_basis_dim) once. Each subsequent per-row evaluation
1296        // subtracts the corresponding row of these matrices from the raw
1297        // cubic-span basis output. When neither runtime has a residual,
1298        // the returned bundle is empty and threading is a no-op.
1299        let anchor_corrections =
1300            self.build_anchor_correction_matrices(input, design_logslope, &z)?;
1301        let marginal_map = marginal_eta
1302            .iter()
1303            .map(|&eta| {
1304                bernoulli_marginal_link_map(&self.base_link, eta)
1305                    .map_err(EstimationError::InvalidInput)
1306            })
1307            .collect::<Result<Vec<_>, _>>()?;
1308
1309        if !flex_active {
1310            let (final_eta_internal, marginal_scales, logslope_scales) = match &self.latent_measure
1311            {
1312                LatentMeasureKind::StandardNormal => {
1313                    let sb_vec = logslope_eta.mapv(|b| scale * b);
1314                    let c_vec = sb_vec.mapv(|sb| (1.0 + sb * sb).sqrt());
1315                    let final_eta_internal = Array1::from_iter(
1316                        (0..n).map(|i| c_vec[i] * marginal_eta[i] + sb_vec[i] * z[i]),
1317                    );
1318                    let marginal_scales = c_vec;
1319                    let logslope_scales = Array1::from_iter((0..n).map(|i| {
1320                        marginal_eta[i] * (scale * scale) * logslope_eta[i] / marginal_scales[i]
1321                            + scale * z[i]
1322                    }));
1323                    (final_eta_internal, marginal_scales, logslope_scales)
1324                }
1325                LatentMeasureKind::GlobalEmpirical { grid } => {
1326                    let mut final_eta = Array1::<f64>::zeros(n);
1327                    let mut marginal_scales = Array1::<f64>::zeros(n);
1328                    let mut logslope_scales = Array1::<f64>::zeros(n);
1329                    for i in 0..n {
1330                        let (intercept, a_marginal, a_slope) = self
1331                            .empirical_rigid_intercept_and_gradient(
1332                                marginal_eta[i],
1333                                logslope_eta[i],
1334                                &grid.nodes,
1335                                &grid.weights,
1336                            )?;
1337                        final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
1338                        marginal_scales[i] = a_marginal;
1339                        logslope_scales[i] = a_slope + scale * z[i];
1340                    }
1341                    (final_eta, marginal_scales, logslope_scales)
1342                }
1343                LatentMeasureKind::LocalEmpirical { .. } => {
1344                    let mut final_eta = Array1::<f64>::zeros(n);
1345                    let mut marginal_scales = Array1::<f64>::zeros(n);
1346                    let mut logslope_scales = Array1::<f64>::zeros(n);
1347                    for i in 0..n {
1348                        let grid = self
1349                            .empirical_grid_for_prediction_row(input, i)?
1350                            .ok_or_else(|| {
1351                                EstimationError::InvalidInput(
1352                                    "local empirical latent prediction did not produce a row grid"
1353                                        .to_string(),
1354                                )
1355                            })?;
1356                        let (intercept, a_marginal, a_slope) = self
1357                            .empirical_rigid_intercept_and_gradient(
1358                                marginal_eta[i],
1359                                logslope_eta[i],
1360                                &grid.nodes,
1361                                &grid.weights,
1362                            )?;
1363                        final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
1364                        marginal_scales[i] = a_marginal;
1365                        logslope_scales[i] = a_slope + scale * z[i];
1366                    }
1367                    (final_eta, marginal_scales, logslope_scales)
1368                }
1369            };
1370
1371            if !need_gradient {
1372                return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
1373            }
1374
1375            // Chunk Jacobian: one pass per row fills both blocks.
1376            let mut grad_internal = Array2::<f64>::zeros((n, theta.len()));
1377            let mut start = 0usize;
1378            while start < n {
1379                let end = (start + chunk_size).min(n);
1380                let mc = input
1381                    .design
1382                    .try_row_chunk(start..end)
1383                    .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
1384                let lc = design_logslope
1385                    .try_row_chunk(start..end)
1386                    .map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
1387
1388                for li in 0..(end - start) {
1389                    let i = start + li;
1390                    let c = marginal_scales[i];
1391                    let g_scale = logslope_scales[i];
1392                    let mut row = grad_internal.row_mut(i);
1393                    for j in 0..marginal_dim {
1394                        row[j] = c * mc[[li, j]];
1395                    }
1396                    for j in 0..logslope_dim {
1397                        row[logslope_offset + j] = g_scale * lc[[li, j]];
1398                    }
1399                }
1400
1401                start = end;
1402            }
1403            return self
1404                .transform_internal_eta_to_base_scale(final_eta_internal, Some(grad_internal));
1405        }
1406
1407        // ── Flexible path: per-row intercept solve, chunked Jacobians ──
1408        let score_warp_obs_design = self
1409            .score_warp_runtime
1410            .as_ref()
1411            .map(|runtime| {
1412                if runtime.anchor_correction.is_some() {
1413                    let anchor_rows = anchor_corrections
1414                        .score_warp_anchor_rows_view()
1415                        .ok_or_else(|| {
1416                            EstimationError::InvalidInput(
1417                                "bernoulli marginal-slope score-warp anchor residual present but \
1418                                 anchor_corrections bundle is missing the parametric anchor rows"
1419                                    .to_string(),
1420                            )
1421                        })?;
1422                    runtime
1423                        .design_with_anchor_rows(&z, anchor_rows)
1424                        .map_err(EstimationError::from)
1425                } else {
1426                    runtime.design(&z).map_err(EstimationError::from)
1427                }
1428            })
1429            .transpose()?;
1430        let score_dev_obs =
1431            if let (Some(design), Some(beta)) = (score_warp_obs_design.as_ref(), beta_score_warp) {
1432                design.dot(&beta.to_owned())
1433            } else {
1434                Array1::zeros(n)
1435            };
1436
1437        // Solve intercepts and (when gradient needed) IFT scalars in chunk-parallel passes.
1438        // Outputs are preallocated and each parallel worker writes directly into
1439        // its exclusive `axis_chunks_iter_mut` slice; no per-chunk owned buffer
1440        // and no serial copy pass over the result chunks.
1441        let score_warp_beta_owned = beta_score_warp.as_ref().map(|v| v.to_owned());
1442        let link_dev_beta_owned = beta_link_dev.as_ref().map(|v| v.to_owned());
1443        let mut intercepts = Array1::<f64>::zeros(n);
1444        let mut a_q_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
1445        let mut a_b_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
1446        let mut a_h_rows = if need_gradient && score_warp_dim > 0 {
1447            Some(Array2::<f64>::zeros((n, score_warp_dim)))
1448        } else {
1449            None
1450        };
1451        let mut a_w_rows = if need_gradient && link_dev_dim > 0 {
1452            Some(Array2::<f64>::zeros((n, link_dev_dim)))
1453        } else {
1454            None
1455        };
1456        let solve_result: Result<(), EstimationError> = {
1457            use ndarray::Axis;
1458            use rayon::iter::IndexedParallelIterator;
1459            let intercepts_chunks: Vec<ndarray::ArrayViewMut1<f64>> = intercepts
1460                .axis_chunks_iter_mut(Axis(0), chunk_size)
1461                .collect();
1462            let a_q_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_q_vec
1463                .as_mut()
1464                .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1465            let a_b_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_b_vec
1466                .as_mut()
1467                .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1468            let a_h_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_h_rows
1469                .as_mut()
1470                .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1471            let a_w_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_w_rows
1472                .as_mut()
1473                .map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
1474
1475            // Bundle per-chunk sinks so each parallel worker owns disjoint mutable
1476            // views into the shared output arrays.
1477            struct FlexSolveSink<'a> {
1478                intercepts: ndarray::ArrayViewMut1<'a, f64>,
1479                a_q: Option<ndarray::ArrayViewMut1<'a, f64>>,
1480                a_b: Option<ndarray::ArrayViewMut1<'a, f64>>,
1481                a_h: Option<ndarray::ArrayViewMut2<'a, f64>>,
1482                a_w: Option<ndarray::ArrayViewMut2<'a, f64>>,
1483            }
1484            let mut sinks: Vec<FlexSolveSink<'_>> = Vec::with_capacity(num_chunks);
1485            // Move each Option<Vec> into iterators so we can zip them.
1486            let mut intercepts_iter = intercepts_chunks.into_iter();
1487            let mut a_q_iter = a_q_chunks.map(|v| v.into_iter());
1488            let mut a_b_iter = a_b_chunks.map(|v| v.into_iter());
1489            let mut a_h_iter = a_h_chunks.map(|v| v.into_iter());
1490            let mut a_w_iter = a_w_chunks.map(|v| v.into_iter());
1491            for _ in 0..num_chunks {
1492                sinks.push(FlexSolveSink {
1493                    intercepts: intercepts_iter.next().expect("chunk count matches"),
1494                    a_q: a_q_iter
1495                        .as_mut()
1496                        .map(|it| it.next().expect("chunk count matches")),
1497                    a_b: a_b_iter
1498                        .as_mut()
1499                        .map(|it| it.next().expect("chunk count matches")),
1500                    a_h: a_h_iter
1501                        .as_mut()
1502                        .map(|it| it.next().expect("chunk count matches")),
1503                    a_w: a_w_iter
1504                        .as_mut()
1505                        .map(|it| it.next().expect("chunk count matches")),
1506                });
1507            }
1508
1509            // Precompute the score-warp basis cubic table once when the latent
1510            // grid is row-constant (`GlobalEmpirical`). The per-row inner loop
1511            // calls `basis_cubic_at(j, node)` with `node` taken from the grid,
1512            // which is identical for every row in this code path, so the
1513            // n_rows × n_nodes × score_warp_dim table can be hoisted out of
1514            // the parallel chunk dispatch. Per-row work only touches the
1515            // basis-function-specific `c0` shift via `score_corr_row`, which
1516            // stays inside the row loop. Computed at the top level so no
1517            // OnceLock / lazy init lives inside the par closure (per the
1518            // OnceLock + nested rayon deadlock rule).
1519            let global_score_basis_table: Option<
1520                Vec<Vec<crate::cubic_cell_kernel::LocalSpanCubic>>,
1521            > = if let (LatentMeasureKind::GlobalEmpirical { grid }, Some(runtime)) =
1522                (&self.latent_measure, self.score_warp_runtime.as_ref())
1523            {
1524                let mut table = Vec::with_capacity(score_warp_dim);
1525                for j in 0..score_warp_dim {
1526                    let mut row = Vec::with_capacity(grid.nodes.len());
1527                    for &node in &grid.nodes {
1528                        row.push(
1529                            runtime
1530                                .basis_cubic_at(j, node)
1531                                .map_err(EstimationError::from)?,
1532                        );
1533                    }
1534                    table.push(row);
1535                }
1536                Some(table)
1537            } else {
1538                None
1539            };
1540            let global_score_basis_table = global_score_basis_table.as_ref();
1541
1542            sinks
1543                .into_par_iter()
1544                .enumerate()
1545                .try_for_each(|(chunk_idx, mut sink)| -> Result<(), EstimationError> {
1546                let start = chunk_idx * chunk_size;
1547                let end = (start + chunk_size).min(n);
1548                let rows = end - start;
1549                // Destructure the sink into independent `&mut` references so we
1550                // can borrow them disjointly across iterations of the inner row
1551                // loop without further reborrowing through `Option::as_mut`.
1552                let intercepts_view = &mut sink.intercepts;
1553                let mut a_q = sink.a_q.as_mut();
1554                let mut a_b = sink.a_b.as_mut();
1555                let mut a_h = sink.a_h.as_mut();
1556                let mut a_w = sink.a_w.as_mut();
1557                let mut warm_start_buf = Array1::<f64>::zeros(1);
1558                let mut f_h_row = vec![0.0; score_warp_dim];
1559                let mut f_w_row = vec![0.0; link_dev_dim];
1560
1561                for local_row in 0..rows {
1562                    let i = start + local_row;
1563                    let slope = logslope_eta[i];
1564                    let q = marginal_eta[i];
1565                    let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
1566                    let score_corr_row = anchor_corrections.score_warp_row(i);
1567                    let link_corr_row = anchor_corrections.link_dev_row(i);
1568                    intercepts_view[local_row] = self.solve_intercept_scalar(
1569                        q,
1570                        slope,
1571                        link_dev_beta_owned.as_ref(),
1572                        score_warp_beta_owned.as_ref(),
1573                        empirical_grid.as_ref(),
1574                        &mut warm_start_buf,
1575                        score_corr_row,
1576                        link_corr_row,
1577                    )?;
1578
1579                    if !need_gradient {
1580                        continue;
1581                    }
1582
1583                    let intercept = intercepts_view[local_row];
1584                    let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
1585                        intercept,
1586                        q,
1587                        slope,
1588                        score_warp_beta_owned.as_ref(),
1589                        link_dev_beta_owned.as_ref(),
1590                        empirical_grid.as_ref(),
1591                        score_corr_row,
1592                        link_corr_row,
1593                    )?;
1594                    let m_a = m_a_raw.max(1e-12);
1595                    a_q.as_mut().expect("a_q allocated when need_gradient")[local_row] =
1596                        marginal_map[i].mu1 / m_a;
1597                    let mut f_b = 0.0;
1598                    f_h_row.fill(0.0);
1599                    f_w_row.fill(0.0);
1600                    if let Some(grid) = empirical_grid.as_ref() {
1601                        for (node_idx, (node, weight)) in grid.pairs().enumerate() {
1602                            let obs = self.observed_denested_cell_partials_at_z(
1603                                node,
1604                                intercept,
1605                                slope,
1606                                score_warp_beta_owned.as_ref(),
1607                                link_dev_beta_owned.as_ref(),
1608                                score_corr_row,
1609                                link_corr_row,
1610                            )?;
1611                            let eta = eval_coeff4_at(&obs.coeff, node);
1612                            let pdf = normal_pdf(eta);
1613                            f_b += weight * pdf * eval_coeff4_at(&obs.dc_db, node);
1614
1615                            if let Some(runtime) = self.score_warp_runtime.as_ref() {
1616                                for j in 0..score_warp_dim {
1617                                    // When the latent grid is row-constant
1618                                    // (`GlobalEmpirical`), the per-(j, node)
1619                                    // basis cubic is identical for every row
1620                                    // and lives in `global_score_basis_table`.
1621                                    // Otherwise (`LocalEmpirical`) the grid
1622                                    // varies per row and we fall back to a
1623                                    // direct `basis_cubic_at` call.
1624                                    let mut basis_span = if let Some(table) =
1625                                        global_score_basis_table
1626                                    {
1627                                        table[j][node_idx]
1628                                    } else {
1629                                        runtime
1630                                            .basis_cubic_at(j, node)
1631                                            .map_err(EstimationError::from)?
1632                                    };
1633                                    // `basis_cubic_at` returns the j-th basis
1634                                    // function's local cubic; the residual
1635                                    // subtracts `correction[j]` from the
1636                                    // constant term (row-constant, basis-
1637                                    // function-specific). Higher span
1638                                    // coefficients are unaffected.
1639                                    if let Some(corr) = score_corr_row {
1640                                        basis_span.c0 -= corr[j];
1641                                    }
1642                                    let coeffs = crate::cubic_cell_kernel::score_basis_cell_coefficients(
1643                                        basis_span,
1644                                        slope,
1645                                    );
1646                                    let coeffs = scale_coeff4(coeffs, scale);
1647                                    f_h_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
1648                                }
1649                            }
1650
1651                            if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1652                                for j in 0..link_dev_dim {
1653                                    let mut basis_span = runtime
1654                                        .basis_cubic_at(j, intercept + slope * node)
1655                                        .map_err(EstimationError::from)?;
1656                                    if let Some(corr) = link_corr_row {
1657                                        basis_span.c0 -= corr[j];
1658                                    }
1659                                    let coeffs = crate::cubic_cell_kernel::link_basis_cell_coefficients(
1660                                        basis_span,
1661                                        intercept,
1662                                        slope,
1663                                    );
1664                                    let coeffs = scale_coeff4(coeffs, scale);
1665                                    f_w_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
1666                                }
1667                            }
1668                        }
1669                    } else {
1670                        let cells = self.denested_partition_cells(
1671                            intercept,
1672                            slope,
1673                            score_warp_beta_owned.as_ref(),
1674                            link_dev_beta_owned.as_ref(),
1675                            score_corr_row,
1676                            link_corr_row,
1677                        )?;
1678                        for partition_cell in cells {
1679                            let cell = partition_cell.cell;
1680                            let state =
1681                                crate::cubic_cell_kernel::evaluate_cell_moments(
1682                                    cell, 9,
1683                                )
1684                                .map_err(EstimationError::InvalidInput)?;
1685                            let (_, dc_db_raw) = crate::cubic_cell_kernel::denested_cell_coefficient_partials(
1686                                partition_cell.score_span,
1687                                partition_cell.link_span,
1688                                intercept,
1689                                slope,
1690                            );
1691                            // `denested_partition_cells` scales the cell itself for
1692                            // Gaussian frailty, so every coefficient partial of
1693                            // F(a, theta) must carry the same probit scale as F_a.
1694                            let dc_db = scale_coeff4(dc_db_raw, scale);
1695                            f_b += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1696                                &dc_db,
1697                                &state.moments,
1698                            )
1699                            .map_err(EstimationError::InvalidInput)?;
1700
1701                            let mid = 0.5 * (cell.left + cell.right);
1702                            if let Some(runtime) = self.score_warp_runtime.as_ref() {
1703                                for j in 0..score_warp_dim {
1704                                    let mut basis_span = runtime
1705                                        .basis_cubic_at(j, mid)
1706                                        .map_err(EstimationError::from)?;
1707                                    if let Some(corr) = score_corr_row {
1708                                        basis_span.c0 -= corr[j];
1709                                    }
1710                                    let coeffs = crate::cubic_cell_kernel::score_basis_cell_coefficients(
1711                                        basis_span, slope,
1712                                    );
1713                                    let coeffs = scale_coeff4(coeffs, scale);
1714                                    f_h_row[j] += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1715                                        &coeffs,
1716                                        &state.moments,
1717                                    )
1718                                    .map_err(EstimationError::InvalidInput)?;
1719                                }
1720                            }
1721
1722                            if let Some(runtime) = self.link_deviation_runtime.as_ref() {
1723                                for j in 0..link_dev_dim {
1724                                    let mut basis_span = runtime
1725                                        .basis_cubic_at(j, intercept + slope * mid)
1726                                        .map_err(EstimationError::from)?;
1727                                    if let Some(corr) = link_corr_row {
1728                                        basis_span.c0 -= corr[j];
1729                                    }
1730                                    let coeffs = crate::cubic_cell_kernel::link_basis_cell_coefficients(
1731                                        basis_span,
1732                                        intercept,
1733                                        slope,
1734                                    );
1735                                    let coeffs = scale_coeff4(coeffs, scale);
1736                                    f_w_row[j] += crate::cubic_cell_kernel::cell_first_derivative_from_moments(
1737                                        &coeffs,
1738                                        &state.moments,
1739                                    )
1740                                    .map_err(EstimationError::InvalidInput)?;
1741                                }
1742                            }
1743                        }
1744                    }
1745                    if let Some(a_h_view) = a_h.as_mut() {
1746                        let factor = -1.0 / m_a;
1747                        for j in 0..score_warp_dim {
1748                            a_h_view[[local_row, j]] = factor * f_h_row[j];
1749                        }
1750                    }
1751                    if let Some(a_w_view) = a_w.as_mut() {
1752                        let factor = -1.0 / m_a;
1753                        for j in 0..link_dev_dim {
1754                            a_w_view[[local_row, j]] = factor * f_w_row[j];
1755                        }
1756                    }
1757                    a_b.as_mut().expect("a_b allocated when need_gradient")[local_row] =
1758                        -f_b / m_a;
1759                }
1760                Ok(())
1761            })
1762        };
1763        solve_result?;
1764
1765        let eta_base = &intercepts + &(&logslope_eta * &z);
1766
1767        let mut link_c_obs: Option<Array1<f64>> = None;
1768        let mut link_basis_obs: Option<Array2<f64>> = None;
1769        let link_dev_obs = if let (Some(runtime), Some(beta_owned)) = (
1770            self.link_deviation_runtime.as_ref(),
1771            link_dev_beta_owned.as_ref(),
1772        ) {
1773            let basis = if runtime.anchor_correction.is_some() {
1774                let anchor_rows =
1775                    anchor_corrections
1776                        .link_dev_anchor_rows_view()
1777                        .ok_or_else(|| {
1778                            EstimationError::InvalidInput(
1779                            "bernoulli marginal-slope link-deviation anchor residual present but \
1780                             anchor_corrections bundle is missing the parametric anchor rows"
1781                                .to_string(),
1782                        )
1783                        })?;
1784                runtime
1785                    .design_with_anchor_rows(&eta_base, anchor_rows)
1786                    .map_err(EstimationError::from)?
1787            } else {
1788                runtime.design(&eta_base).map_err(EstimationError::from)?
1789            };
1790            let dev = basis.dot(beta_owned);
1791            if need_gradient {
1792                let d1 = runtime
1793                    .first_derivative_design(&eta_base)
1794                    .map_err(EstimationError::from)?;
1795                let mut c_obs = d1.dot(beta_owned);
1796                c_obs.mapv_inplace(|v| v + 1.0);
1797                link_c_obs = Some(c_obs);
1798                link_basis_obs = Some(basis);
1799            }
1800            dev
1801        } else {
1802            Array1::zeros(n)
1803        };
1804        let final_eta_internal =
1805            (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
1806
1807        if !need_gradient {
1808            return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
1809        }
1810
1811        let a_q_vec = a_q_vec.unwrap();
1812        let a_b_vec = a_b_vec.unwrap();
1813
1814        // Emit chunk Jacobians using precomputed scalars; each worker writes
1815        // directly into its exclusive `axis_chunks_iter_mut` slice of the
1816        // preallocated `grad` output so no serial copy pass is needed.
1817        let mut grad = Array2::<f64>::zeros((n, theta.len()));
1818        {
1819            use ndarray::Axis;
1820            use rayon::iter::IndexedParallelIterator;
1821            let grad_result: Result<(), String> = grad
1822                .axis_chunks_iter_mut(Axis(0), chunk_size)
1823                .into_par_iter()
1824                .enumerate()
1825                .try_for_each(|(chunk_idx, mut grad_chunk)| -> Result<(), String> {
1826                    let start = chunk_idx * chunk_size;
1827                    let end = (start + chunk_size).min(n);
1828                    let mc = input
1829                        .design
1830                        .try_row_chunk(start..end)
1831                        .map_err(|e| e.to_string())?;
1832                    let lc = design_logslope
1833                        .try_row_chunk(start..end)
1834                        .map_err(|e| e.to_string())?;
1835                    let rows = end - start;
1836
1837                    for li in 0..rows {
1838                        let i = start + li;
1839                        let mut row = grad_chunk.row_mut(li);
1840
1841                        let a_q = a_q_vec[i];
1842                        for j in 0..marginal_dim {
1843                            row[j] = a_q * mc[[li, j]];
1844                        }
1845
1846                        let base_multiplier = link_c_obs.as_ref().map_or(1.0, |c| c[i]);
1847                        let g_scale = base_multiplier * (a_b_vec[i] + z[i]) + score_dev_obs[i];
1848                        for j in 0..logslope_dim {
1849                            row[logslope_offset + j] = g_scale * lc[[li, j]];
1850                        }
1851
1852                        if let (Some(a_h_rows), Some(obs_design)) =
1853                            (a_h_rows.as_ref(), score_warp_obs_design.as_ref())
1854                        {
1855                            let slope = logslope_eta[i];
1856                            for j in 0..score_warp_dim {
1857                                row[score_warp_offset + j] =
1858                                    base_multiplier * a_h_rows[[i, j]] + slope * obs_design[[i, j]];
1859                            }
1860                        }
1861
1862                        if let Some(a_w_rows) = a_w_rows.as_ref() {
1863                            for j in 0..link_dev_dim {
1864                                row[link_dev_offset + j] = a_w_rows[[i, j]];
1865                            }
1866                        }
1867
1868                        if let (Some(link_c), Some(link_basis)) =
1869                            (link_c_obs.as_ref(), link_basis_obs.as_ref())
1870                        {
1871                            let c = link_c[i];
1872                            for j in 0..marginal_dim {
1873                                row[j] *= c;
1874                            }
1875                            for j in 0..link_dev_dim {
1876                                row[link_dev_offset + j] =
1877                                    c * row[link_dev_offset + j] + link_basis[[i, j]];
1878                            }
1879                        }
1880                    }
1881                    Ok(())
1882                });
1883            grad_result.map_err(EstimationError::InvalidInput)?;
1884        }
1885        if scale != 1.0 {
1886            grad.mapv_inplace(|v| scale * v);
1887        }
1888        self.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad))
1889    }
1890
1891    /// Per-row final (base-scale) linear predictor for an arbitrary
1892    /// coefficient vector `theta` in the saved `[marginal | logslope |
1893    /// score_warp? | link_dev?]` block order. The marginal-slope rigid
1894    /// kernel is applied exactly per row, so the returned η is the same
1895    /// object the point predictor consumes — only parameterised by an
1896    /// external draw instead of `self.theta()`. Used by the posterior
1897    /// predictive path (#1049) to map each Laplace draw to its η surface
1898    /// before the shared eta→bands collapse; the response scale is the
1899    /// probit inverse link `μ = Φ(η)`.
1900    pub fn final_eta_from_theta(
1901        &self,
1902        input: &PredictInput,
1903        theta: &Array1<f64>,
1904    ) -> Result<Array1<f64>, EstimationError> {
1905        let (eta, _) = self.final_eta_and_gradient_from_theta(input, theta, false)?;
1906        Ok(eta)
1907    }
1908
1909    /// Length of the concatenated coefficient vector this predictor
1910    /// consumes (`marginal + logslope + score_warp? + link_dev?`). The
1911    /// posterior predictive path validates each saved draw against this
1912    /// before mapping it through [`Self::final_eta_from_theta`].
1913    pub fn theta_len(&self) -> usize {
1914        self.beta_marginal.len()
1915            + self.beta_logslope.len()
1916            + self.beta_score_warp.as_ref().map_or(0, Array1::len)
1917            + self.beta_link_dev.as_ref().map_or(0, Array1::len)
1918    }
1919
1920    /// Per-row `(eta, ∂eta/∂q_marginal)` under the exact IFT pull-back.
1921    ///
1922    /// Returns the same `eta` as `predict_plugin_response`/`predict_linear_predictor`
1923    /// plus the analytic derivative of the internal probit index with respect to
1924    /// the per-row marginal q (the linear predictor before the de-nested
1925    /// calibration). Survival prediction multiplies the second component by the
1926    /// per-row `dq/dt` to obtain the exact hazard time derivative under
1927    /// score-warp / link-deviation flex blocks.
1928    ///
1929    /// Rigid path (no flex blocks): `∂eta/∂q = c = sqrt(1 + (s b)^2)`, recovering
1930    /// the rigid-path probit-frailty composition. Flex path: `∂eta/∂q =
1931    /// scale · link_c_obs · a_q` where `link_c_obs = 1 + Δ_w'(eta_base)` is the
1932    /// link-deviation slope at the observed `eta_base = a + b z` and `a_q =
1933    /// φ(q) / |F_a|` is the implicit-function derivative of the calibration
1934    /// intercept (mirrors the bernoulli `final_eta_and_gradient_from_theta`
1935    /// flex branch lines 1399-1593).
1936    pub fn predict_eta_and_q_chain(
1937        &self,
1938        input: &PredictInput,
1939    ) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
1940        let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
1941            EstimationError::InvalidInput(format!(
1942                "bernoulli marginal-slope prediction requires auxiliary z column '{}'",
1943                self.z_column
1944            ))
1945        })?;
1946        let z_normalized = self
1947            .latent_z_normalization
1948            .apply(z_raw, "bernoulli marginal-slope prediction")
1949            .map_err(EstimationError::from)?;
1950        // P4: see `final_eta_and_gradient_from_theta` for the rationale.
1951        // The rank-INT calibration is a mathematically exact monotone
1952        // transform; both the rigid standard-normal kernel and the
1953        // implicit-function chain rule consume the calibrated z, never
1954        // the raw normalized z, exactly mirroring fit-time semantics.
1955        let z = self.apply_latent_z_calibration(&z_normalized);
1956        // #905: replace z by ζ = (z − m(C))/√v(C) when training engaged the
1957        // conditional Auto gate (no-op otherwise; mutually exclusive with the
1958        // rank-INT calibration above).
1959        let z = self.apply_latent_z_conditional_calibration(&z, input)?;
1960        let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
1961            EstimationError::InvalidInput(
1962                "bernoulli marginal-slope prediction requires logslope design".to_string(),
1963            )
1964        })?;
1965        let n = z.len();
1966        if input.offset.len() != n {
1967            return Err(EstimationError::InvalidInput(format!(
1968                "bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
1969                input.offset.len()
1970            )));
1971        }
1972        let logslope_offset = input
1973            .offset_noise
1974            .as_ref()
1975            .map_or_else(|| Array1::zeros(n), Clone::clone);
1976        if logslope_offset.len() != n {
1977            return Err(EstimationError::InvalidInput(format!(
1978                "bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
1979                logslope_offset.len()
1980            )));
1981        }
1982        let marginal_eta = input
1983            .design
1984            .dot(&self.beta_marginal)
1985            .mapv(|v| v + self.baseline_marginal)
1986            + &input.offset;
1987        let logslope_eta = design_logslope
1988            .dot(&self.beta_logslope)
1989            .mapv(|v| v + self.baseline_logslope)
1990            + &logslope_offset;
1991        let scale = self.probit_frailty_scale();
1992        let flex_active =
1993            self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
1994
1995        // Rigid path mirrors `final_eta_and_gradient_from_theta` lines 1342-1383:
1996        //   eta = c·q + s·b·z,  ∂eta/∂q = c.
1997        if !flex_active {
1998            match &self.latent_measure {
1999                LatentMeasureKind::StandardNormal => {
2000                    // Vectorize: sb = scale·logslope, c = sqrt(1 + sb²),
2001                    // eta = c·marginal_eta + sb·z, ∂eta/∂q = c.
2002                    let sb = logslope_eta.mapv(|x| scale * x);
2003                    let deta_dq = sb.mapv(|s| (1.0 + s * s).sqrt());
2004                    let eta = &deta_dq * marginal_eta + &sb * z;
2005                    return Ok((eta, deta_dq));
2006                }
2007                _ => {
2008                    let mut eta = Array1::<f64>::zeros(n);
2009                    let mut deta_dq = Array1::<f64>::zeros(n);
2010                    for i in 0..n {
2011                        let grid = self
2012                            .empirical_grid_for_prediction_row(input, i)?
2013                            .ok_or_else(|| {
2014                                EstimationError::InvalidInput(
2015                                    "empirical latent prediction did not produce a row grid"
2016                                        .to_string(),
2017                                )
2018                            })?;
2019                        let (intercept, a_marginal, _) = self
2020                            .empirical_rigid_intercept_and_gradient(
2021                                marginal_eta[i],
2022                                logslope_eta[i],
2023                                &grid.nodes,
2024                                &grid.weights,
2025                            )?;
2026                        eta[i] = intercept + scale * logslope_eta[i] * z[i];
2027                        deta_dq[i] = a_marginal;
2028                    }
2029                    return Ok((eta, deta_dq));
2030                }
2031            }
2032        }
2033
2034        // Flex path: solve the per-row intercept, then evaluate
2035        //   eta = scale · (a + b·z + b·Δ_h(z) + Δ_w(a + b·z))
2036        //   ∂eta/∂q = scale · (1 + Δ_w'(a + b·z)) · ∂a/∂q,
2037        //   ∂a/∂q   = φ(q) / |F_a|         (IFT, marginal_link is probit so mu1 = φ(q))
2038        // Mirrors `final_eta_and_gradient_from_theta` lines 1385-1621.
2039        let marginal_map = marginal_eta
2040            .iter()
2041            .map(|&eta_marg| {
2042                bernoulli_marginal_link_map(&self.base_link, eta_marg)
2043                    .map_err(EstimationError::InvalidInput)
2044            })
2045            .collect::<Result<Vec<_>, _>>()?;
2046        // Cross-block anchor corrections (see final_eta_and_gradient_from_theta
2047        // for the design); precompute once before the per-row loop.
2048        let anchor_corrections =
2049            self.build_anchor_correction_matrices(input, design_logslope, &z)?;
2050        // Per-row: solve intercept scalar, evaluate denested calibration,
2051        // record (intercept, a_q). The `warm_start_buf` is just per-call
2052        // scratch — give each rayon worker its own buffer via fold init.
2053        use rayon::iter::{IntoParallelIterator, ParallelIterator};
2054        let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
2055            .into_par_iter()
2056            .map_init(
2057                || Array1::<f64>::zeros(1),
2058                |warm_start_buf, i| {
2059                    let q = marginal_eta[i];
2060                    let slope = logslope_eta[i];
2061                    let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
2062                    let score_corr_row = anchor_corrections.score_warp_row(i);
2063                    let link_corr_row = anchor_corrections.link_dev_row(i);
2064                    let intercept = self.solve_intercept_scalar(
2065                        q,
2066                        slope,
2067                        self.beta_link_dev.as_ref(),
2068                        self.beta_score_warp.as_ref(),
2069                        empirical_grid.as_ref(),
2070                        warm_start_buf,
2071                        score_corr_row,
2072                        link_corr_row,
2073                    )?;
2074                    let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
2075                        intercept,
2076                        q,
2077                        slope,
2078                        self.beta_score_warp.as_ref(),
2079                        self.beta_link_dev.as_ref(),
2080                        empirical_grid.as_ref(),
2081                        score_corr_row,
2082                        link_corr_row,
2083                    )?;
2084                    let m_a = m_a_raw.max(1e-12);
2085                    Ok((intercept, marginal_map[i].mu1 / m_a))
2086                },
2087            )
2088            .collect();
2089        let pairs = pairs?;
2090        let mut intercepts = Array1::<f64>::zeros(n);
2091        let mut a_q = Array1::<f64>::zeros(n);
2092        for (i, (intercept, a)) in pairs.into_iter().enumerate() {
2093            intercepts[i] = intercept;
2094            a_q[i] = a;
2095        }
2096
2097        let score_dev_obs = if let (Some(runtime), Some(beta)) = (
2098            self.score_warp_runtime.as_ref(),
2099            self.beta_score_warp.as_ref(),
2100        ) {
2101            let design = if runtime.anchor_correction.is_some() {
2102                let anchor_rows = anchor_corrections
2103                    .score_warp_anchor_rows_view()
2104                    .ok_or_else(|| {
2105                        EstimationError::InvalidInput(
2106                            "bernoulli marginal-slope score-warp anchor residual present but \
2107                             anchor_corrections bundle is missing the parametric anchor rows"
2108                                .to_string(),
2109                        )
2110                    })?;
2111                runtime
2112                    .design_with_anchor_rows(&z, anchor_rows)
2113                    .map_err(EstimationError::from)?
2114            } else {
2115                runtime.design(&z).map_err(EstimationError::from)?
2116            };
2117            design.dot(beta)
2118        } else {
2119            Array1::zeros(n)
2120        };
2121        let eta_base = &intercepts + &(&logslope_eta * &z);
2122        let (link_dev_obs, link_c_obs) = if let (Some(runtime), Some(beta)) = (
2123            self.link_deviation_runtime.as_ref(),
2124            self.beta_link_dev.as_ref(),
2125        ) {
2126            let basis = if runtime.anchor_correction.is_some() {
2127                let anchor_rows =
2128                    anchor_corrections
2129                        .link_dev_anchor_rows_view()
2130                        .ok_or_else(|| {
2131                            EstimationError::InvalidInput(
2132                            "bernoulli marginal-slope link-deviation anchor residual present but \
2133                             anchor_corrections bundle is missing the parametric anchor rows"
2134                                .to_string(),
2135                        )
2136                        })?;
2137                runtime
2138                    .design_with_anchor_rows(&eta_base, anchor_rows)
2139                    .map_err(EstimationError::from)?
2140            } else {
2141                runtime.design(&eta_base).map_err(EstimationError::from)?
2142            };
2143            let dev = basis.dot(beta);
2144            let d1 = runtime
2145                .first_derivative_design(&eta_base)
2146                .map_err(EstimationError::from)?;
2147            let mut c_obs = d1.dot(beta);
2148            c_obs.mapv_inplace(|v| v + 1.0);
2149            (dev, c_obs)
2150        } else {
2151            (Array1::zeros(n), Array1::ones(n))
2152        };
2153        let final_eta_internal =
2154            (&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
2155        let deta_dq = (&link_c_obs * &a_q).mapv(|v| scale * v);
2156        Ok((final_eta_internal, deta_dq))
2157    }
2158}