Skip to main content

gam_models/survival/marginal_slope/
identifiability.rs

1//! Survival marginal-slope concrete impls for the family-agnostic
2//! identifiability compiler (`gam_identifiability::families::compiler`).
3//!
4//! Survival's row primary state is the 4-vector `u_i = (q0, q1, qd1, g)`,
5//! so `K = 4`. The row Hessian is the 4×4 second-derivative block of the
6//! per-row neg-log-likelihood kernel `row_primary_closed_form` at a pilot
7//! `β`, PSD-clamped via eigendecomposition (negative eigenvalues projected
8//! to zero) to handle pilot points far from the optimum.
9//!
10//! Each block exposes its row Jacobian as the contribution of `δβ_block`
11//! to the row primary-state vector:
12//!
13//! - **TimeBlockOperator**: `(δq0, δq1, δqd1, 0)` from `design_entry`,
14//!   `design_exit`, `design_derivative_exit` rows.
15//! - **MarginalBlockOperator**: `(δq, δq, δqd_marginal, 0)` from the
16//!   marginal design row (shared by q0 and q1; qd contribution zero unless
17//!   timewiggle is active — captured by an explicit derivative row matrix).
18//! - **LogslopeBlockOperator**: `(0, 0, 0, δg)` from the logslope design.
19//! - **ScoreWarpBlockOperator**: `(δq, δq, δqd_warp, 0)` from the warp
20//!   basis (shifts q at entry/exit; chain rule via dq0_seed/dt for qd1).
21//! - **LinkDevBlockOperator**: `(δq, δq, δqd_link, 0)` from the link-dev
22//!   basis on the rigid/pilot q-seed.
23//!
24//! Phase 4a delivery: trait impls + an input-builder helper. Phase 4b
25//! threads these through SMGS's construction site and the migrated pilot
26//! β; Phase 4c deletes the legacy
27//! `install_compiled_flex_block_into_runtime` path.
28
29use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34    BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44/// Threshold below which a coefficient vector is treated as the trivial
45/// (all-zero) pilot point in the drift-detection audit. At β ≈ 0 the
46/// primary-state coupling g vanishes (c ≡ 1), so the frozen pilot W is exact;
47/// any |β_j| above this is "non-trivial" and requires the family scalars to
48/// re-evaluate W(β). The bound is well below any meaningful fitted coefficient.
49const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51/// Per-row 4×4 row Hessian for the survival marginal-slope likelihood at a
52/// pilot `β`. The pilot supplies the primary-state vector
53/// `(q0_i, q1_i, qd1_i, g_i)` and the per-row sample weight + event
54/// indicator + z + probit scale. The 4×4 block is evaluated via the
55/// existing `row_primary_closed_form` kernel (which already returns the
56/// full Hessian in `(q0, q1, qd1, g)` order) and PSD-clamped per row.
57pub struct SurvivalRowHessian {
58    /// PSD-projected per-row 4×4 Hessian, stored row-major as
59    /// `(n × 4 × 4)`.
60    h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64    /// Construct from explicit per-row pilot primary-state and the row
65    /// data needed by `row_primary_closed_form`. Negative eigenvalues are
66    /// projected to zero before storage so the matrix is PSD.
67    pub fn from_pilot_primary_state(
68        q0: &Array1<f64>,
69        q1: &Array1<f64>,
70        qd1: &Array1<f64>,
71        g: &Array1<f64>,
72        z: &Array1<f64>,
73        weights: &Array1<f64>,
74        event: &Array1<f64>,
75        derivative_guard: f64,
76        probit_scale: f64,
77    ) -> Result<Self, String> {
78        let n = q0.len();
79        if [
80            q1.len(),
81            qd1.len(),
82            g.len(),
83            z.len(),
84            weights.len(),
85            event.len(),
86        ]
87        .iter()
88        .any(|&l| l != n)
89        {
90            return Err(format!(
91                "SurvivalRowHessian: length mismatch \
92                 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93                q1.len(),
94                qd1.len(),
95                g.len(),
96                z.len(),
97                weights.len(),
98                event.len()
99            ));
100        }
101        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102        for i in 0..n {
103            let (_, _grad, hess) =
104                crate::survival::marginal_slope::row_primary_for_compiler(
105                    q0[i],
106                    q1[i],
107                    qd1[i],
108                    g[i],
109                    z[i],
110                    weights[i],
111                    event[i],
112                    derivative_guard,
113                    probit_scale,
114                )?;
115            // PSD-clamp via eigendecomposition: project negative eigvals to 0.
116            let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117            for a in 0..K_SURVIVAL {
118                for b in 0..K_SURVIVAL {
119                    h_i[[a, b]] = hess[a][b];
120                }
121            }
122            let clamped = psd_clamp_4x4(&h_i);
123            for a in 0..K_SURVIVAL {
124                for b in 0..K_SURVIVAL {
125                    h_full[[i, a, b]] = clamped[[a, b]];
126                }
127            }
128        }
129        Ok(Self { h: h_full })
130    }
131
132    /// Construct from an already-PSD per-row tensor. Used by callers that
133    /// have computed the Hessian via a different route.
134    pub fn from_full(h: Array3<f64>) -> Self {
135        assert_eq!(h.shape()[1], K_SURVIVAL);
136        assert_eq!(h.shape()[2], K_SURVIVAL);
137        Self { h }
138    }
139}
140
141impl RowHessian for SurvivalRowHessian {
142    fn k(&self) -> usize {
143        K_SURVIVAL
144    }
145    fn nrows(&self) -> usize {
146        self.h.shape()[0]
147    }
148    fn fill_row(&self, row: usize, out: &mut [f64]) {
149        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150        for a in 0..K_SURVIVAL {
151            for b in 0..K_SURVIVAL {
152                out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153            }
154        }
155    }
156    fn evaluate_full(&self) -> Array3<f64> {
157        self.h.clone()
158    }
159}
160
161/// `FamilyChannelHessian` for survival marginal-slope.
162///
163/// The 4×4 per-subject W_i is the Hessian of the row negative log-likelihood
164/// `ρ_i(q0, q1, qd1, g) = −δ_i log f(η_1, ad1) − (1−δ_i) log S(η_1) + log S(η_0)`
165/// with respect to the 4-vector primary state `(q0, q1, qd1, g)`.
166///
167/// Derivation of W_i entries (all from `row_primary_closed_form`):
168///
169/// - W[0,0] = u2_η0 · c²  (q0–q0; only η0 depends on q0)
170/// - W[1,1] = (u2_η1 + w·δ) · c²  (q1–q1; η1 and log-φ both depend on q1)
171/// - W[2,2] = w·δ · (∂ad1/∂qd1)² · (−1/ad1²)  (qd1–qd1 via neglog(ad1))
172/// - W[3,3] = u2_η0·(∂η0/∂g)² + u1_η0·(∂²η0/∂g²) + u2_η1·(∂η1/∂g)² + ...
173/// - W[0,3] = W[3,0] = u2_η0·c·(q0·c1 + s_f·z) + u1_η0·c1  (cross q0–g)
174/// - W[1,3] = W[3,1] = u2_η1·c·(q1·c1 + s_f·z) + u1_η1·c1  (cross q1–g)
175/// - W[2,3] = W[3,2] = u2_ad1·c·(qd1·c1) + u1_ad1·c1  (cross qd1–g)
176/// - All other off-diagonals are zero (η0, η1, ad1 depend on non-overlapping
177///   subsets of (q0,q1,qd1,g), and only g is shared across all three).
178///
179/// This is already computed by `row_primary_closed_form` and stored in
180/// `SurvivalRowHessian::h` after PSD-clamping.
181///
182/// # β-dependent W via `channel_hessian_at`
183///
184/// `channel_hessian_at` overrides the default β-independent path.  When
185/// `family_scalars` carries `SurvivalMarginalSlopeFamilyScalars`, the
186/// current per-row primary state `(q0_i, q1_i, qd1_i, g_i)` is read from
187/// those scalars and the 4×4 W_i is recomputed via `row_primary_for_compiler`.
188/// This makes `I(β) = J(β)^T W(β) J(β)` accurate at the current β instead of
189/// at the frozen pilot β=0 state.
190///
191/// When `family_scalars` is `None` but `beta` is zero-ish (all entries ≤ ε),
192/// the frozen pilot W is returned unchanged.  When `family_scalars` is `None`
193/// and any `beta` entry is non-trivial, `Err` is returned — the caller must
194/// supply scalars for a correct W at non-pilot β (same contract as T26's
195/// Jacobian callbacks: scalars required when β affects the primary state).
196impl FamilyChannelHessian for SurvivalRowHessian {
197    fn n_outputs(&self) -> usize {
198        K_SURVIVAL
199    }
200
201    fn n_subjects(&self) -> usize {
202        self.h.shape()[0]
203    }
204
205    fn fill_subject(&self, i: usize, out: &mut [f64]) {
206        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207        for a in 0..K_SURVIVAL {
208            for b in 0..K_SURVIVAL {
209                out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210            }
211        }
212    }
213
214    fn evaluate_full(&self) -> ndarray::Array3<f64> {
215        self.h.clone()
216    }
217
218    fn channel_hessian_at(
219        &self,
220        beta: &[f64],
221        family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222    ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223        use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225        let scalars_opt =
226            family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228        // Determine whether beta is non-trivial (any |β_j| > ε).
229        let beta_nontrivial = beta
230            .iter()
231            .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233        match scalars_opt {
234            None if beta_nontrivial => {
235                // β is non-zero in a way that would change W via the primary-state
236                // coupling (g ≠ 0 → c ≠ 1 → W changes).  Scalars are required.
237                Err(
238                    "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239                     family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240                     via FamilyLinearizationState::family_scalars to evaluate W(β) \
241                     correctly (same contract as T26 Jacobian callbacks)."
242                        .to_string(),
243                )
244            }
245            None => {
246                // β ≈ 0: return the frozen pilot W unchanged.
247                Ok(Arc::new(gam_problem::TensorChannelHessian {
248                    h: self.h.clone(),
249                }))
250            }
251            Some(sc) => {
252                let n = self.h.shape()[0];
253                if sc.q0_i.len() != n
254                    || sc.q1_i.len() != n
255                    || sc.qd1_i.len() != n
256                    || sc.g_i.len() != n
257                    || sc.z_i.len() != n
258                {
259                    return Err(format!(
260                        "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261                         (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262                        sc.q0_i.len(),
263                        sc.q1_i.len(),
264                        sc.qd1_i.len(),
265                        sc.g_i.len(),
266                        sc.z_i.len(),
267                    ));
268                }
269                // We do not have weights/event stored in SurvivalRowHessian itself.
270                // The scalars carry the per-row primary state; we need per-row weights
271                // and event indicators to call row_primary_for_compiler.  Those are
272                // NOT stored in SurvivalMarginalSlopeFamilyScalars — so we can only
273                // recompute W's structural shape (the 4×4 curvature geometry) using
274                // unit weights and event=1, which gives us the correct _direction_ of
275                // W at the current β even if the magnitude is off by the sample weight.
276                //
277                // For the drift-detection audit the direction matters more than the
278                // exact per-row magnitudes: rank changes emerge from structural
279                // identifiability, not from per-row weight scaling. Using w=1, d=1
280                // is therefore the principled approximation for the audit path.
281                //
282                // Production callers that need exact W (e.g. for the Fisher Gram in
283                // the compiler) should use SurvivalRowHessian::from_pilot_primary_state
284                // directly with the true per-row weights and event indicators.
285                let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286                for i in 0..n {
287                    let q0 = sc.q0_i[i];
288                    let q1 = sc.q1_i[i];
289                    let qd1 = sc.qd1_i[i];
290                    let g = sc.g_i[i];
291                    let z = sc.z_i[i];
292                    // Use unit weight and d=1 (event indicator 1) for the audit path.
293                    // The derivative_guard is the family default (small but non-zero).
294                    match crate::survival::marginal_slope::row_primary_for_compiler(
295                        q0, q1, qd1, g, z, 1.0,  // w = unit weight
296                        1.0,  // d = event
297                        crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298                        sc.s, // probit_scale from scalars
299                    ) {
300                        Ok((_nll, _grad, hess)) => {
301                            let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302                            for a in 0..K_SURVIVAL {
303                                for b in 0..K_SURVIVAL {
304                                    h_i[[a, b]] = hess[a][b];
305                                }
306                            }
307                            let clamped = psd_clamp_4x4(&h_i);
308                            for a in 0..K_SURVIVAL {
309                                for b in 0..K_SURVIVAL {
310                                    h_full[[i, a, b]] = clamped[[a, b]];
311                                }
312                            }
313                        }
314                        Err(_) => {
315                            // Monotonicity violation or other numerical issue at this
316                            // row: fall back to the frozen pilot W for this row only.
317                            for a in 0..K_SURVIVAL {
318                                for b in 0..K_SURVIVAL {
319                                    h_full[[i, a, b]] = self.h[[i, a, b]];
320                                }
321                            }
322                        }
323                    }
324                }
325                Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326            }
327        }
328    }
329}
330
331/// Project a 4×4 symmetric matrix onto the PSD cone: zero negative
332/// eigenvalues. If the eigendecomposition fails (extremely defensive —
333/// `row_primary_closed_form` already guarantees finite entries), return
334/// the diagonal with negatives clamped.
335fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336    let k = m.nrows();
337    let (evals, evecs) = match m.eigh(Side::Lower) {
338        Ok(pair) => pair,
339        Err(_) => {
340            let mut out = Array2::<f64>::zeros((k, k));
341            for i in 0..k {
342                out[[i, i]] = m[[i, i]].max(0.0);
343            }
344            return out;
345        }
346    };
347    let mut out = Array2::<f64>::zeros((k, k));
348    for i in 0..k {
349        for j in 0..k {
350            let mut acc = 0.0;
351            for l in 0..k {
352                acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353            }
354            out[[i, j]] = acc;
355        }
356    }
357    out
358}
359
360/// Row Jacobian operator for the survival time block. Channels (q0, q1,
361/// qd1) come from the three time designs; the g channel is zero.
362pub struct TimeBlockOperator {
363    dq0: Array2<f64>,
364    dq1: Array2<f64>,
365    dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369    pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370        assert_eq!(dq0.dim(), dq1.dim());
371        assert_eq!(dq0.dim(), dqd1.dim());
372        Self { dq0, dq1, dqd1 }
373    }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377    fn k(&self) -> usize {
378        K_SURVIVAL
379    }
380    fn ncols(&self) -> usize {
381        self.dq0.ncols()
382    }
383    fn nrows(&self) -> usize {
384        self.dq0.nrows()
385    }
386    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387        assert_eq!(out.len(), K_SURVIVAL);
388        assert_eq!(delta_beta.len(), self.dq0.ncols());
389        let mut acc = [0.0_f64; K_SURVIVAL];
390        for (j, &b) in delta_beta.iter().enumerate() {
391            acc[0] += self.dq0[[row, j]] * b;
392            acc[1] += self.dq1[[row, j]] * b;
393            acc[2] += self.dqd1[[row, j]] * b;
394        }
395        out.copy_from_slice(&acc);
396    }
397    fn evaluate_full(&self) -> Array3<f64> {
398        let n = self.dq0.nrows();
399        let p = self.dq0.ncols();
400        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401        for i in 0..n {
402            for j in 0..p {
403                out[[i, j, 0]] = self.dq0[[i, j]];
404                out[[i, j, 1]] = self.dq1[[i, j]];
405                out[[i, j, 2]] = self.dqd1[[i, j]];
406            }
407        }
408        out
409    }
410    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411        // Scale straight out of the three compact `(n, p)` channel designs —
412        // the compiler consumes the `(n·K, p)` sqrt(H)-scaled design, so the
413        // dense `(n, p, K)` tensor (3 of its 4 channels held explicitly, the
414        // 4th identically zero) that `evaluate_full()` builds is never needed.
415        // (#738: a capability is not a representation.)
416        let n = self.dq0.nrows();
417        let p = self.dq0.ncols();
418        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419            0 => self.dq0[[i, a]],
420            1 => self.dq1[[i, a]],
421            2 => self.dqd1[[i, a]],
422            _ => 0.0,
423        })
424    }
425}
426
427/// Row Jacobian operator for a block whose contribution flows into the
428/// q-channels (q0 and q1 identically) and optionally the qd1 channel.
429/// Covers the survival marginal, score-warp, and link-dev blocks (all
430/// three share the structural property `δq0 = δq1 = basis·δβ`, `δg = 0`).
431pub struct QChannelBlockOperator {
432    dq: Array2<f64>,
433    dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437    pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438        assert_eq!(dq.dim(), dqd1.dim());
439        Self { dq, dqd1 }
440    }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444    fn k(&self) -> usize {
445        K_SURVIVAL
446    }
447    fn ncols(&self) -> usize {
448        self.dq.ncols()
449    }
450    fn nrows(&self) -> usize {
451        self.dq.nrows()
452    }
453    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454        assert_eq!(out.len(), K_SURVIVAL);
455        assert_eq!(delta_beta.len(), self.dq.ncols());
456        let mut dq_acc = 0.0;
457        let mut dqd_acc = 0.0;
458        for (j, &b) in delta_beta.iter().enumerate() {
459            dq_acc += self.dq[[row, j]] * b;
460            dqd_acc += self.dqd1[[row, j]] * b;
461        }
462        out[0] = dq_acc;
463        out[1] = dq_acc;
464        out[2] = dqd_acc;
465        out[3] = 0.0;
466    }
467    fn evaluate_full(&self) -> Array3<f64> {
468        let n = self.dq.nrows();
469        let p = self.dq.ncols();
470        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471        for i in 0..n {
472            for j in 0..p {
473                let v = self.dq[[i, j]];
474                out[[i, j, 0]] = v;
475                out[[i, j, 1]] = v;
476                out[[i, j, 2]] = self.dqd1[[i, j]];
477            }
478        }
479        out
480    }
481    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482        // q0 and q1 share `dq`; qd1 is `dqd1`; the g channel is identically
483        // zero. Scale directly from the compact `(n, p)` designs, skipping the
484        // dense `(n, p, K)` tensor `evaluate_full()` would build. (#738.)
485        let n = self.dq.nrows();
486        let p = self.dq.ncols();
487        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488            0 | 1 => self.dq[[i, a]],
489            2 => self.dqd1[[i, a]],
490            _ => 0.0,
491        })
492    }
493}
494
495/// Row Jacobian operator for the survival logslope block: contribution
496/// lives entirely on the g channel.
497pub struct LogslopeBlockOperator {
498    dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502    pub fn new(dg: Array2<f64>) -> Self {
503        Self { dg }
504    }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508    fn k(&self) -> usize {
509        K_SURVIVAL
510    }
511    fn ncols(&self) -> usize {
512        self.dg.ncols()
513    }
514    fn nrows(&self) -> usize {
515        self.dg.nrows()
516    }
517    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518        assert_eq!(out.len(), K_SURVIVAL);
519        assert_eq!(delta_beta.len(), self.dg.ncols());
520        let mut acc = 0.0;
521        for (j, &b) in delta_beta.iter().enumerate() {
522            acc += self.dg[[row, j]] * b;
523        }
524        out[0] = 0.0;
525        out[1] = 0.0;
526        out[2] = 0.0;
527        out[3] = acc;
528    }
529    fn evaluate_full(&self) -> Array3<f64> {
530        let n = self.dg.nrows();
531        let p = self.dg.ncols();
532        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533        for i in 0..n {
534            for j in 0..p {
535                out[[i, j, 3]] = self.dg[[i, j]];
536            }
537        }
538        out
539    }
540    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541        // The logslope contribution lives entirely on the g channel (3); the
542        // other three channels are identically zero. Scale directly from the
543        // compact `(n, p)` design, skipping the mostly-zero dense `(n, p, K)`
544        // tensor `evaluate_full()` would build. (#738.)
545        let n = self.dg.nrows();
546        let p = self.dg.ncols();
547        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548            if c == 3 { self.dg[[i, a]] } else { 0.0 }
549        })
550    }
551}
552
553/// Inputs assembled for the survival fit driver to feed `compile()`. The
554/// ordering follows `gauge_priority` descending (time=200 → marginal=150 →
555/// logslope=120 → score_warp=80 → link_dev=60).
556pub struct SurvivalCompilerInputs {
557    pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558    pub ordering: Vec<BlockOrder>,
559}
560
561/// Per-block V reparameterisation matrices for the three parametric
562/// survival blocks emitted by [`compile_survival_parametric_designs`].
563/// Each `v_*` is a `(p_block_raw × p_block_kept)` selection-or-rotation
564/// matrix that maps a `β_kept` coefficient vector to its `β_raw`
565/// equivalent: `β_raw = V · β_kept`. The construction site applies these
566/// to the raw block designs (`design_raw · V → design_compiled`) and
567/// to the penalties (`Vᵀ S V`) before building `ParameterBlockSpec`s
568/// and passing the compiled designs into `make_family`.
569///
570/// Phase-4b architecture: this is the seam where the family-agnostic
571/// row-Jacobian compiler hands control back to the family-specific
572/// construction site. Each `v_*` width equals the corresponding
573/// `CompiledBlocks::blocks[i].t_lw.ncols()` — i.e., the kept-direction
574/// count after sqrt-H-metric residualisation and post-walk RRQR
575/// trailing-pivot drop.
576pub struct SurvivalParametricCompiled {
577    pub v_time: Array2<f64>,
578    pub v_marginal: Array2<f64>,
579    pub v_logslope: Array2<f64>,
580    /// Per-block dropped raw-column count, indexed
581    /// `(time_dropped, marginal_dropped, logslope_dropped)`. Equal to
582    /// `(p_raw − v.ncols())` for each block. Useful for logging the
583    /// gauge-attribution summary at the construction site.
584    pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588    raw: DesignMatrix,
589    v: &Array2<f64>,
590    context: &str,
591) -> Result<DesignMatrix, String> {
592    if raw.ncols() != v.nrows() {
593        return Err(format!(
594            "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595            raw.ncols(),
596            v.nrows(),
597            v.nrows(),
598            v.ncols(),
599        ));
600    }
601    let inner_dense = match raw {
602        DesignMatrix::Dense(d) => d,
603        DesignMatrix::Sparse(_) => {
604            let dense = raw
605                .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606                .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607            DenseDesignMatrix::from(dense)
608        }
609    };
610    let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611        .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612    Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615/// Per-term V reparameterisation matrices for the three parametric
616/// survival blocks. Each block's full V is the block-diagonal assembly
617/// of its per-term V's (one entry per element of the input
618/// `*_partition`). Preserves per-term penalty structure: applying
619/// `V_b = block_diag(V_term1, ..., V_termM)` to a per-term BlockwisePenalty
620/// pulls each penalty back only via its OWN term's V, so what was a
621/// per-term λ tunable in REML stays per-term tunable.
622pub struct SurvivalParametricCompiledPerTerm {
623    pub v_time_per_term: Vec<Array2<f64>>,
624    pub v_marginal_per_term: Vec<Array2<f64>>,
625    pub v_logslope_per_term: Vec<Array2<f64>>,
626    /// Per-term residualised reparam `R_b = M_b · V_b` from the
627    /// identifiability compiler, in the same global compile order
628    /// (time terms, then marginal terms, then logslope terms). `None`
629    /// for the very first compiled block (no anchor). Used by the
630    /// V+M-exact apply path to emit residualised rows
631    /// `C_b·V_b − A_{<b}·R_b` and to assemble the full triangular T.
632    pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633    /// Per-block drops (raw_cols − sum(kept_cols across terms)).
634    pub drops_by_block: (usize, usize, usize),
635}
636
637/// Per-term-aware compile: residualise each block's TERMS individually
638/// in priority order so the emitted V is block-diagonal on term
639/// boundaries. This preserves the per-term penalty structure that
640/// REML's per-λ accounting depends on.
641///
642/// Each `*_partition` is a list of disjoint contiguous column ranges
643/// covering `[0..p_block)`. For the marginal/logslope blocks the
644/// natural source is the union of `BlockwisePenalty::col_range` values
645/// (one per smoothness penalty / term) plus the complement
646/// (unpenalised parametric columns).
647///
648/// Order of residualisation: time terms first (in their partition
649/// order), then marginal terms, then logslope terms. Within each
650/// block, terms are residualised against ALL prior anchor columns
651/// (terms from earlier blocks + earlier terms within this block).
652/// Aliased directions land in the lowest-priority block that contains
653/// them, in the natural term order within that block — matching the
654/// gauge-priority ownership contract.
655pub fn compile_survival_parametric_designs_per_term(
656    time_dq0: Array2<f64>,
657    time_dq1: Array2<f64>,
658    time_dqd1: Array2<f64>,
659    time_partition: &[std::ops::Range<usize>],
660    marginal_dq: Array2<f64>,
661    marginal_dqd1: Array2<f64>,
662    marginal_partition: &[std::ops::Range<usize>],
663    logslope_dg: Array2<f64>,
664    logslope_partition: &[std::ops::Range<usize>],
665    row_hess: &dyn RowHessian,
666) -> Result<SurvivalParametricCompiledPerTerm, String> {
667    use gam_identifiability::families::compiler::compile;
668
669    let p_time = time_dq0.ncols();
670    let p_marg = marginal_dq.ncols();
671    let p_log = logslope_dg.ncols();
672    validate_partition(time_partition, p_time, "time")?;
673    validate_partition(marginal_partition, p_marg, "marginal")?;
674    validate_partition(logslope_partition, p_log, "logslope")?;
675
676    // Build per-term operators. Each term gets its own RowJacobianOperator
677    // restricted to its column slice; the operator type matches the
678    // block's K-channel signature (Time, QChannel, Logslope).
679    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
680    let mut ordering: Vec<BlockOrder> = Vec::new();
681    for range in time_partition {
682        let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
683        let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
684        let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
685        operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
686        ordering.push(BlockOrder::Time);
687    }
688    for range in marginal_partition {
689        let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
690        let dqd1 = marginal_dqd1
691            .slice(ndarray::s![.., range.clone()])
692            .to_owned();
693        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
694        ordering.push(BlockOrder::Marginal);
695    }
696    for range in logslope_partition {
697        let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
698        operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
699        ordering.push(BlockOrder::Logslope);
700    }
701
702    let compiled = compile(&operators, row_hess, &ordering).map_err(|e| {
703        format!("identifiability::families::compiler::compile (per-term) failed: {e}")
704    })?;
705    let blocks = compiled.blocks;
706    let n_time = time_partition.len();
707    let n_marg = marginal_partition.len();
708    let n_log = logslope_partition.len();
709    if blocks.len() != n_time + n_marg + n_log {
710        return Err(format!(
711            "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
712            n_time + n_marg + n_log,
713            n_time,
714            n_marg,
715            n_log,
716            blocks.len(),
717        ));
718    }
719    let mut iter = blocks.into_iter();
720    let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
721    let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
722    for _ in 0..n_time {
723        let blk = iter.next().unwrap();
724        v_time_per_term.push(blk.t_lw);
725        r_time_per_term.push(blk.r_lw);
726    }
727    let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
728    let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
729    for _ in 0..n_marg {
730        let blk = iter.next().unwrap();
731        v_marginal_per_term.push(blk.t_lw);
732        r_marginal_per_term.push(blk.r_lw);
733    }
734    let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
735    let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
736    for _ in 0..n_log {
737        let blk = iter.next().unwrap();
738        v_logslope_per_term.push(blk.t_lw);
739        r_logslope_per_term.push(blk.r_lw);
740    }
741    let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
742    r_lw_per_term.extend(r_time_per_term);
743    r_lw_per_term.extend(r_marginal_per_term);
744    r_lw_per_term.extend(r_logslope_per_term);
745    let drops_time: usize = time_partition
746        .iter()
747        .zip(v_time_per_term.iter())
748        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
749        .sum();
750    let drops_marg: usize = marginal_partition
751        .iter()
752        .zip(v_marginal_per_term.iter())
753        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
754        .sum();
755    let drops_log: usize = logslope_partition
756        .iter()
757        .zip(v_logslope_per_term.iter())
758        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
759        .sum();
760    Ok(SurvivalParametricCompiledPerTerm {
761        v_time_per_term,
762        v_marginal_per_term,
763        v_logslope_per_term,
764        r_lw_per_term,
765        drops_by_block: (drops_time, drops_marg, drops_log),
766    })
767}
768
769fn validate_partition(
770    partition: &[std::ops::Range<usize>],
771    p_block: usize,
772    label: &str,
773) -> Result<(), String> {
774    if partition.is_empty() {
775        if p_block == 0 {
776            return Ok(());
777        }
778        return Err(format!(
779            "{label} partition empty but block has p={p_block} columns"
780        ));
781    }
782    if partition[0].start != 0 {
783        return Err(format!(
784            "{label} partition must start at 0, got start={}",
785            partition[0].start
786        ));
787    }
788    if partition.last().unwrap().end != p_block {
789        return Err(format!(
790            "{label} partition must cover [0, {p_block}); last range ends at {}",
791            partition.last().unwrap().end
792        ));
793    }
794    for w in partition.windows(2) {
795        if w[0].end != w[1].start {
796            return Err(format!(
797                "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
798                w[0].start, w[0].end, w[1].start, w[1].end
799            ));
800        }
801        if w[0].is_empty() {
802            return Err(format!(
803                "{label} partition has empty range [{}..{})",
804                w[0].start, w[0].end
805            ));
806        }
807    }
808    if partition.last().unwrap().is_empty() {
809        return Err(format!("{label} partition's final range is empty",));
810    }
811    Ok(())
812}
813
814/// Derive a disjoint contiguous partition of `[0..p_block)` from a
815/// list of BlockwisePenalty col_ranges. Distinct penalty ranges define
816/// term boundaries; gaps between them (unpenalised columns) become
817/// their own single-column partitions. Multiple penalties with the
818/// SAME col_range (e.g. tensor anisotropy axes) coalesce to one term.
819pub fn extract_term_partition_from_penalty_ranges(
820    p_block: usize,
821    penalty_ranges: &[std::ops::Range<usize>],
822) -> Vec<std::ops::Range<usize>> {
823    use std::collections::BTreeSet;
824    let mut starts: BTreeSet<usize> = BTreeSet::new();
825    starts.insert(0);
826    starts.insert(p_block);
827    for r in penalty_ranges {
828        starts.insert(r.start.min(p_block));
829        starts.insert(r.end.min(p_block));
830    }
831    let v: Vec<usize> = starts.into_iter().collect();
832    v.windows(2)
833        .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
834        .collect()
835}
836
837/// Pull a single raw block-local [`BlockwisePenalty`] back through the
838/// block's own diagonal reparameterisation `V_b` (the `(b, b)` block of
839/// the triangular T), producing a per-block-width compiled penalty.
840///
841/// The penalty's `local` is `pen.col_range.len()` square and covers a
842/// sub-region of the raw block at offset `pen.col_range.start` (which is
843/// block-local, i.e. relative to the block's first raw column). It is
844/// embedded into the full raw block width `v_block.nrows()` at that
845/// offset, then pulled back as `V_bᵀ · embed(S) · V_b`, giving a
846/// `(w_b_compiled × w_b_compiled)` symmetric `PenaltyMatrix::Dense`
847/// where `w_b_compiled == v_block.ncols()`.
848///
849/// This is the penalty contract a per-block `ParameterBlockSpec`
850/// requires: each block's penalty acts on that block's own compiled
851/// coordinate `θ_b`. The cross-block residualisation `R_{a→b}` carried
852/// in T's strict-upper triangle is absorbed into the *design* columns
853/// (the residualised emitted design `C_b V_b − A_{<b} R_b`), not into
854/// the penalty — exactly as the VM-exact compile-map path
855/// [`apply_compiled_map_to_designs`] does. Pulling the penalty back
856/// through the full joint T instead would yield a `(p_compiled × p_compiled)` dense
857/// matrix that cannot live in a single block's `penalties` slot and
858/// would violate the `p_b × p_b` block-spec validation.
859pub fn pull_back_blockwise_penalty_through_block_v(
860    pen: &gam_terms::smooth::BlockwisePenalty,
861    v_block: &Array2<f64>,
862) -> Result<PenaltyMatrix, String> {
863    let raw_p = v_block.nrows();
864    let compiled_p = v_block.ncols();
865    let block_p = pen.col_range.len();
866    let embed_start = pen.col_range.start;
867    let embed_end = pen.col_range.end;
868    if embed_end > raw_p {
869        return Err(format!(
870            "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
871             exceeds block raw width {raw_p}"
872        ));
873    }
874    if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
875        return Err(format!(
876            "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
877             width is {block_p}",
878            pen.local.nrows(),
879            pen.local.ncols(),
880        ));
881    }
882    let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
883    if block_p > 0 {
884        let mut dst =
885            embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
886        for i in 0..block_p {
887            for j in 0..block_p {
888                dst[[i, j]] = pen.local[[i, j]];
889            }
890        }
891    }
892    // V_bᵀ · embed(S) · V_b → (compiled_p × compiled_p).
893    let temp = embedded.dot(v_block);
894    let pulled = v_block.t().dot(&temp);
895    let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
896    for i in 0..compiled_p {
897        for j in 0..compiled_p {
898            sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
899        }
900    }
901    Ok(PenaltyMatrix::Dense(sym))
902}
903
904/// Assemble a 3-block [`CompiledMap`] (time, marginal, logslope) from a
905/// [`SurvivalParametricCompiledPerTerm`] produced by the full 4×4 row-Hessian
906/// driver [`compile_survival_parametric_designs_per_term`].
907///
908/// The full global triangular `T` is built from the per-term `V`/`R` blocks
909/// (diagonal `V_b`, strict-upper `−R_{a→b}` — identical to the matrix the
910/// result-time lift [`Gauge::from_v_and_r`] uses), then partitioned
911/// into the three *block* ranges (raw = summed per-term raw widths, compiled =
912/// summed per-term kept widths). The resulting `CompiledMap` is interchangeable
913/// with one from
914/// [`gam_identifiability::families::compiler::compile_from_raw_grams`], so the
915/// existing [`apply_compiled_map_to_designs`] +
916/// [`Gauge::from_compiled_map`] machinery consumes it unchanged.
917///
918/// This is the seam that lets the survival closed-form fast path engage on the
919/// *correct* identifiable quotient: the cheap η₁-only rawstack metric can
920/// falsely collapse a whole channel (marginal/logslope share a PC surface in
921/// the η₁ row curvature), but the full survival row Hessian is 4×4 in
922/// `(q0, q1, qd1, g)` and chains differently into each block, so it keeps the
923/// channels distinct when no *true* alias exists. The reduced basis it emits
924/// goes to Newton in place of the rank-deficient raw basis.
925pub fn compiled_map_from_per_term(
926    compiled: &SurvivalParametricCompiledPerTerm,
927) -> gam_identifiability::families::compiler::CompiledMap {
928    // Per-term V's and R's in global compile order: time terms, then marginal,
929    // then logslope — exactly the order `r_lw_per_term` is stored in.
930    let mut v_all: Vec<Array2<f64>> = Vec::new();
931    v_all.extend(compiled.v_time_per_term.iter().cloned());
932    v_all.extend(compiled.v_marginal_per_term.iter().cloned());
933    v_all.extend(compiled.v_logslope_per_term.iter().cloned());
934
935    let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
936
937    // Per-block raw / compiled widths = summed per-term widths within the block.
938    let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
939    let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
940    let raw_time = raw_w(&compiled.v_time_per_term);
941    let raw_marg = raw_w(&compiled.v_marginal_per_term);
942    let raw_log = raw_w(&compiled.v_logslope_per_term);
943    let kept_time = kept_w(&compiled.v_time_per_term);
944    let kept_marg = kept_w(&compiled.v_marginal_per_term);
945    let kept_log = kept_w(&compiled.v_logslope_per_term);
946
947    let raw_block_ranges = vec![
948        0..raw_time,
949        raw_time..(raw_time + raw_marg),
950        (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
951    ];
952    let compiled_block_ranges = vec![
953        0..kept_time,
954        kept_time..(kept_time + kept_marg),
955        (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
956    ];
957
958    gam_identifiability::families::compiler::CompiledMap {
959        raw_from_compiled: t_full,
960        compiled_block_ranges,
961        raw_block_ranges,
962    }
963}
964
965/// Apply a global [`CompiledMap`] T directly to the three survival
966/// parametric block designs (time/marginal/logslope). Slices the
967/// per-block diagonal of T into `V_b = T[raw_range_b, compiled_range_b]`
968/// (shape `p_b_raw × w_b_compiled`), wraps each channel's raw design via
969/// [`wrap_design_with_transform`], and pulls each block's penalties back
970/// through that block's OWN `V_b` via
971/// [`pull_back_blockwise_penalty_through_block_v`], producing
972/// per-block-width `(w_b_compiled × w_b_compiled)` penalties — the shape
973/// a per-block `ParameterBlockSpec.penalties` slot requires.
974///
975/// `map.raw_block_ranges` must equal three contiguous ranges in the
976/// order Time → Marginal → Logslope (matching the input designs).
977/// `map.compiled_block_ranges` runs in the same order.
978///
979/// Penalties supplied to this function:
980/// - `time_penalties` are `BlockwisePenalty`s whose `col_range` is in
981///   the time block's local raw coords (e.g. `0..p_time`).
982/// - `marginal_penalties` / `logslope_penalties` likewise — local to
983///   their own channel's raw width.
984///
985/// Each penalty's block-local `col_range` is embedded into the block's
986/// raw width and pulled back as `V_bᵀ S_b V_b`. The cross-block
987/// residualisation `R_{a→b}` carried in T's strict-upper triangle is
988/// absorbed into the residualised *design* columns, not the penalty, so
989/// the per-block penalty model stays exact for the highest-priority
990/// block (time, no anchor → `R = []`) and matches the sibling per-block
991/// compile path for the rest.
992pub fn apply_compiled_map_to_designs(
993    map: &gam_identifiability::families::compiler::CompiledMap,
994    time_design_entry: DesignMatrix,
995    time_design_exit: DesignMatrix,
996    time_design_derivative_exit: DesignMatrix,
997    marginal_design: DesignMatrix,
998    logslope_design: DesignMatrix,
999    time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1000    marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1001    logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1002) -> Result<CompiledSurvivalDesignsVMExact, String> {
1003    if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1004        return Err(format!(
1005            "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1006             got {} raw / {} compiled",
1007            map.raw_block_ranges.len(),
1008            map.compiled_block_ranges.len(),
1009        ));
1010    }
1011    let time_raw = map.raw_block_ranges[0].clone();
1012    let marg_raw = map.raw_block_ranges[1].clone();
1013    let log_raw = map.raw_block_ranges[2].clone();
1014    let time_compiled = map.compiled_block_ranges[0].clone();
1015    let marg_compiled = map.compiled_block_ranges[1].clone();
1016    let log_compiled = map.compiled_block_ranges[2].clone();
1017
1018    let t = &map.raw_from_compiled;
1019    let raw_total = t.nrows();
1020    let compiled_total = t.ncols();
1021    let expected_raw_total = log_raw.end;
1022    if raw_total != expected_raw_total {
1023        return Err(format!(
1024            "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1025             {expected_raw_total}"
1026        ));
1027    }
1028    let expected_compiled_total = log_compiled.end;
1029    if compiled_total != expected_compiled_total {
1030        return Err(format!(
1031            "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1032             sum to {expected_compiled_total}"
1033        ));
1034    }
1035
1036    let v_time = t
1037        .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1038        .to_owned();
1039    let v_marg = t
1040        .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1041        .to_owned();
1042    let v_log = t
1043        .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1044        .to_owned();
1045
1046    let time_entry_out =
1047        wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1048    let time_exit_out =
1049        wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1050    let time_deriv_out = wrap_design_with_transform(
1051        time_design_derivative_exit,
1052        &v_time,
1053        "compiled-map: time derivative_exit",
1054    )?;
1055    let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1056    let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1057
1058    // Pull each block's penalties back through that block's OWN diagonal
1059    // reparameterisation V_b (= the (b, b) block of T). This produces a
1060    // per-block-width `(w_b_compiled × w_b_compiled)` penalty — the only
1061    // shape a per-block `ParameterBlockSpec.penalties` slot accepts.
1062    //
1063    // The block-local penalty `V_bᵀ S_b V_b` is the correct per-block
1064    // penalty: in raw coords the model penalises `γ_bᵀ S_b γ_b` on block
1065    // b's own coefficients, and under the residualised reparameterisation
1066    // the cross-block carry `R_{a→b}` lives entirely in the *design*
1067    // columns (`C_b V_b − A_{<b} R_b`), not in the penalty.
1068    //
1069    // Pulling penalties back through the full joint triangular T instead
1070    // (`Tᵀ blkdiag(S_b) T`) yields a `(p_compiled × p_compiled)` dense
1071    // matrix whose off-diagonal couples θ_b to earlier blocks' θ_a;
1072    // jamming that joint-width matrix into a single block's `penalties`
1073    // produced the `block 0 penalty 0 must be 12x12, got 17x17` mismatch
1074    // that surfaced as the `assert_valid_blockspecs` FFI panic. The two
1075    // agree whenever the residualisation `R_{a→b}` lands in the null space
1076    // of S_a (the shared low-order / parametric directions the identifiable
1077    // quotient strips), which is the case the compiler targets.
1078    let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1079                    v_block: &Array2<f64>,
1080                    channel: &str|
1081     -> Result<Vec<PenaltyMatrix>, String> {
1082        pens.iter()
1083            .map(|p| {
1084                pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1085                    format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1086                })
1087            })
1088            .collect()
1089    };
1090
1091    let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1092    let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1093    let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1094    validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1095    validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1096    validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1097
1098    Ok(CompiledSurvivalDesignsVMExact {
1099        time_design_entry: time_entry_out,
1100        time_design_exit: time_exit_out,
1101        time_design_derivative_exit: time_deriv_out,
1102        marginal_design: marg_out,
1103        logslope_design: log_out,
1104        time_penalties,
1105        marginal_penalties,
1106        logslope_penalties,
1107    })
1108}
1109
1110fn validate_block_penalty_shapes(
1111    block: &str,
1112    width: usize,
1113    penalties: &[PenaltyMatrix],
1114) -> Result<(), String> {
1115    for (idx, penalty) in penalties.iter().enumerate() {
1116        let shape = penalty.shape();
1117        if shape != (width, width) {
1118            return Err(format!(
1119                "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1120                shape.0, shape.1
1121            ));
1122        }
1123    }
1124    Ok(())
1125}
1126
1127/// Run the identifiability compiler on the three survival parametric
1128/// blocks (time, marginal, logslope) at a pilot β and return the per-
1129/// block V reparameterisation matrices.
1130///
1131/// `row_hess` must be a PSD per-row 4×4 Hessian of `−log L_i(u_i)` at
1132/// the pilot β (see [`SurvivalRowHessian::from_pilot_primary_state`]).
1133/// The compiler residualises blocks left-to-right in priority order
1134/// (time → marginal → logslope) in the sqrt-H-metric so any aliased
1135/// direction lands in the lower-priority block, then runs a post-walk
1136/// column-pivoted QR on the cumulative anchor and drops trailing
1137/// pivots from the latest block. The returned V matrices are ready to
1138/// be applied to each block's raw design and penalty before the
1139/// `ParameterBlockSpec` list is assembled.
1140///
1141/// On `FullyAliased` from `compile()` (a block fully absorbed by its
1142/// cumulative anchor) this returns `Err`. The construction site should
1143/// surface that as a structured user-facing diagnostic — the model is
1144/// asking the compiler to assign zero degrees of freedom to a named
1145/// parametric block, which is a model-spec bug not a numerical one.
1146///
1147/// Sibling Phase-4b wiring (`bernoulli_marginal_slope::install_compiled_flex_block_into_runtime`)
1148/// already calls `compile()` for the flex blocks. This helper extends
1149/// that contract to the parametric blocks by giving the SMGS
1150/// construction site a one-line entry point — it does NOT yet apply
1151/// the V transforms to the family's captured designs (the captured-
1152/// design update is the remaining integration step that touches the
1153/// family's row-Hessian assembly assertions).
1154pub fn compile_survival_parametric_designs(
1155    time_dq0: Array2<f64>,
1156    time_dq1: Array2<f64>,
1157    time_dqd1: Array2<f64>,
1158    marginal_dq: Array2<f64>,
1159    marginal_dqd1: Array2<f64>,
1160    logslope_dg: Array2<f64>,
1161    row_hess: &dyn RowHessian,
1162) -> Result<SurvivalParametricCompiled, String> {
1163    use gam_identifiability::families::compiler::compile;
1164
1165    let p_time_raw = time_dq0.ncols();
1166    let p_marg_raw = marginal_dq.ncols();
1167    let p_log_raw = logslope_dg.ncols();
1168
1169    let inputs = build_survival_compiler_inputs(
1170        time_dq0,
1171        time_dq1,
1172        time_dqd1,
1173        marginal_dq,
1174        marginal_dqd1,
1175        logslope_dg,
1176        None,
1177        None,
1178    );
1179    if inputs.operators.len() != 3 {
1180        return Err(format!(
1181            "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1182             (time, marginal, logslope); got {}",
1183            inputs.operators.len(),
1184        ));
1185    }
1186    let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1187        .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1188    if compiled.blocks.len() != 3 {
1189        return Err(format!(
1190            "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1191            compiled.blocks.len(),
1192        ));
1193    }
1194    let v_time = compiled.blocks[0].t_lw.clone();
1195    let v_marginal = compiled.blocks[1].t_lw.clone();
1196    let v_logslope = compiled.blocks[2].t_lw.clone();
1197    let drops_by_block = (
1198        p_time_raw.saturating_sub(v_time.ncols()),
1199        p_marg_raw.saturating_sub(v_marginal.ncols()),
1200        p_log_raw.saturating_sub(v_logslope.ncols()),
1201    );
1202    Ok(SurvivalParametricCompiled {
1203        v_time,
1204        v_marginal,
1205        v_logslope,
1206        drops_by_block,
1207    })
1208}
1209
1210/// Build the operator stack from already-materialised dense designs.
1211///
1212/// `time_dq0/dq1/dqd1` are the time block's three primary-state Jacobians
1213/// at training rows. `marginal_dq` and `marginal_dqd1` are the marginal
1214/// block's contributions to q (shared between q0 and q1) and to qd1
1215/// (typically zero unless timewiggle interacts). `logslope_dg` is the
1216/// logslope block's contribution to g.
1217///
1218/// `score_warp_(dq, dqd1)` / `link_dev_(dq, dqd1)` are present only when
1219/// the corresponding flex block is active. The returned `ordering` parallels
1220/// `operators` so the caller can route compiled outputs back to runtime slots.
1221pub fn build_survival_compiler_inputs(
1222    time_dq0: Array2<f64>,
1223    time_dq1: Array2<f64>,
1224    time_dqd1: Array2<f64>,
1225    marginal_dq: Array2<f64>,
1226    marginal_dqd1: Array2<f64>,
1227    logslope_dg: Array2<f64>,
1228    score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1229    link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1230) -> SurvivalCompilerInputs {
1231    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1232    let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1233
1234    operators.push(Arc::new(TimeBlockOperator::new(
1235        time_dq0, time_dq1, time_dqd1,
1236    )));
1237    ordering.push(BlockOrder::Time);
1238
1239    operators.push(Arc::new(QChannelBlockOperator::new(
1240        marginal_dq,
1241        marginal_dqd1,
1242    )));
1243    ordering.push(BlockOrder::Marginal);
1244
1245    operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1246    ordering.push(BlockOrder::Logslope);
1247
1248    if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1249        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1250        ordering.push(BlockOrder::ScoreWarp);
1251    }
1252    if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1253        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1254        ordering.push(BlockOrder::LinkDev);
1255    }
1256
1257    SurvivalCompilerInputs {
1258        operators,
1259        ordering,
1260    }
1261}
1262
1263/// V+M-exact compiled designs + per-block penalties for the survival
1264/// time/marginal/logslope blocks, produced by
1265/// [`apply_compiled_map_to_designs`] from a `CompiledMap`. The
1266/// construction site swaps raw designs/penalties for these compiled
1267/// versions before building `ParameterBlockSpec`s.
1268///
1269/// The emitted designs carry the exact residualised `C_b·V_b − A_{<b}·R_b`
1270/// row form (via [`wrap_design_with_transform`] on `V_b = T[raw_b, comp_b]`):
1271/// the cross-block residualisation `R_{a→b}` lives in those design columns,
1272/// while each block's penalty is pulled back through that block's own
1273/// diagonal `V_b` as `V_bᵀ S_b V_b` (the `*_penalties` fields).
1274///
1275/// At fit result the joint compiled β is lifted back to raw via the
1276/// `gam_solve::gauge::Gauge` built from the *same* `CompiledMap`
1277/// (`β_raw = T · θ`, T block-upper-triangular with `V_b` on the diagonal
1278/// and `-R_{a→b}` off-diagonal). The full T therefore lives on that
1279/// `Gauge`, not on this struct — the caller holds the `CompiledMap` and
1280/// constructs both from it, so duplicating T here would be dead state.
1281pub struct CompiledSurvivalDesignsVMExact {
1282    pub time_design_entry: DesignMatrix,
1283    pub time_design_exit: DesignMatrix,
1284    pub time_design_derivative_exit: DesignMatrix,
1285    pub marginal_design: DesignMatrix,
1286    pub logslope_design: DesignMatrix,
1287    /// Per-block penalties, each pulled back through that block's OWN
1288    /// diagonal reparameterisation `V_b` as `V_bᵀ S_b V_b`. The result
1289    /// is a per-block-width `PenaltyMatrix::Dense`
1290    /// (`w_b_compiled × w_b_compiled`) — the shape a per-block
1291    /// `ParameterBlockSpec.penalties` slot requires. Cross-block
1292    /// residualisation `R_{a→b}` is carried by the residualised design
1293    /// columns, not the penalty.
1294    pub time_penalties: Vec<PenaltyMatrix>,
1295    pub marginal_penalties: Vec<PenaltyMatrix>,
1296    pub logslope_penalties: Vec<PenaltyMatrix>,
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301    use super::*;
1302    use gam_problem::Gauge;
1303
1304    #[test]
1305    fn psd_clamp_zeros_negative_eigenvalues() {
1306        // Construct M = U diag(2, -1, 0.5, -0.25) Uᵀ for a fixed U from
1307        // a small rotation, verify the clamped matrix has eigenvalues
1308        // (2, 0, 0.5, 0).
1309        let mut m = Array2::<f64>::zeros((4, 4));
1310        // Diagonal with mixed signs is sufficient for the test: the
1311        // eigenvalues equal the diagonal and the eigenvectors are e_i.
1312        m[[0, 0]] = 2.0;
1313        m[[1, 1]] = -1.0;
1314        m[[2, 2]] = 0.5;
1315        m[[3, 3]] = -0.25;
1316        let clamped = psd_clamp_4x4(&m);
1317        assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1318        assert!(clamped[[1, 1]].abs() < 1e-12);
1319        assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1320        assert!(clamped[[3, 3]].abs() < 1e-12);
1321    }
1322
1323    #[test]
1324    fn time_block_operator_evaluate_full_shape() {
1325        let n = 6;
1326        let p = 3;
1327        let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1328        let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1329        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1330        let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1331        let full = op.evaluate_full();
1332        assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1333        for i in 0..n {
1334            for j in 0..p {
1335                assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1336                assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1337                assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1338                assert_eq!(full[[i, j, 3]], 0.0);
1339            }
1340        }
1341    }
1342
1343    #[test]
1344    fn q_channel_block_apply_row_shares_q0_q1() {
1345        let n = 5;
1346        let p = 2;
1347        let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1348        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1349        let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1350        let mut out = [0.0_f64; K_SURVIVAL];
1351        let delta = [1.0_f64, -0.5];
1352        op.apply_row(3, &delta, &mut out);
1353        let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1354        let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1355        assert!((out[0] - want_q).abs() < 1e-12);
1356        assert!((out[1] - want_q).abs() < 1e-12);
1357        assert!((out[2] - want_qd).abs() < 1e-12);
1358        assert_eq!(out[3], 0.0);
1359    }
1360
1361    #[test]
1362    fn logslope_block_writes_only_g_channel() {
1363        let n = 4;
1364        let p = 2;
1365        let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1366        let op = LogslopeBlockOperator::new(dg.clone());
1367        let mut out = [0.0_f64; K_SURVIVAL];
1368        let delta = [2.0_f64, -1.0];
1369        op.apply_row(1, &delta, &mut out);
1370        assert_eq!(out[0], 0.0);
1371        assert_eq!(out[1], 0.0);
1372        assert_eq!(out[2], 0.0);
1373        let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1374        assert!((out[3] - want).abs() < 1e-12);
1375    }
1376
1377    #[test]
1378    fn extract_term_partition_simple_cases() {
1379        let full = 0..5usize;
1380        // No penalties: whole block is one term.
1381        let part = extract_term_partition_from_penalty_ranges(5, &[]);
1382        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1383        // One penalty covering the whole block.
1384        let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1385        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1386        // Two penalties with a gap: produces three terms (pen1, gap, pen2).
1387        let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1388        assert_eq!(part, vec![0..3, 3..6, 6..10]);
1389        // Duplicate penalty ranges coalesce.
1390        let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1391        assert_eq!(part, vec![0..3, 3..6]);
1392        // Empty block.
1393        let part = extract_term_partition_from_penalty_ranges(0, &[]);
1394        assert!(part.is_empty());
1395    }
1396
1397    #[test]
1398    fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1399        let v_a = Array2::<f64>::eye(2);
1400        let v_b = Array2::<f64>::eye(2);
1401        let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1402        assert_eq!(t.dim(), (4, 4));
1403        let eye4 = Array2::<f64>::eye(4);
1404        for i in 0..4 {
1405            for j in 0..4 {
1406                assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1407            }
1408        }
1409    }
1410
1411    #[test]
1412    fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1413        let mut v_a = Array2::<f64>::zeros((3, 2));
1414        v_a[[0, 0]] = 1.0;
1415        v_a[[1, 0]] = 0.5;
1416        v_a[[2, 1]] = 1.0;
1417        let v_b = Array2::<f64>::eye(2);
1418        let r_ab =
1419            Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1420        let t =
1421            assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1422        assert_eq!(t.dim(), (5, 4));
1423        for i in 0..3 {
1424            for j in 0..2 {
1425                assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1426            }
1427        }
1428        for i in 0..2 {
1429            for j in 0..2 {
1430                assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1431            }
1432        }
1433        for i in 0..3 {
1434            for j in 0..2 {
1435                assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1436            }
1437        }
1438        for i in 0..2 {
1439            for j in 0..2 {
1440                assert_eq!(t[[3 + i, j]], 0.0);
1441            }
1442        }
1443    }
1444
1445    #[test]
1446    fn validate_partition_rejects_bad_partitions() {
1447        let bad_start = 1..5usize;
1448        let short_cover = 0..3usize;
1449        let full_cover = 0..5usize;
1450        // Doesn't start at 0.
1451        assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1452        // Doesn't cover the block.
1453        assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1454        // Has a gap.
1455        assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1456        // Has overlap.
1457        assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1458        // Has empty range.
1459        assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1460        // Empty block + empty partition OK.
1461        assert!(validate_partition(&[], 0, "test").is_ok());
1462        // Valid partition.
1463        assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1464        assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1465    }
1466
1467    /// Regression for #368: the phase-4b compiled-map penalty pullback must
1468    /// emit a PER-BLOCK-WIDTH penalty for every block (sized to that block's
1469    /// COMPILED design width), even when a block drops columns and the
1470    /// triangular T carries nonzero off-diagonal cross-block residualisation
1471    /// `R_{a→b}`. The original bug pulled penalties back through the full
1472    /// joint T (`Tᵀ S T`), producing joint-compiled-width penalties (e.g.
1473    /// 7×7) that did not fit a single per-block `ParameterBlockSpec.penalties`
1474    /// slot (e.g. time block compiled width 3), making `validate_blockspecs`
1475    /// fail and `assert_valid_blockspecs` panic across the FFI boundary on
1476    /// ordinary survival data.
1477    #[test]
1478    fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1479        use gam_identifiability::families::compiler::CompiledMap;
1480        use gam_terms::smooth::BlockwisePenalty;
1481
1482        let n = 10;
1483        // Time raw 3 → compiled 3 (block 0: no anchor, V pure, R=None).
1484        // Marginal raw 3 → compiled 2 (a real drop, with nonzero R against time).
1485        // Logslope raw 2 → compiled 2 (nonzero R against time+marginal).
1486        let v_time =
1487            Array2::<f64>::from_shape_fn(
1488                (3, 3),
1489                |(i, j)| {
1490                    if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1491                },
1492            );
1493        let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1494            0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1495        });
1496        let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1497        // R_marg: rows = time raw width 3, cols = marginal compiled width 2.
1498        let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1499        // R_log: rows = time+marg RAW width 6 (3 + 3), cols = logslope compiled
1500        // width 2. `assemble_block_triangular_t` stacks R_{a→b} over a<b, so the row
1501        // count is the sum of the RAW widths of the prior blocks (not their
1502        // compiled widths — marginal's compiled width is 2 but its raw width is 3).
1503        let r_log =
1504            Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1505
1506        let t = assemble_block_triangular_t(
1507            &[v_time.clone(), v_marg.clone(), v_log.clone()],
1508            &[None, Some(r_marg.clone()), Some(r_log.clone())],
1509        );
1510        assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1511
1512        let map = CompiledMap {
1513            raw_from_compiled: t.clone(),
1514            compiled_block_ranges: vec![0..3, 3..5, 5..7],
1515            raw_block_ranges: vec![0..3, 3..6, 6..8],
1516        };
1517
1518        // Raw designs (dense, n rows).
1519        let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1520            Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1521        ));
1522        let raw_time_exit = raw_time_entry.clone();
1523        let raw_time_deriv = raw_time_entry.clone();
1524        let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1525            (n, 3),
1526            |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1527        )));
1528        let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1529            (n, 2),
1530            |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1531        )));
1532
1533        // Block-local penalties (col_range relative to each block's first col).
1534        let s_time =
1535            Array2::<f64>::from_shape_fn(
1536                (3, 3),
1537                |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1538            );
1539        let s_marg =
1540            Array2::<f64>::from_shape_fn(
1541                (3, 3),
1542                |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1543            );
1544        let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1545        let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1546        let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1547        let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1548
1549        let out = apply_compiled_map_to_designs(
1550            &map,
1551            raw_time_entry,
1552            raw_time_exit,
1553            raw_time_deriv,
1554            raw_marg,
1555            raw_log,
1556            &time_pens,
1557            &marg_pens,
1558            &log_pens,
1559        )
1560        .expect("apply_compiled_map_to_designs must succeed");
1561
1562        // Designs carry per-block compiled widths.
1563        assert_eq!(out.time_design_entry.ncols(), 3);
1564        assert_eq!(out.marginal_design.ncols(), 2);
1565        assert_eq!(out.logslope_design.ncols(), 2);
1566
1567        // Core invariant the bug violated: every penalty is sized to ITS
1568        // OWN block's compiled width, NOT the joint compiled width (7).
1569        for s in &out.time_penalties {
1570            assert_eq!(
1571                s.as_dense_cow().dim(),
1572                (3, 3),
1573                "time penalty must be per-block 3×3, not joint-width"
1574            );
1575        }
1576        for s in &out.marginal_penalties {
1577            assert_eq!(
1578                s.as_dense_cow().dim(),
1579                (2, 2),
1580                "marginal penalty must match reduced compiled width 2, not joint 7"
1581            );
1582        }
1583        for s in &out.logslope_penalties {
1584            assert_eq!(s.as_dense_cow().dim(), (2, 2));
1585        }
1586
1587        // For the time block (block 0, no anchor ⇒ R=None), the per-block
1588        // pullback is EXACT: θ_timeᵀ P_time θ_time == γ_timeᵀ S_time γ_time
1589        // with γ_time = V_time · θ_time. Verify the quadratic-form identity.
1590        let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1591        let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1592        let gamma_time = v_time.dot(&theta_time);
1593        let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1594        let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1595        assert!(
1596            (lhs - rhs).abs() < 1e-10,
1597            "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1598        );
1599
1600        // The marginal pullback must equal V_margᵀ S_marg V_marg exactly
1601        // (block-local; the cross-block R_marg lives in the design, not here).
1602        let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1603        let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1604        for i in 0..2 {
1605            for j in 0..2 {
1606                assert!(
1607                    (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1608                    "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1609                );
1610            }
1611        }
1612    }
1613
1614    /// Top-level Phase-4b API test for the SMGS parametric path:
1615    /// call `compile_survival_parametric_designs` on a shared-constant
1616    /// alias between time and marginal, with an identity row Hessian.
1617    /// Verify the returned `v_*` matrices have the expected widths
1618    /// (time keeps all 3, marginal loses 1, logslope keeps both) and
1619    /// `drops_by_block` reports `(0, 1, 0)`.
1620    #[test]
1621    fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1622        let n = 24;
1623        let p_time = 3;
1624        let p_marginal = 3;
1625        let p_logslope = 2;
1626        let x: Vec<f64> = (0..n)
1627            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1628            .collect();
1629        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1630        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1631        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1632        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1633        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1634        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1635        for i in 0..n {
1636            time_dq0[[i, 0]] = 1.0;
1637            time_dq0[[i, 1]] = x[i];
1638            time_dq0[[i, 2]] = x[i] * x[i];
1639            time_dq1[[i, 0]] = 1.0;
1640            time_dq1[[i, 1]] = x[i];
1641            time_dq1[[i, 2]] = x[i] * x[i];
1642            time_dqd1[[i, 0]] = 0.0;
1643            time_dqd1[[i, 1]] = 1.0;
1644            time_dqd1[[i, 2]] = 2.0 * x[i];
1645            marg_dq[[i, 0]] = 1.0; // alias with time col 0
1646            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1647            marg_dq[[i, 2]] = x[i].sin();
1648            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1649            log_dg[[i, 1]] = x[i].tanh();
1650        }
1651        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1652        for i in 0..n {
1653            for k in 0..K_SURVIVAL {
1654                h_full[[i, k, k]] = 1.0;
1655            }
1656        }
1657        let row_hess = SurvivalRowHessian::from_full(h_full);
1658        let out = compile_survival_parametric_designs(
1659            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1660        )
1661        .expect("Phase-4b parametric compile must succeed on single-direction alias");
1662        assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1663        assert_eq!(
1664            out.v_marginal.ncols(),
1665            p_marginal - 1,
1666            "marginal loses exactly the shared-constant direction"
1667        );
1668        assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1669        assert_eq!(
1670            out.drops_by_block,
1671            (0, 1, 0),
1672            "attribution: zero from time/logslope, one from marginal",
1673        );
1674    }
1675
1676    /// End-to-end Phase-4b smoke test: build the full 3-block survival
1677    /// parametric operator stack (time + marginal + logslope) with a
1678    /// shared-constant alias seeded between the time and marginal
1679    /// blocks, feed it into `compile()` with an identity 4×4 row
1680    /// Hessian on every row, and verify the compiler:
1681    ///
1682    ///   (1) returns a [`CompiledBlocks`] with one block per input;
1683    ///   (2) preserves all 3 columns of the highest-priority `Time`
1684    ///       block in `t_lw` (the time block enters first in the
1685    ///       ordering, so its full column span survives);
1686    ///   (3) drops exactly one direction from `Marginal` (the
1687    ///       constant aliased with the time intercept), leaving its
1688    ///       remaining columns in `t_lw`;
1689    ///   (4) reports `joint_rank` = (raw_total - 1).
1690    ///
1691    /// This validates the Phase-4b construction-time orthogonalisation
1692    /// path on the survival K=4 row primary state and then feeds the
1693    /// compiled per-block reduced bases through the SMGS lift [`Gauge`]
1694    /// (step 6), asserting the lift's reduced/raw block structure agrees
1695    /// with the compiled rank-drop — the construction contract end to end.
1696    #[test]
1697    fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1698        use gam_identifiability::families::compiler::compile;
1699
1700        let n = 32;
1701        let p_time = 3;
1702        let p_marginal = 3;
1703        let p_logslope = 2;
1704
1705        // Time block:
1706        //   col 0 = ones (the shared constant — aliases marginal col 0);
1707        //   col 1 = linear x;
1708        //   col 2 = quadratic x².
1709        // q0/q1 share the same design (so the alias surfaces in both
1710        // the entry and exit primary channels); qd1 is the derivative
1711        // of the design w.r.t. time at the exit point, which for the
1712        // constant column is exactly zero (the gauge identity that
1713        // makes the constant a true null direction under (q0, q1, qd1)
1714        // joint).
1715        let x: Vec<f64> = (0..n)
1716            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1717            .collect();
1718        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1719        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1720        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1721        for i in 0..n {
1722            time_dq0[[i, 0]] = 1.0;
1723            time_dq0[[i, 1]] = x[i];
1724            time_dq0[[i, 2]] = x[i] * x[i];
1725            time_dq1[[i, 0]] = 1.0;
1726            time_dq1[[i, 1]] = x[i];
1727            time_dq1[[i, 2]] = x[i] * x[i];
1728            // d/dt of a constant = 0; d/dt of x ≡ 1; d/dt of x² ≡ 2x.
1729            time_dqd1[[i, 0]] = 0.0;
1730            time_dqd1[[i, 1]] = 1.0;
1731            time_dqd1[[i, 2]] = 2.0 * x[i];
1732        }
1733
1734        // Marginal block (q-channel only; qd1 contribution zero — no
1735        // timewiggle in this scenario):
1736        //   col 0 = ones (the shared constant);
1737        //   col 1 = x³;
1738        //   col 2 = sin(x).
1739        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1740        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1741        for i in 0..n {
1742            marg_dq[[i, 0]] = 1.0;
1743            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1744            marg_dq[[i, 2]] = x[i].sin();
1745        }
1746
1747        // Logslope block (g-channel only):
1748        //   col 0 = cos(2x);
1749        //   col 1 = tanh(x).  (no shared constant — logslope is clean)
1750        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1751        for i in 0..n {
1752            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1753            log_dg[[i, 1]] = x[i].tanh();
1754        }
1755
1756        let inputs = build_survival_compiler_inputs(
1757            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1758        );
1759
1760        // Identity 4×4 row Hessian on every row. With H_i = I the
1761        // sqrt-H metric collapses to the standard Frobenius metric,
1762        // so the compiler's residualisation is ordinary least-squares
1763        // projection — exactly what we want for verifying the
1764        // structural rank-deficiency attribution.
1765        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1766        for i in 0..n {
1767            for k in 0..K_SURVIVAL {
1768                h_full[[i, k, k]] = 1.0;
1769            }
1770        }
1771        let row_hess = SurvivalRowHessian::from_full(h_full);
1772
1773        let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
1774            .expect("survival 3-block compile must succeed; aliasing is single-direction");
1775
1776        // (1) One CompiledBlock per input.
1777        assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
1778
1779        // (2) Time enters first; under sqrt-I metric every column of
1780        // the time block is residual-vs-empty-anchor and therefore
1781        // survives the eigendecomposition with positive eigenvalue.
1782        // V_time has p_time columns.
1783        let v_time = &compiled.blocks[0].t_lw;
1784        assert_eq!(
1785            v_time.ncols(),
1786            p_time,
1787            "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
1788            v_time.dim(),
1789        );
1790
1791        // (3) Marginal enters second. Its constant column is aliased
1792        // with time's constant column in (q0, q1) and contributes zero
1793        // to qd1. After residualising against the time anchor in the
1794        // K=4 stacked metric, the residual Gram has rank
1795        // p_marginal − 1 (one direction collapsed by the alias). So
1796        // V_marginal has exactly (p_marginal − 1) columns.
1797        let v_marg = &compiled.blocks[1].t_lw;
1798        assert_eq!(
1799            v_marg.ncols(),
1800            p_marginal - 1,
1801            "marginal block must lose exactly the shared-constant direction; \
1802             V_marginal cols = {}, expected {}",
1803            v_marg.ncols(),
1804            p_marginal - 1,
1805        );
1806
1807        // (4) Logslope enters third and carries no shared direction
1808        // with time or marginal in the g-channel. Both columns survive.
1809        let v_log = &compiled.blocks[2].t_lw;
1810        assert_eq!(
1811            v_log.ncols(),
1812            p_logslope,
1813            "logslope block (no shared direction) must retain all {p_logslope} columns",
1814        );
1815
1816        // (5) Joint rank consistency: sum of compiled column counts
1817        // equals raw_total minus the one aliased direction.
1818        let raw_total = p_time + p_marginal + p_logslope;
1819        let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
1820        assert_eq!(
1821            kept_total,
1822            raw_total - 1,
1823            "joint kept = raw_total − aliased; got {kept_total}, expected {}",
1824            raw_total - 1,
1825        );
1826        assert_eq!(
1827            compiled.joint_rank, kept_total,
1828            "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
1829        );
1830
1831        // (6) SMGS construction contract. Feed the compiled per-block reduced
1832        // bases (V_k = t_lw, shaped raw_k × kept_k) into the SMGS lift `Gauge`
1833        // and verify the lift's coordinate bookkeeping matches the compiler's
1834        // rank attribution: the reduced dimension equals `joint_rank`, the
1835        // reduced block boundaries advance by each block's kept width, and —
1836        // with R = None (no residualised cross-block reparam in this V-only
1837        // construction) — the raw block boundaries advance by each block's raw
1838        // width. This exercises the SMGS construction hook directly on the
1839        // compiled output rather than asserting against a hypothetical shape.
1840        let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
1841        let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
1842        let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
1843
1844        let mut expected_reduced = vec![0usize];
1845        let mut expected_raw = vec![0usize];
1846        for b in &compiled.blocks {
1847            let prev_reduced = *expected_reduced.last().unwrap();
1848            expected_reduced.push(prev_reduced + b.t_lw.ncols());
1849            let prev_raw = *expected_raw.last().unwrap();
1850            expected_raw.push(prev_raw + b.t_lw.nrows());
1851        }
1852        assert_eq!(
1853            *gauge.block_starts_reduced.last().unwrap(),
1854            compiled.joint_rank,
1855            "SMGS lift reduced dimension must equal the compiled joint_rank",
1856        );
1857        assert_eq!(
1858            gauge.block_starts_reduced, expected_reduced,
1859            "SMGS lift reduced block boundaries must match the compiled kept widths",
1860        );
1861        assert_eq!(
1862            gauge.block_starts_raw, expected_raw,
1863            "SMGS lift raw block boundaries must match the compiled per-block raw widths",
1864        );
1865
1866        // (7) Every kept direction is finite and non-degenerate. A retained
1867        // column with a zero or non-finite norm would be a spurious rank
1868        // contribution that the count-only checks above cannot catch, so verify
1869        // each compiled block's surviving directions directly.
1870        for (bi, block) in compiled.blocks.iter().enumerate() {
1871            for j in 0..block.t_lw.ncols() {
1872                let col = block.t_lw.column(j);
1873                assert!(
1874                    col.iter().all(|v| v.is_finite()),
1875                    "block {bi} kept direction {j} has a non-finite entry",
1876                );
1877                let norm = col.dot(&col).sqrt();
1878                assert!(
1879                    norm > 1e-10,
1880                    "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
1881                );
1882            }
1883        }
1884    }
1885
1886    /// `T = I` case: per-block V = identity, R = None. The triangular
1887    /// lift must be the identity on each block.
1888    #[test]
1889    fn smgs_lift_via_t_identity_passes_through() {
1890        let v0 = Array2::<f64>::eye(3);
1891        let v1 = Array2::<f64>::eye(2);
1892        let v_per_term = vec![v0, v1];
1893        let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
1894        let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
1895        assert_eq!(lift.t_full.dim(), (5, 5));
1896        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
1897        assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
1898        for i in 0..5 {
1899            for j in 0..5 {
1900                let want = if i == j { 1.0 } else { 0.0 };
1901                assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
1902            }
1903        }
1904        let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
1905        let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
1906        let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
1907        assert_eq!(lifted.len(), 2);
1908        for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
1909            assert!((a - b).abs() < 1e-14);
1910        }
1911        for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
1912            assert!((a - b).abs() < 1e-14);
1913        }
1914    }
1915
1916    /// Two-block toy: V_a = I_3, V_b drops the middle column, R is a
1917    /// non-trivial residualised reparam. Verify β_a_raw = θ_a − R · θ_b
1918    /// and β_b_raw = V_b · θ_b.
1919    #[test]
1920    fn smgs_lift_via_t_two_block_with_residualisation() {
1921        let v_a = Array2::<f64>::eye(3);
1922        let mut v_b = Array2::<f64>::zeros((3, 2));
1923        v_b[[0, 0]] = 1.0;
1924        v_b[[2, 1]] = 1.0;
1925        let mut r_b = Array2::<f64>::zeros((3, 2));
1926        r_b[[0, 0]] = 0.4;
1927        r_b[[0, 1]] = -0.1;
1928        r_b[[1, 0]] = 0.7;
1929        r_b[[1, 1]] = 1.3;
1930        r_b[[2, 0]] = -0.2;
1931        r_b[[2, 1]] = 0.5;
1932        let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
1933        assert_eq!(lift.t_full.dim(), (6, 5));
1934        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
1935        assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
1936
1937        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
1938        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
1939        let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
1940        let r_theta_b = r_b.dot(&theta_b);
1941        let expected_a = &theta_a - &r_theta_b;
1942        assert_eq!(lifted[0].len(), 3);
1943        for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
1944            assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
1945        }
1946        assert_eq!(lifted[1].len(), 3);
1947        assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
1948        assert!(lifted[1][1].abs() < 1e-12);
1949        assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
1950    }
1951
1952    /// Covariance pushforward `Σ_raw = T · Σ_θ · Tᵀ` must be the exact
1953    /// inference companion of the point-estimate lift. Two invariants:
1954    ///
1955    /// 1. Identity T (V = I, R = None): the lifted covariance equals the
1956    ///    input covariance — a true no-op for a rank-clean fit.
1957    /// 2. Rank-1 consistency with the β lift: for a degenerate posterior
1958    ///    `Σ_θ = θ θᵀ`, the pushforward must equal `(T θ)(T θ)ᵀ`, i.e.
1959    ///    lifting the covariance of a point mass agrees with lifting the
1960    ///    point itself. This couples `lift_covariance` to
1961    ///    `lift_block_betas` exactly, so the mean and its
1962    ///    uncertainty can never drift into inconsistent coordinates.
1963    #[test]
1964    fn smgs_lift_covariance_identity_and_rank1_consistency() {
1965        // ── Invariant 1: identity T leaves the covariance unchanged. ──
1966        let lift_id = Gauge::from_v_and_r(
1967            &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
1968            &[None, None],
1969        );
1970        let mut cov = Array2::<f64>::zeros((4, 4));
1971        // An arbitrary symmetric PSD-ish covariance.
1972        for i in 0..4 {
1973            for j in 0..4 {
1974                cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
1975            }
1976        }
1977        let lifted_id = lift_id.lift_covariance(&cov);
1978        assert_eq!(lifted_id.dim(), (4, 4));
1979        for i in 0..4 {
1980            for j in 0..4 {
1981                assert!(
1982                    (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
1983                    "identity-T covariance lift must be a no-op at [{i},{j}]",
1984                );
1985            }
1986        }
1987
1988        // ── Invariant 2: rank-1 Σ_θ = θθᵀ pushes to (Tθ)(Tθ)ᵀ. ──
1989        // Reuse the two-block-with-residualisation geometry: V_a = I_3,
1990        // V_b drops the middle raw column, R_b non-trivial → raw width 6,
1991        // compiled width 5.
1992        let v_a = Array2::<f64>::eye(3);
1993        let mut v_b = Array2::<f64>::zeros((3, 2));
1994        v_b[[0, 0]] = 1.0;
1995        v_b[[2, 1]] = 1.0;
1996        let mut r_b = Array2::<f64>::zeros((3, 2));
1997        r_b[[0, 0]] = 0.4;
1998        r_b[[0, 1]] = -0.1;
1999        r_b[[1, 0]] = 0.7;
2000        r_b[[1, 1]] = 1.3;
2001        r_b[[2, 0]] = -0.2;
2002        r_b[[2, 1]] = 0.5;
2003        let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2004
2005        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2006        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2007        // Concatenated compiled θ (width 5).
2008        let theta_full = Array1::from(vec![
2009            theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2010        ]);
2011        // Σ_θ = θ θᵀ (rank-1).
2012        let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2013        for i in 0..5 {
2014            for j in 0..5 {
2015                cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2016            }
2017        }
2018        let lifted_cov = lift.lift_covariance(&cov_rank1);
2019        // Reference: (T θ)(T θ)ᵀ via the point-estimate lift.
2020        let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2021        let beta_raw = Array1::from(
2022            lifted_blocks
2023                .iter()
2024                .flat_map(|b| b.iter().copied())
2025                .collect::<Vec<f64>>(),
2026        );
2027        assert_eq!(lifted_cov.dim(), (6, 6));
2028        assert_eq!(beta_raw.len(), 6);
2029        for i in 0..6 {
2030            for j in 0..6 {
2031                let want = beta_raw[i] * beta_raw[j];
2032                assert!(
2033                    (lifted_cov[[i, j]] - want).abs() < 1e-10,
2034                    "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2035                    lifted_cov[[i, j]],
2036                );
2037            }
2038        }
2039        // Symmetry sanity.
2040        for i in 0..6 {
2041            for j in 0..6 {
2042                assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2043            }
2044        }
2045    }
2046
2047    /// When all R's are None, the triangular gauge lift must equal the
2048    /// strictly per-block `V_b · θ_b` lift.
2049    #[test]
2050    fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2051        let mut v_a = Array2::<f64>::zeros((3, 2));
2052        v_a[[0, 0]] = 0.6;
2053        v_a[[1, 0]] = -0.8;
2054        v_a[[1, 1]] = 0.3;
2055        v_a[[2, 1]] = 0.9;
2056        let mut v_b = Array2::<f64>::zeros((4, 3));
2057        v_b[[0, 0]] = 1.0;
2058        v_b[[1, 1]] = -0.4;
2059        v_b[[2, 0]] = 0.2;
2060        v_b[[2, 2]] = 0.7;
2061        v_b[[3, 2]] = -1.1;
2062        let v_per_term = vec![v_a.clone(), v_b.clone()];
2063        let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2064        let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2065        let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2066        let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2067        let ref_a = v_a.dot(&theta_a);
2068        let ref_b = v_b.dot(&theta_b);
2069        assert_eq!(via_t[0].len(), ref_a.len());
2070        for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2071            assert!((g - w).abs() < 1e-12);
2072        }
2073        assert_eq!(via_t[1].len(), ref_b.len());
2074        for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2075            assert!((g - w).abs() < 1e-12);
2076        }
2077    }
2078
2079    /// Recompile-after-first-PIRLS-accept refinement: under a structural
2080    /// (identity) row Hessian, a direction that is *only* identifiable
2081    /// through the q1 channel survives the per-term compile; under a
2082    /// data-adaptive row Hessian that happens to zero out the q1/qd1/g
2083    /// metric weight (everything except q0), the same direction collapses.
2084    /// This pins the diagnostic the production hook in
2085    /// `fit_survival_marginal_slope_terms` watches for: the two row
2086    /// Hessians produce different `drops_by_block` on identical raw
2087    /// designs.
2088    #[test]
2089    fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2090        let n = 6usize;
2091        // Time block: a single column that only contributes through q0
2092        // (entry-time channel). Both row Hessians see it identically on
2093        // the q0 axis.
2094        let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2095        let time_dq1 = Array2::<f64>::zeros((n, 1));
2096        let time_dqd1 = Array2::<f64>::zeros((n, 1));
2097        // Marginal block: a single column whose q0 part is colinear with
2098        // the time block's q0 (both are ones-vectors). Its q-channel maps
2099        // into BOTH q0 and q1 under QChannelBlockOperator, so under a
2100        // metric that weighs q1 it carries a non-colinear component.
2101        let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2102        let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2103        // No logslope columns.
2104        let log_dg = Array2::<f64>::zeros((n, 0));
2105        let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2106        time_partition.push(0..1);
2107        let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2108        marg_partition.push(0..1);
2109        let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2110
2111        // Pass 1: structural identity row Hessian. q0/q1/qd1/g all weighted
2112        // equally → marg's q1 component is visible, so marg is identifiable
2113        // after residualising against the time block (drops_marg = 0).
2114        let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2115        for i in 0..n {
2116            for k in 0..K_SURVIVAL {
2117                h_ident[[i, k, k]] = 1.0;
2118            }
2119        }
2120        let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2121        let compiled_ident = compile_survival_parametric_designs_per_term(
2122            time_dq0.clone(),
2123            time_dq1.clone(),
2124            time_dqd1.clone(),
2125            &time_partition,
2126            marg_dq.clone(),
2127            marg_dqd1.clone(),
2128            &marg_partition,
2129            log_dg.clone(),
2130            &log_partition,
2131            &row_hess_ident,
2132        )
2133        .expect("identity-H compile must succeed");
2134
2135        // Pass 2: data-adaptive row Hessian that only weighs q0 (all
2136        // other channel diagonals zero). Marg's q1 contribution is now
2137        // invisible → marg fully aliases with time on q0 → drops_marg = 1.
2138        let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2139        for i in 0..n {
2140            h_q0_only[[i, 0, 0]] = 1.0;
2141        }
2142        let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2143        let compiled_q0 = compile_survival_parametric_designs_per_term(
2144            time_dq0,
2145            time_dq1,
2146            time_dqd1,
2147            &time_partition,
2148            marg_dq,
2149            marg_dqd1,
2150            &marg_partition,
2151            log_dg,
2152            &log_partition,
2153            &row_hess_q0,
2154        )
2155        .expect("q0-only-H compile must succeed");
2156
2157        // The two drops_by_block tuples disagree on the marginal block —
2158        // this is exactly the "pilot-curvature trap" the recompile-after-
2159        // accept hook is designed to surface.
2160        assert_ne!(
2161            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2162            "structural-H and data-adaptive-H compiles must produce different \
2163             drops_by_block on the constructed pilot-curvature-trap design; \
2164             identity={:?} q0-only={:?}",
2165            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2166        );
2167        // Under identity H, marg survives (no drop).
2168        assert_eq!(
2169            compiled_ident.drops_by_block.1, 0,
2170            "identity-H marg drops expected 0, got {:?}",
2171            compiled_ident.drops_by_block,
2172        );
2173        // Under q0-only H, marg fully aliases with time on q0.
2174        assert_eq!(
2175            compiled_q0.drops_by_block.1, 1,
2176            "q0-only-H marg drops expected 1, got {:?}",
2177            compiled_q0.drops_by_block,
2178        );
2179    }
2180
2181    #[test]
2182    fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2183        // Build a per-term compile by hand: time has one term (raw 2, kept 2),
2184        // marginal one term (raw 2, kept 1 — a drop), logslope one term
2185        // (raw 1, kept 1). No required channel is fully collapsed.
2186        let v_time = Array2::<f64>::eye(2);
2187        let mut v_marg = Array2::<f64>::zeros((2, 1));
2188        v_marg[[0, 0]] = 1.0;
2189        v_marg[[1, 0]] = 0.5;
2190        let v_log = Array2::<f64>::eye(1);
2191        // R for the marginal block (anchor = time, raw width 2) and logslope
2192        // block (anchors = time + marginal, raw width 2 + 2 = 4).
2193        let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2194        let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2195        let per_term = SurvivalParametricCompiledPerTerm {
2196            v_time_per_term: vec![v_time.clone()],
2197            v_marginal_per_term: vec![v_marg.clone()],
2198            v_logslope_per_term: vec![v_log.clone()],
2199            r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2200            drops_by_block: (0, 1, 0),
2201        };
2202
2203        let map = compiled_map_from_per_term(&per_term);
2204
2205        // Raw block ranges: time 0..2, marginal 2..4, logslope 4..5.
2206        assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2207        // Compiled block ranges: time 0..2, marginal 2..3, logslope 3..4.
2208        assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2209        assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2210
2211        // The block-diagonal slices recovered by apply_compiled_map_to_designs
2212        // must equal the per-term V's exactly.
2213        let v_time_slice = map
2214            .raw_from_compiled
2215            .slice(ndarray::s![0..2, 0..2])
2216            .to_owned();
2217        let v_marg_slice = map
2218            .raw_from_compiled
2219            .slice(ndarray::s![2..4, 2..3])
2220            .to_owned();
2221        let v_log_slice = map
2222            .raw_from_compiled
2223            .slice(ndarray::s![4..5, 3..4])
2224            .to_owned();
2225        for i in 0..2 {
2226            for j in 0..2 {
2227                assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2228            }
2229            assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2230        }
2231        assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2232
2233        // The cross-block carry (-R) must sit in the strict upper triangle, so
2234        // the map agrees with the lift assembled directly from V and R.
2235        let ordering = [
2236            gam_identifiability::families::compiler::BlockOrder::Time,
2237            gam_identifiability::families::compiler::BlockOrder::Marginal,
2238            gam_identifiability::families::compiler::BlockOrder::Logslope,
2239        ];
2240        let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2241        let v_all = vec![v_time, v_marg, v_log];
2242        let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2243        assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2244        for i in 0..lift_from_map.t_full.nrows() {
2245            for j in 0..lift_from_map.t_full.ncols() {
2246                assert!(
2247                    (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2248                    "T mismatch at ({i},{j}): map={} direct={}",
2249                    lift_from_map.t_full[[i, j]],
2250                    lift_direct.t_full[[i, j]],
2251                );
2252            }
2253        }
2254    }
2255}