Skip to main content

gam_models/survival/marginal_slope/
identifiability.rs

1//! Survival marginal-slope concrete impls for the family-agnostic
2//! identifiability compiler (`gam_identifiability::families::compiler`).
3//!
4//! Survival's row primary state is the 4-vector `u_i = (q0, q1, qd1, g)`,
5//! so `K = 4`. The row Hessian is the 4×4 second-derivative block of the
6//! per-row neg-log-likelihood kernel `row_primary_closed_form` at a pilot
7//! `β`, PSD-clamped via eigendecomposition (negative eigenvalues projected
8//! to zero) to handle pilot points far from the optimum.
9//!
10//! Each block exposes its row Jacobian as the contribution of `δβ_block`
11//! to the row primary-state vector:
12//!
13//! - **TimeBlockOperator**: `(δq0, δq1, δqd1, 0)` from `design_entry`,
14//!   `design_exit`, `design_derivative_exit` rows.
15//! - **MarginalBlockOperator**: `(δq, δq, δqd_marginal, 0)` from the
16//!   marginal design row (shared by q0 and q1; qd contribution zero unless
17//!   timewiggle is active — captured by an explicit derivative row matrix).
18//! - **LogslopeBlockOperator**: `(0, 0, 0, δg)` from the logslope design.
19//! - **ScoreWarpBlockOperator**: `(δq, δq, δqd_warp, 0)` from the warp
20//!   basis (shifts q at entry/exit; chain rule via dq0_seed/dt for qd1).
21//! - **LinkDevBlockOperator**: `(δq, δq, δqd_link, 0)` from the link-dev
22//!   basis on the rigid/pilot q-seed.
23//!
24//! Phase 4a delivery: trait impls + an input-builder helper. Phase 4b
25//! threads these through SMGS's construction site and the migrated pilot
26//! β; Phase 4c deletes the legacy
27//! `install_compiled_flex_block_into_runtime` path.
28
29use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34    BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44/// Threshold below which a coefficient vector is treated as the trivial
45/// (all-zero) pilot point in the drift-detection audit. At β ≈ 0 the
46/// primary-state coupling g vanishes (c ≡ 1), so the frozen pilot W is exact;
47/// any |β_j| above this is "non-trivial" and requires the family scalars to
48/// re-evaluate W(β). The bound is well below any meaningful fitted coefficient.
49const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51/// Per-row 4×4 row Hessian for the survival marginal-slope likelihood at a
52/// pilot `β`. The pilot supplies the primary-state vector
53/// `(q0_i, q1_i, qd1_i, g_i)` and the per-row sample weight + event
54/// indicator + z + probit scale. The 4×4 block is evaluated via the
55/// existing `row_primary_closed_form` kernel (which already returns the
56/// full Hessian in `(q0, q1, qd1, g)` order) and PSD-clamped per row.
57pub struct SurvivalRowHessian {
58    /// PSD-projected per-row 4×4 Hessian, stored row-major as
59    /// `(n × 4 × 4)`.
60    h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64    /// Construct from explicit per-row pilot primary-state and the row
65    /// data needed by `row_primary_closed_form`. Negative eigenvalues are
66    /// projected to zero before storage so the matrix is PSD.
67    pub fn from_pilot_primary_state(
68        q0: &Array1<f64>,
69        q1: &Array1<f64>,
70        qd1: &Array1<f64>,
71        g: &Array1<f64>,
72        z: &Array1<f64>,
73        weights: &Array1<f64>,
74        event: &Array1<f64>,
75        derivative_guard: f64,
76        probit_scale: f64,
77    ) -> Result<Self, String> {
78        let n = q0.len();
79        if [
80            q1.len(),
81            qd1.len(),
82            g.len(),
83            z.len(),
84            weights.len(),
85            event.len(),
86        ]
87        .iter()
88        .any(|&l| l != n)
89        {
90            return Err(format!(
91                "SurvivalRowHessian: length mismatch \
92                 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93                q1.len(),
94                qd1.len(),
95                g.len(),
96                z.len(),
97                weights.len(),
98                event.len()
99            ));
100        }
101        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102        for i in 0..n {
103            let (_, _grad, hess) =
104                crate::survival::marginal_slope::row_primary_for_compiler(
105                    q0[i],
106                    q1[i],
107                    qd1[i],
108                    g[i],
109                    z[i],
110                    weights[i],
111                    event[i],
112                    derivative_guard,
113                    probit_scale,
114                )?;
115            // PSD-clamp via eigendecomposition: project negative eigvals to 0.
116            let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117            for a in 0..K_SURVIVAL {
118                for b in 0..K_SURVIVAL {
119                    h_i[[a, b]] = hess[a][b];
120                }
121            }
122            let clamped = psd_clamp_4x4(&h_i);
123            for a in 0..K_SURVIVAL {
124                for b in 0..K_SURVIVAL {
125                    h_full[[i, a, b]] = clamped[[a, b]];
126                }
127            }
128        }
129        Ok(Self { h: h_full })
130    }
131
132    /// Construct from an already-PSD per-row tensor. Used by callers that
133    /// have computed the Hessian via a different route.
134    pub fn from_full(h: Array3<f64>) -> Self {
135        assert_eq!(h.shape()[1], K_SURVIVAL);
136        assert_eq!(h.shape()[2], K_SURVIVAL);
137        Self { h }
138    }
139}
140
141impl RowHessian for SurvivalRowHessian {
142    fn k(&self) -> usize {
143        K_SURVIVAL
144    }
145    fn nrows(&self) -> usize {
146        self.h.shape()[0]
147    }
148    fn fill_row(&self, row: usize, out: &mut [f64]) {
149        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150        for a in 0..K_SURVIVAL {
151            for b in 0..K_SURVIVAL {
152                out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153            }
154        }
155    }
156    fn evaluate_full(&self) -> Array3<f64> {
157        self.h.clone()
158    }
159}
160
161/// `FamilyChannelHessian` for survival marginal-slope.
162///
163/// The 4×4 per-subject W_i is the Hessian of the row negative log-likelihood
164/// `ρ_i(q0, q1, qd1, g) = −δ_i log f(η_1, ad1) − (1−δ_i) log S(η_1) + log S(η_0)`
165/// with respect to the 4-vector primary state `(q0, q1, qd1, g)`.
166///
167/// Derivation of W_i entries (all from `row_primary_closed_form`):
168///
169/// - W[0,0] = u2_η0 · c²  (q0–q0; only η0 depends on q0)
170/// - W[1,1] = (u2_η1 + w·δ) · c²  (q1–q1; η1 and log-φ both depend on q1)
171/// - W[2,2] = w·δ · (∂ad1/∂qd1)² · (−1/ad1²)  (qd1–qd1 via neglog(ad1))
172/// - W[3,3] = u2_η0·(∂η0/∂g)² + u1_η0·(∂²η0/∂g²) + u2_η1·(∂η1/∂g)² + ...
173/// - W[0,3] = W[3,0] = u2_η0·c·(q0·c1 + s_f·z) + u1_η0·c1  (cross q0–g)
174/// - W[1,3] = W[3,1] = u2_η1·c·(q1·c1 + s_f·z) + u1_η1·c1  (cross q1–g)
175/// - W[2,3] = W[3,2] = u2_ad1·c·(qd1·c1) + u1_ad1·c1  (cross qd1–g)
176/// - All other off-diagonals are zero (η0, η1, ad1 depend on non-overlapping
177///   subsets of (q0,q1,qd1,g), and only g is shared across all three).
178///
179/// This is already computed by `row_primary_closed_form` and stored in
180/// `SurvivalRowHessian::h` after PSD-clamping.
181///
182/// # β-dependent W via `channel_hessian_at`
183///
184/// `channel_hessian_at` overrides the default β-independent path.  When
185/// `family_scalars` carries `SurvivalMarginalSlopeFamilyScalars`, the
186/// current per-row primary state `(q0_i, q1_i, qd1_i, g_i)` is read from
187/// those scalars and the 4×4 W_i is recomputed via `row_primary_for_compiler`.
188/// This makes `I(β) = J(β)^T W(β) J(β)` accurate at the current β instead of
189/// at the frozen pilot β=0 state.
190///
191/// When `family_scalars` is `None` but `beta` is zero-ish (all entries ≤ ε),
192/// the frozen pilot W is returned unchanged.  When `family_scalars` is `None`
193/// and any `beta` entry is non-trivial, `Err` is returned — the caller must
194/// supply scalars for a correct W at non-pilot β (same contract as T26's
195/// Jacobian callbacks: scalars required when β affects the primary state).
196impl FamilyChannelHessian for SurvivalRowHessian {
197    fn n_outputs(&self) -> usize {
198        K_SURVIVAL
199    }
200
201    fn n_subjects(&self) -> usize {
202        self.h.shape()[0]
203    }
204
205    fn fill_subject(&self, i: usize, out: &mut [f64]) {
206        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207        for a in 0..K_SURVIVAL {
208            for b in 0..K_SURVIVAL {
209                out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210            }
211        }
212    }
213
214    fn evaluate_full(&self) -> ndarray::Array3<f64> {
215        self.h.clone()
216    }
217
218    fn channel_hessian_at(
219        &self,
220        beta: &[f64],
221        family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222    ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223        use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225        let scalars_opt =
226            family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228        // Determine whether beta is non-trivial (any |β_j| > ε).
229        let beta_nontrivial = beta
230            .iter()
231            .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233        match scalars_opt {
234            None if beta_nontrivial => {
235                // β is non-zero in a way that would change W via the primary-state
236                // coupling (g ≠ 0 → c ≠ 1 → W changes).  Scalars are required.
237                Err(
238                    "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239                     family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240                     via FamilyLinearizationState::family_scalars to evaluate W(β) \
241                     correctly (same contract as T26 Jacobian callbacks)."
242                        .to_string(),
243                )
244            }
245            None => {
246                // β ≈ 0: return the frozen pilot W unchanged.
247                Ok(Arc::new(gam_problem::TensorChannelHessian {
248                    h: self.h.clone(),
249                }))
250            }
251            Some(sc) => {
252                let n = self.h.shape()[0];
253                if sc.q0_i.len() != n
254                    || sc.q1_i.len() != n
255                    || sc.qd1_i.len() != n
256                    || sc.g_i.len() != n
257                    || sc.z_i.len() != n
258                {
259                    return Err(format!(
260                        "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261                         (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262                        sc.q0_i.len(),
263                        sc.q1_i.len(),
264                        sc.qd1_i.len(),
265                        sc.g_i.len(),
266                        sc.z_i.len(),
267                    ));
268                }
269                // We do not have weights/event stored in SurvivalRowHessian itself.
270                // The scalars carry the per-row primary state; we need per-row weights
271                // and event indicators to call row_primary_for_compiler.  Those are
272                // NOT stored in SurvivalMarginalSlopeFamilyScalars — so we can only
273                // recompute W's structural shape (the 4×4 curvature geometry) using
274                // unit weights and event=1, which gives us the correct _direction_ of
275                // W at the current β even if the magnitude is off by the sample weight.
276                //
277                // For the drift-detection audit the direction matters more than the
278                // exact per-row magnitudes: rank changes emerge from structural
279                // identifiability, not from per-row weight scaling. Using w=1, d=1
280                // is therefore the principled approximation for the audit path.
281                //
282                // Production callers that need exact W (e.g. for the Fisher Gram in
283                // the compiler) should use SurvivalRowHessian::from_pilot_primary_state
284                // directly with the true per-row weights and event indicators.
285                let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286                for i in 0..n {
287                    let q0 = sc.q0_i[i];
288                    let q1 = sc.q1_i[i];
289                    let qd1 = sc.qd1_i[i];
290                    let g = sc.g_i[i];
291                    let z = sc.z_i[i];
292                    // Use unit weight and d=1 (event indicator 1) for the audit path.
293                    // The derivative_guard is the family default (small but non-zero).
294                    match crate::survival::marginal_slope::row_primary_for_compiler(
295                        q0, q1, qd1, g, z, 1.0,  // w = unit weight
296                        1.0,  // d = event
297                        crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298                        sc.s, // probit_scale from scalars
299                    ) {
300                        Ok((_nll, _grad, hess)) => {
301                            let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302                            for a in 0..K_SURVIVAL {
303                                for b in 0..K_SURVIVAL {
304                                    h_i[[a, b]] = hess[a][b];
305                                }
306                            }
307                            let clamped = psd_clamp_4x4(&h_i);
308                            for a in 0..K_SURVIVAL {
309                                for b in 0..K_SURVIVAL {
310                                    h_full[[i, a, b]] = clamped[[a, b]];
311                                }
312                            }
313                        }
314                        Err(_) => {
315                            // Monotonicity violation or other numerical issue at this
316                            // row: fall back to the frozen pilot W for this row only.
317                            for a in 0..K_SURVIVAL {
318                                for b in 0..K_SURVIVAL {
319                                    h_full[[i, a, b]] = self.h[[i, a, b]];
320                                }
321                            }
322                        }
323                    }
324                }
325                Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326            }
327        }
328    }
329}
330
331/// Project a 4×4 symmetric matrix onto the PSD cone: zero negative
332/// eigenvalues. If the eigendecomposition fails (extremely defensive —
333/// `row_primary_closed_form` already guarantees finite entries), return
334/// the diagonal with negatives clamped.
335fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336    let k = m.nrows();
337    let (evals, evecs) = match m.eigh(Side::Lower) {
338        Ok(pair) => pair,
339        Err(_) => {
340            let mut out = Array2::<f64>::zeros((k, k));
341            for i in 0..k {
342                out[[i, i]] = m[[i, i]].max(0.0);
343            }
344            return out;
345        }
346    };
347    let mut out = Array2::<f64>::zeros((k, k));
348    for i in 0..k {
349        for j in 0..k {
350            let mut acc = 0.0;
351            for l in 0..k {
352                acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353            }
354            out[[i, j]] = acc;
355        }
356    }
357    out
358}
359
360/// Row Jacobian operator for the survival time block. Channels (q0, q1,
361/// qd1) come from the three time designs; the g channel is zero.
362pub struct TimeBlockOperator {
363    dq0: Array2<f64>,
364    dq1: Array2<f64>,
365    dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369    pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370        assert_eq!(dq0.dim(), dq1.dim());
371        assert_eq!(dq0.dim(), dqd1.dim());
372        Self { dq0, dq1, dqd1 }
373    }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377    fn k(&self) -> usize {
378        K_SURVIVAL
379    }
380    fn ncols(&self) -> usize {
381        self.dq0.ncols()
382    }
383    fn nrows(&self) -> usize {
384        self.dq0.nrows()
385    }
386    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387        assert_eq!(out.len(), K_SURVIVAL);
388        assert_eq!(delta_beta.len(), self.dq0.ncols());
389        let mut acc = [0.0_f64; K_SURVIVAL];
390        for (j, &b) in delta_beta.iter().enumerate() {
391            acc[0] += self.dq0[[row, j]] * b;
392            acc[1] += self.dq1[[row, j]] * b;
393            acc[2] += self.dqd1[[row, j]] * b;
394        }
395        out.copy_from_slice(&acc);
396    }
397    fn evaluate_full(&self) -> Array3<f64> {
398        let n = self.dq0.nrows();
399        let p = self.dq0.ncols();
400        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401        for i in 0..n {
402            for j in 0..p {
403                out[[i, j, 0]] = self.dq0[[i, j]];
404                out[[i, j, 1]] = self.dq1[[i, j]];
405                out[[i, j, 2]] = self.dqd1[[i, j]];
406            }
407        }
408        out
409    }
410    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411        // Scale straight out of the three compact `(n, p)` channel designs —
412        // the compiler consumes the `(n·K, p)` sqrt(H)-scaled design, so the
413        // dense `(n, p, K)` tensor (3 of its 4 channels held explicitly, the
414        // 4th identically zero) that `evaluate_full()` builds is never needed.
415        // (#738: a capability is not a representation.)
416        let n = self.dq0.nrows();
417        let p = self.dq0.ncols();
418        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419            0 => self.dq0[[i, a]],
420            1 => self.dq1[[i, a]],
421            2 => self.dqd1[[i, a]],
422            _ => 0.0,
423        })
424    }
425}
426
427/// Row Jacobian operator for a block whose contribution flows into the
428/// q-channels (q0 and q1 identically) and optionally the qd1 channel.
429/// Covers the survival marginal, score-warp, and link-dev blocks (all
430/// three share the structural property `δq0 = δq1 = basis·δβ`, `δg = 0`).
431pub struct QChannelBlockOperator {
432    dq: Array2<f64>,
433    dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437    pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438        assert_eq!(dq.dim(), dqd1.dim());
439        Self { dq, dqd1 }
440    }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444    fn k(&self) -> usize {
445        K_SURVIVAL
446    }
447    fn ncols(&self) -> usize {
448        self.dq.ncols()
449    }
450    fn nrows(&self) -> usize {
451        self.dq.nrows()
452    }
453    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454        assert_eq!(out.len(), K_SURVIVAL);
455        assert_eq!(delta_beta.len(), self.dq.ncols());
456        let mut dq_acc = 0.0;
457        let mut dqd_acc = 0.0;
458        for (j, &b) in delta_beta.iter().enumerate() {
459            dq_acc += self.dq[[row, j]] * b;
460            dqd_acc += self.dqd1[[row, j]] * b;
461        }
462        out[0] = dq_acc;
463        out[1] = dq_acc;
464        out[2] = dqd_acc;
465        out[3] = 0.0;
466    }
467    fn evaluate_full(&self) -> Array3<f64> {
468        let n = self.dq.nrows();
469        let p = self.dq.ncols();
470        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471        for i in 0..n {
472            for j in 0..p {
473                let v = self.dq[[i, j]];
474                out[[i, j, 0]] = v;
475                out[[i, j, 1]] = v;
476                out[[i, j, 2]] = self.dqd1[[i, j]];
477            }
478        }
479        out
480    }
481    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482        // q0 and q1 share `dq`; qd1 is `dqd1`; the g channel is identically
483        // zero. Scale directly from the compact `(n, p)` designs, skipping the
484        // dense `(n, p, K)` tensor `evaluate_full()` would build. (#738.)
485        let n = self.dq.nrows();
486        let p = self.dq.ncols();
487        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488            0 | 1 => self.dq[[i, a]],
489            2 => self.dqd1[[i, a]],
490            _ => 0.0,
491        })
492    }
493}
494
495/// Row Jacobian operator for the survival logslope block: contribution
496/// lives entirely on the g channel.
497pub struct LogslopeBlockOperator {
498    dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502    pub fn new(dg: Array2<f64>) -> Self {
503        Self { dg }
504    }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508    fn k(&self) -> usize {
509        K_SURVIVAL
510    }
511    fn ncols(&self) -> usize {
512        self.dg.ncols()
513    }
514    fn nrows(&self) -> usize {
515        self.dg.nrows()
516    }
517    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518        assert_eq!(out.len(), K_SURVIVAL);
519        assert_eq!(delta_beta.len(), self.dg.ncols());
520        let mut acc = 0.0;
521        for (j, &b) in delta_beta.iter().enumerate() {
522            acc += self.dg[[row, j]] * b;
523        }
524        out[0] = 0.0;
525        out[1] = 0.0;
526        out[2] = 0.0;
527        out[3] = acc;
528    }
529    fn evaluate_full(&self) -> Array3<f64> {
530        let n = self.dg.nrows();
531        let p = self.dg.ncols();
532        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533        for i in 0..n {
534            for j in 0..p {
535                out[[i, j, 3]] = self.dg[[i, j]];
536            }
537        }
538        out
539    }
540    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541        // The logslope contribution lives entirely on the g channel (3); the
542        // other three channels are identically zero. Scale directly from the
543        // compact `(n, p)` design, skipping the mostly-zero dense `(n, p, K)`
544        // tensor `evaluate_full()` would build. (#738.)
545        let n = self.dg.nrows();
546        let p = self.dg.ncols();
547        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548            if c == 3 { self.dg[[i, a]] } else { 0.0 }
549        })
550    }
551}
552
553/// Inputs assembled for the survival fit driver to feed `compile()`. The
554/// ordering follows `gauge_priority` descending (time=200 → marginal=150 →
555/// logslope=120 → score_warp=80 → link_dev=60).
556pub struct SurvivalCompilerInputs {
557    pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558    pub ordering: Vec<BlockOrder>,
559}
560
561/// Per-block V reparameterisation matrices for the three parametric
562/// survival blocks emitted by [`compile_survival_parametric_designs`].
563/// Each `v_*` is a `(p_block_raw × p_block_kept)` selection-or-rotation
564/// matrix that maps a `β_kept` coefficient vector to its `β_raw`
565/// equivalent: `β_raw = V · β_kept`. The construction site applies these
566/// to the raw block designs (`design_raw · V → design_compiled`) and
567/// to the penalties (`Vᵀ S V`) before building `ParameterBlockSpec`s
568/// and passing the compiled designs into `make_family`.
569///
570/// Phase-4b architecture: this is the seam where the family-agnostic
571/// row-Jacobian compiler hands control back to the family-specific
572/// construction site. Each `v_*` width equals the corresponding
573/// `CompiledBlocks::blocks[i].t_lw.ncols()` — i.e., the kept-direction
574/// count after sqrt-H-metric residualisation and post-walk RRQR
575/// trailing-pivot drop.
576pub struct SurvivalParametricCompiled {
577    pub v_time: Array2<f64>,
578    pub v_marginal: Array2<f64>,
579    pub v_logslope: Array2<f64>,
580    /// Per-block dropped raw-column count, indexed
581    /// `(time_dropped, marginal_dropped, logslope_dropped)`. Equal to
582    /// `(p_raw − v.ncols())` for each block. Useful for logging the
583    /// gauge-attribution summary at the construction site.
584    pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588    raw: DesignMatrix,
589    v: &Array2<f64>,
590    context: &str,
591) -> Result<DesignMatrix, String> {
592    if raw.ncols() != v.nrows() {
593        return Err(format!(
594            "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595            raw.ncols(),
596            v.nrows(),
597            v.nrows(),
598            v.ncols(),
599        ));
600    }
601    let inner_dense = match raw {
602        DesignMatrix::Dense(d) => d,
603        DesignMatrix::Sparse(_) => {
604            let dense = raw
605                .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606                .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607            DenseDesignMatrix::from(dense)
608        }
609    };
610    let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611        .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612    Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615/// Per-term V reparameterisation matrices for the three parametric
616/// survival blocks. Each block's full V is the block-diagonal assembly
617/// of its per-term V's (one entry per element of the input
618/// `*_partition`). Preserves per-term penalty structure: applying
619/// `V_b = block_diag(V_term1, ..., V_termM)` to a per-term BlockwisePenalty
620/// pulls each penalty back only via its OWN term's V, so what was a
621/// per-term λ tunable in REML stays per-term tunable.
622pub struct SurvivalParametricCompiledPerTerm {
623    pub v_time_per_term: Vec<Array2<f64>>,
624    pub v_marginal_per_term: Vec<Array2<f64>>,
625    pub v_logslope_per_term: Vec<Array2<f64>>,
626    /// Per-term residualised reparam `R_b = M_b · V_b` from the
627    /// identifiability compiler, in the same global compile order
628    /// (time terms, then marginal terms, then logslope terms). `None`
629    /// for the very first compiled block (no anchor). Used by the
630    /// V+M-exact apply path to emit residualised rows
631    /// `C_b·V_b − A_{<b}·R_b` and to assemble the full triangular T.
632    pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633    /// Per-block drops (raw_cols − sum(kept_cols across terms)).
634    pub drops_by_block: (usize, usize, usize),
635}
636
637/// Per-term-aware compile: residualise each block's TERMS individually
638/// in priority order so the emitted V is block-diagonal on term
639/// boundaries. This preserves the per-term penalty structure that
640/// REML's per-λ accounting depends on.
641///
642/// Each `*_partition` is a list of disjoint contiguous column ranges
643/// covering `[0..p_block)`. For the marginal/logslope blocks the
644/// natural source is the union of `BlockwisePenalty::col_range` values
645/// (one per smoothness penalty / term) plus the complement
646/// (unpenalised parametric columns).
647///
648/// Order of residualisation: time terms first (in their partition
649/// order), then marginal terms, then logslope terms. Within each
650/// block, terms are residualised against ALL prior anchor columns
651/// (terms from earlier blocks + earlier terms within this block).
652/// Aliased directions land in the lowest-priority block that contains
653/// them, in the natural term order within that block — matching the
654/// gauge-priority ownership contract.
655pub fn compile_survival_parametric_designs_per_term(
656    time_dq0: Array2<f64>,
657    time_dq1: Array2<f64>,
658    time_dqd1: Array2<f64>,
659    time_partition: &[std::ops::Range<usize>],
660    marginal_dq: Array2<f64>,
661    marginal_dqd1: Array2<f64>,
662    marginal_partition: &[std::ops::Range<usize>],
663    logslope_dg: Array2<f64>,
664    logslope_partition: &[std::ops::Range<usize>],
665    row_hess: &dyn RowHessian,
666    protect_time: bool,
667) -> Result<SurvivalParametricCompiledPerTerm, String> {
668    use gam_identifiability::families::compiler::compile_protected;
669
670    let p_time = time_dq0.ncols();
671    let p_marg = marginal_dq.ncols();
672    let p_log = logslope_dg.ncols();
673    validate_partition(time_partition, p_time, "time")?;
674    validate_partition(marginal_partition, p_marg, "marginal")?;
675    validate_partition(logslope_partition, p_log, "logslope")?;
676
677    // Build per-term operators. Each term gets its own RowJacobianOperator
678    // restricted to its column slice; the operator type matches the
679    // block's K-channel signature (Time, QChannel, Logslope).
680    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
681    let mut ordering: Vec<BlockOrder> = Vec::new();
682    for range in time_partition {
683        let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
684        let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
685        let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
686        operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
687        ordering.push(BlockOrder::Time);
688    }
689    for range in marginal_partition {
690        let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
691        let dqd1 = marginal_dqd1
692            .slice(ndarray::s![.., range.clone()])
693            .to_owned();
694        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
695        ordering.push(BlockOrder::Marginal);
696    }
697    for range in logslope_partition {
698        let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
699        operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
700        ordering.push(BlockOrder::Logslope);
701    }
702
703    // The time block carries the monotone time-wiggle basis, whose effective
704    // Jacobian is a fixed nonlinear functional basis rather than a linear
705    // design. When `protect_time` is set it must be kept at full raw width: a
706    // linear reparameterisation of it would desynchronise the raw-width
707    // wiggle-basis chain rule (`SmsTimewiggleTimeJacobian`), which recomputes
708    // the basis on every evaluation. Marginal/logslope still reduce against the
709    // full time anchor. The time block spans operators `0..n_time` (pushed
710    // first above); mark exactly those protected.
711    let n_time = time_partition.len();
712    let protected: Vec<bool> = if protect_time {
713        (0..operators.len()).map(|i| i < n_time).collect()
714    } else {
715        Vec::new()
716    };
717    let compiled =
718        compile_protected(&operators, row_hess, &ordering, &protected).map_err(|e| {
719            format!("identifiability::families::compiler::compile (per-term) failed: {e}")
720        })?;
721    let blocks = compiled.blocks;
722    let n_marg = marginal_partition.len();
723    let n_log = logslope_partition.len();
724    if blocks.len() != n_time + n_marg + n_log {
725        return Err(format!(
726            "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
727            n_time + n_marg + n_log,
728            n_time,
729            n_marg,
730            n_log,
731            blocks.len(),
732        ));
733    }
734    let mut iter = blocks.into_iter();
735    let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
736    let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
737    for _ in 0..n_time {
738        let blk = iter.next().unwrap();
739        v_time_per_term.push(blk.t_lw);
740        r_time_per_term.push(blk.r_lw);
741    }
742    let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
743    let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
744    for _ in 0..n_marg {
745        let blk = iter.next().unwrap();
746        v_marginal_per_term.push(blk.t_lw);
747        r_marginal_per_term.push(blk.r_lw);
748    }
749    let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
750    let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
751    for _ in 0..n_log {
752        let blk = iter.next().unwrap();
753        v_logslope_per_term.push(blk.t_lw);
754        r_logslope_per_term.push(blk.r_lw);
755    }
756    let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
757    r_lw_per_term.extend(r_time_per_term);
758    r_lw_per_term.extend(r_marginal_per_term);
759    r_lw_per_term.extend(r_logslope_per_term);
760    let drops_time: usize = time_partition
761        .iter()
762        .zip(v_time_per_term.iter())
763        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
764        .sum();
765    let drops_marg: usize = marginal_partition
766        .iter()
767        .zip(v_marginal_per_term.iter())
768        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
769        .sum();
770    let drops_log: usize = logslope_partition
771        .iter()
772        .zip(v_logslope_per_term.iter())
773        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
774        .sum();
775    Ok(SurvivalParametricCompiledPerTerm {
776        v_time_per_term,
777        v_marginal_per_term,
778        v_logslope_per_term,
779        r_lw_per_term,
780        drops_by_block: (drops_time, drops_marg, drops_log),
781    })
782}
783
784fn validate_partition(
785    partition: &[std::ops::Range<usize>],
786    p_block: usize,
787    label: &str,
788) -> Result<(), String> {
789    if partition.is_empty() {
790        if p_block == 0 {
791            return Ok(());
792        }
793        return Err(format!(
794            "{label} partition empty but block has p={p_block} columns"
795        ));
796    }
797    if partition[0].start != 0 {
798        return Err(format!(
799            "{label} partition must start at 0, got start={}",
800            partition[0].start
801        ));
802    }
803    if partition.last().unwrap().end != p_block {
804        return Err(format!(
805            "{label} partition must cover [0, {p_block}); last range ends at {}",
806            partition.last().unwrap().end
807        ));
808    }
809    for w in partition.windows(2) {
810        if w[0].end != w[1].start {
811            return Err(format!(
812                "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
813                w[0].start, w[0].end, w[1].start, w[1].end
814            ));
815        }
816        if w[0].is_empty() {
817            return Err(format!(
818                "{label} partition has empty range [{}..{})",
819                w[0].start, w[0].end
820            ));
821        }
822    }
823    if partition.last().unwrap().is_empty() {
824        return Err(format!("{label} partition's final range is empty",));
825    }
826    Ok(())
827}
828
829/// Derive a disjoint contiguous partition of `[0..p_block)` from a
830/// list of BlockwisePenalty col_ranges. Distinct penalty ranges define
831/// term boundaries; gaps between them (unpenalised columns) become
832/// their own single-column partitions. Multiple penalties with the
833/// SAME col_range (e.g. tensor anisotropy axes) coalesce to one term.
834pub fn extract_term_partition_from_penalty_ranges(
835    p_block: usize,
836    penalty_ranges: &[std::ops::Range<usize>],
837) -> Vec<std::ops::Range<usize>> {
838    use std::collections::BTreeSet;
839    let mut starts: BTreeSet<usize> = BTreeSet::new();
840    starts.insert(0);
841    starts.insert(p_block);
842    for r in penalty_ranges {
843        starts.insert(r.start.min(p_block));
844        starts.insert(r.end.min(p_block));
845    }
846    let v: Vec<usize> = starts.into_iter().collect();
847    v.windows(2)
848        .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
849        .collect()
850}
851
852/// Pull a single raw block-local [`BlockwisePenalty`] back through the
853/// block's own diagonal reparameterisation `V_b` (the `(b, b)` block of
854/// the triangular T), producing a per-block-width compiled penalty.
855///
856/// The penalty's `local` is `pen.col_range.len()` square and covers a
857/// sub-region of the raw block at offset `pen.col_range.start` (which is
858/// block-local, i.e. relative to the block's first raw column). It is
859/// embedded into the full raw block width `v_block.nrows()` at that
860/// offset, then pulled back as `V_bᵀ · embed(S) · V_b`, giving a
861/// `(w_b_compiled × w_b_compiled)` symmetric `PenaltyMatrix::Dense`
862/// where `w_b_compiled == v_block.ncols()`.
863///
864/// This is the penalty contract a per-block `ParameterBlockSpec`
865/// requires: each block's penalty acts on that block's own compiled
866/// coordinate `θ_b`. The cross-block residualisation `R_{a→b}` carried
867/// in T's strict-upper triangle is absorbed into the *design* columns
868/// (the residualised emitted design `C_b V_b − A_{<b} R_b`), not into
869/// the penalty — exactly as the VM-exact compile-map path
870/// [`apply_compiled_map_to_designs`] does. Pulling the penalty back
871/// through the full joint T instead would yield a `(p_compiled × p_compiled)` dense
872/// matrix that cannot live in a single block's `penalties` slot and
873/// would violate the `p_b × p_b` block-spec validation.
874pub fn pull_back_blockwise_penalty_through_block_v(
875    pen: &gam_terms::smooth::BlockwisePenalty,
876    v_block: &Array2<f64>,
877) -> Result<PenaltyMatrix, String> {
878    let raw_p = v_block.nrows();
879    let compiled_p = v_block.ncols();
880    let block_p = pen.col_range.len();
881    let embed_start = pen.col_range.start;
882    let embed_end = pen.col_range.end;
883    if embed_end > raw_p {
884        return Err(format!(
885            "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
886             exceeds block raw width {raw_p}"
887        ));
888    }
889    if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
890        return Err(format!(
891            "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
892             width is {block_p}",
893            pen.local.nrows(),
894            pen.local.ncols(),
895        ));
896    }
897    let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
898    if block_p > 0 {
899        let mut dst =
900            embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
901        for i in 0..block_p {
902            for j in 0..block_p {
903                dst[[i, j]] = pen.local[[i, j]];
904            }
905        }
906    }
907    // V_bᵀ · embed(S) · V_b → (compiled_p × compiled_p).
908    let temp = embedded.dot(v_block);
909    let pulled = v_block.t().dot(&temp);
910    let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
911    for i in 0..compiled_p {
912        for j in 0..compiled_p {
913            sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
914        }
915    }
916    Ok(PenaltyMatrix::Dense(sym))
917}
918
919/// Assemble a 3-block [`CompiledMap`] (time, marginal, logslope) from a
920/// [`SurvivalParametricCompiledPerTerm`] produced by the full 4×4 row-Hessian
921/// driver [`compile_survival_parametric_designs_per_term`].
922///
923/// The full global triangular `T` is built from the per-term `V`/`R` blocks
924/// (diagonal `V_b`, strict-upper `−R_{a→b}` — identical to the matrix the
925/// result-time lift [`Gauge::from_v_and_r`] uses), then partitioned
926/// into the three *block* ranges (raw = summed per-term raw widths, compiled =
927/// summed per-term kept widths). The resulting `CompiledMap` is interchangeable
928/// with one from
929/// [`gam_identifiability::families::compiler::compile_from_raw_grams`], so the
930/// existing [`apply_compiled_map_to_designs`] +
931/// [`Gauge::from_compiled_map`] machinery consumes it unchanged.
932///
933/// This is the seam that lets the survival closed-form fast path engage on the
934/// *correct* identifiable quotient: the cheap η₁-only rawstack metric can
935/// falsely collapse a whole channel (marginal/logslope share a PC surface in
936/// the η₁ row curvature), but the full survival row Hessian is 4×4 in
937/// `(q0, q1, qd1, g)` and chains differently into each block, so it keeps the
938/// channels distinct when no *true* alias exists. The reduced basis it emits
939/// goes to Newton in place of the rank-deficient raw basis.
940pub fn compiled_map_from_per_term(
941    compiled: &SurvivalParametricCompiledPerTerm,
942) -> gam_identifiability::families::compiler::CompiledMap {
943    // Per-term V's and R's in global compile order: time terms, then marginal,
944    // then logslope — exactly the order `r_lw_per_term` is stored in.
945    let mut v_all: Vec<Array2<f64>> = Vec::new();
946    v_all.extend(compiled.v_time_per_term.iter().cloned());
947    v_all.extend(compiled.v_marginal_per_term.iter().cloned());
948    v_all.extend(compiled.v_logslope_per_term.iter().cloned());
949
950    let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
951
952    // Per-block raw / compiled widths = summed per-term widths within the block.
953    let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
954    let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
955    let raw_time = raw_w(&compiled.v_time_per_term);
956    let raw_marg = raw_w(&compiled.v_marginal_per_term);
957    let raw_log = raw_w(&compiled.v_logslope_per_term);
958    let kept_time = kept_w(&compiled.v_time_per_term);
959    let kept_marg = kept_w(&compiled.v_marginal_per_term);
960    let kept_log = kept_w(&compiled.v_logslope_per_term);
961
962    let raw_block_ranges = vec![
963        0..raw_time,
964        raw_time..(raw_time + raw_marg),
965        (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
966    ];
967    let compiled_block_ranges = vec![
968        0..kept_time,
969        kept_time..(kept_time + kept_marg),
970        (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
971    ];
972
973    gam_identifiability::families::compiler::CompiledMap {
974        raw_from_compiled: t_full,
975        compiled_block_ranges,
976        raw_block_ranges,
977    }
978}
979
980/// Build a W-orthogonal **partial** reduced-logslope reparameterisation `T`
981/// (`p_log × r`, `0 < r < p_log`) for the survival marginal↔logslope confound,
982/// mirroring the proven-correct BMS effective-Schur-Gram construction
983/// [`crate::bms::block_specs::reduced_logslope_transform_effective`] but in
984/// survival's per-row 4×4 primary-state Hessian metric.
985///
986/// # Why this exists (#979)
987///
988/// The survival marginal and logslope channels share the SAME spatial basis
989/// (e.g. `matern(PC1,PC2,PC3)`), so on clustered-PC data the full 4×4
990/// row-Hessian identifiability compiler can attribute the *entire* shared
991/// surface to the lowest-priority logslope block and collapse it to zero width
992/// — which the `#741` required-channel guard rejects, forcing a fallback to the
993/// UNREDUCED design + Jeffreys conditioning. That fallback leaves a
994/// quadratically-flat near-null direction in the joint penalised Hessian
995/// `M = JᵀHJ + S`, so the inner joint-Newton cannot certify stationarity and
996/// runs to its bounded iteration cap without converging.
997///
998/// The BMS path never hits this because it does a *partial* reduction: it
999/// removes from the logslope block ONLY the directions whose effective image is
1000/// W-explained by the marginal span (the confounded null space of the effective
1001/// Schur Gram), keeping every surviving logslope direction. The result is
1002/// full-rank `M` BY CONSTRUCTION — no runtime projection needed.
1003///
1004/// # The metric collapse to scalar weights
1005///
1006/// At the pilot the marginal design feeds the primary channels `(q0, q1)`
1007/// identically (`∂q0/∂β_m = ∂q1/∂β_m = m`, and `∂qd1/∂β_m = 0` because the
1008/// `#808` fallback always builds a zero marginal-derivative design) and the
1009/// logslope design feeds only `g` (`∂g/∂β_s = g_dg`). With the per-row PSD 4×4
1010/// Hessian `H` in channel order `(q0, q1, qd1, g)` and `H[0,1] = 0` (q0 and q1
1011/// enter the disjoint outputs η0, η1), the combined effective Gram
1012/// `[[A, Bᵀ], [B, C]] = J_combinedᵀ H J_combined` (PSD per row) collapses to
1013/// scalar-weighted Grams of the raw block designs:
1014///
1015/// ```text
1016///     w_mm = H00 + H11           (marginal self weight, ≥ 0)
1017///     w_mg = H03 + H13           (marginal↔logslope cross weight)
1018///     w_gg = H33                 (logslope self weight, ≥ 0)
1019///     A = m_dqᵀ diag(w_mm) m_dq + εI     (p_m × p_m)
1020///     B = m_dqᵀ diag(w_mg) g_dg          (p_m × p_log)
1021///     C = g_dgᵀ diag(w_gg) g_dg          (p_log × p_log)
1022///     Gtt = C − Bᵀ A⁻¹ B                 (p_log × p_log, PSD Schur complement)
1023/// ```
1024///
1025/// `T` is the orthonormal eigenbasis of `Gtt` for eigenvalues above a tolerance
1026/// relative to the effective logslope energy scale (single-sourced from the BMS
1027/// reference cut). Returns `Ok(None)` when there is nothing to reduce
1028/// (`r == p_log`) or the entire effective logslope image collapses into the
1029/// marginal span (`r == 0`); in both cases the caller keeps its existing path.
1030///
1031/// Precondition: `marginal_dq`'s derivative-into-qd1 contribution is zero (the
1032/// `#808` fallback constructs `m_dqd1` as an all-zero matrix), so marginal
1033/// touches only `(q0, q1)` and the scalar collapse above is exact.
1034pub fn survival_reduced_logslope_transform_effective(
1035    marginal_dq: ndarray::ArrayView2<'_, f64>,
1036    logslope_dg: ndarray::ArrayView2<'_, f64>,
1037    row_hess: &SurvivalRowHessian,
1038) -> Result<Option<Array2<f64>>, String> {
1039    use crate::bms::block_specs::LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1040    use gam_linalg::faer_ndarray::{
1041        FaerArrayView, factorize_symmetricwith_fallback, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
1042    };
1043
1044    let n = marginal_dq.nrows();
1045    let p_m = marginal_dq.ncols();
1046    let p_log = logslope_dg.ncols();
1047    if p_m == 0 || p_log == 0 {
1048        return Ok(None);
1049    }
1050    if logslope_dg.nrows() != n || row_hess.h.shape()[0] != n {
1051        return Err(format!(
1052            "survival reduced logslope: row mismatch marginal={n}, logslope={}, row_hess={}",
1053            logslope_dg.nrows(),
1054            row_hess.h.shape()[0],
1055        ));
1056    }
1057
1058    // Scalar effective weights from the per-row 4×4 PSD Hessian, channel order
1059    // (q0, q1, qd1, g). Marginal → {q0, q1} (identical column m), logslope → {g}.
1060    let mut w_mm = Array1::<f64>::zeros(n);
1061    let mut w_mg = Array1::<f64>::zeros(n);
1062    let mut w_gg = Array1::<f64>::zeros(n);
1063    for i in 0..n {
1064        w_mm[i] = row_hess.h[[i, 0, 0]] + row_hess.h[[i, 1, 1]];
1065        w_mg[i] = row_hess.h[[i, 0, 3]] + row_hess.h[[i, 1, 3]];
1066        w_gg[i] = row_hess.h[[i, 3, 3]];
1067        if !(w_mm[i].is_finite() && w_mg[i].is_finite() && w_gg[i].is_finite()) {
1068            return Err("survival reduced logslope: non-finite row Hessian weight".to_string());
1069        }
1070    }
1071
1072    let marg = marginal_dq.to_owned();
1073    let log = logslope_dg.to_owned();
1074
1075    // C = G_effᵀ W G_eff (raw-coordinate effective logslope Gram); its diagonal
1076    // sets the energy scale for the relative kept-direction tolerance.
1077    let c_gram = fast_xt_diag_x(&log, &w_gg);
1078    let energy_scale = (0..p_log).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
1079    if !energy_scale.is_finite() || energy_scale <= 0.0 {
1080        return Ok(None);
1081    }
1082
1083    // A = M_effᵀ W M_eff + εI (ridge relative to the marginal effective energy
1084    // so the Schur solve is well-posed even when the marginal pilot Gram is
1085    // rank-soft; the ridge only under-removes, i.e. is conservative).
1086    let mut a_gram = fast_xt_diag_x(&marg, &w_mm);
1087    let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
1088    let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
1089    for i in 0..p_m {
1090        a_gram[[i, i]] += a_ridge;
1091    }
1092
1093    // B = M_effᵀ W G_eff (p_m × p_log);  Gtt = C − Bᵀ A⁻¹ B (p_log × p_log, PSD).
1094    let b_cross = fast_xt_diag_y(&marg, &w_mg, &log);
1095    let a_view = FaerArrayView::new(&a_gram);
1096    let a_factor = factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower).map_err(|e| {
1097        format!("survival reduced logslope: marginal effective Gram factorization failed: {e}")
1098    })?;
1099    let b_view = FaerArrayView::new(&b_cross);
1100    let solved = a_factor.solve(b_view.as_ref()); // A⁻¹ B  (p_m × p_log)
1101    let a_inv_b = Array2::from_shape_fn((p_m, p_log), |(i, j)| solved[(i, j)]);
1102    let schur = fast_atb(&b_cross, &a_inv_b); // Bᵀ A⁻¹ B  (p_log × p_log)
1103    let mut stt = &c_gram - &schur;
1104    stt = (&stt + &stt.t()) * 0.5;
1105    if stt.iter().any(|v| !v.is_finite()) {
1106        return Err(
1107            "survival reduced logslope: effective Schur Gram produced non-finite entries"
1108                .to_string(),
1109        );
1110    }
1111
1112    let (evals, evecs) = stt
1113        .eigh(Side::Lower)
1114        .map_err(|e| format!("survival reduced logslope: eigendecomposition failed: {e:?}"))?;
1115    // A `Gtt` eigenvalue far below the effective logslope energy scale means that
1116    // direction's effective logslope column is W-explained by the effective
1117    // marginal span — exactly the joint-Hessian rank-soft confounded direction.
1118    let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1119    let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
1120    kept.sort_by(|&a, &b| {
1121        evals[b]
1122            .partial_cmp(&evals[a])
1123            .unwrap_or(std::cmp::Ordering::Equal)
1124    });
1125    let r = kept.len();
1126    // r == p_log: no confounded direction to remove. r == 0: the whole effective
1127    // logslope image is in the marginal span. In both cases keep the raw design.
1128    if r == p_log || r == 0 {
1129        return Ok(None);
1130    }
1131    let mut transform = Array2::<f64>::zeros((p_log, r));
1132    for (out_col, &src) in kept.iter().enumerate() {
1133        transform.column_mut(out_col).assign(&evecs.column(src));
1134    }
1135    if transform.iter().any(|v| !v.is_finite()) {
1136        return Err(
1137            "survival reduced logslope: reduced transform produced non-finite entries".to_string(),
1138        );
1139    }
1140    Ok(Some(transform))
1141}
1142
1143/// Assemble a block-diagonal 3-block [`CompiledMap`] that passes the time and
1144/// marginal blocks through unchanged (identity) and reparameterises ONLY the
1145/// logslope block via `t_log` (`p_log × r`). Used by the survival `#979`
1146/// partial reduced-logslope confound removal
1147/// ([`survival_reduced_logslope_transform_effective`]): the marginal/time
1148/// channels are untouched, the logslope block drops only its confounded
1149/// directions, and the joint penalised Hessian is full-rank by construction.
1150///
1151/// The resulting `CompiledMap` is interchangeable with one from
1152/// [`compiled_map_from_per_term`] /
1153/// [`gam_identifiability::families::compiler::compile_from_raw_grams`], so the
1154/// existing [`apply_compiled_map_to_designs`] + [`Gauge::from_compiled_map`]
1155/// machinery consumes it unchanged. Because the map is block-diagonal there is
1156/// no strict-upper cross-block residual `R`, and `apply_compiled_map_to_designs`
1157/// reads only the per-block diagonal `V_b = T[raw_b, compiled_b]` — `V_time` and
1158/// `V_marg` are identities, `V_log = t_log`.
1159pub fn survival_block_diagonal_logslope_map(
1160    p_time: usize,
1161    p_marg: usize,
1162    t_log: &Array2<f64>,
1163) -> gam_identifiability::families::compiler::CompiledMap {
1164    let p_log = t_log.nrows();
1165    let r = t_log.ncols();
1166    let raw_total = p_time + p_marg + p_log;
1167    let compiled_total = p_time + p_marg + r;
1168    let mut t_full = Array2::<f64>::zeros((raw_total, compiled_total));
1169    for i in 0..p_time {
1170        t_full[[i, i]] = 1.0;
1171    }
1172    for i in 0..p_marg {
1173        t_full[[p_time + i, p_time + i]] = 1.0;
1174    }
1175    for ri in 0..p_log {
1176        for cj in 0..r {
1177            t_full[[p_time + p_marg + ri, p_time + p_marg + cj]] = t_log[[ri, cj]];
1178        }
1179    }
1180    gam_identifiability::families::compiler::CompiledMap {
1181        raw_from_compiled: t_full,
1182        compiled_block_ranges: vec![
1183            0..p_time,
1184            p_time..(p_time + p_marg),
1185            (p_time + p_marg)..compiled_total,
1186        ],
1187        raw_block_ranges: vec![
1188            0..p_time,
1189            p_time..(p_time + p_marg),
1190            (p_time + p_marg)..raw_total,
1191        ],
1192    }
1193}
1194
1195/// Apply a global [`CompiledMap`] T directly to the three survival
1196/// parametric block designs (time/marginal/logslope). Slices the
1197/// per-block diagonal of T into `V_b = T[raw_range_b, compiled_range_b]`
1198/// (shape `p_b_raw × w_b_compiled`), wraps each channel's raw design via
1199/// [`wrap_design_with_transform`], and pulls each block's penalties back
1200/// through that block's OWN `V_b` via
1201/// [`pull_back_blockwise_penalty_through_block_v`], producing
1202/// per-block-width `(w_b_compiled × w_b_compiled)` penalties — the shape
1203/// a per-block `ParameterBlockSpec.penalties` slot requires.
1204///
1205/// `map.raw_block_ranges` must equal three contiguous ranges in the
1206/// order Time → Marginal → Logslope (matching the input designs).
1207/// `map.compiled_block_ranges` runs in the same order.
1208///
1209/// Penalties supplied to this function:
1210/// - `time_penalties` are `BlockwisePenalty`s whose `col_range` is in
1211///   the time block's local raw coords (e.g. `0..p_time`).
1212/// - `marginal_penalties` / `logslope_penalties` likewise — local to
1213///   their own channel's raw width.
1214///
1215/// Each penalty's block-local `col_range` is embedded into the block's
1216/// raw width and pulled back as `V_bᵀ S_b V_b`. The cross-block
1217/// residualisation `R_{a→b}` carried in T's strict-upper triangle is
1218/// absorbed into the residualised *design* columns, not the penalty, so
1219/// the per-block penalty model stays exact for the highest-priority
1220/// block (time, no anchor → `R = []`) and matches the sibling per-block
1221/// compile path for the rest.
1222pub fn apply_compiled_map_to_designs(
1223    map: &gam_identifiability::families::compiler::CompiledMap,
1224    time_design_entry: DesignMatrix,
1225    time_design_exit: DesignMatrix,
1226    time_design_derivative_exit: DesignMatrix,
1227    marginal_design: DesignMatrix,
1228    logslope_design: DesignMatrix,
1229    time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1230    marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1231    logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1232) -> Result<CompiledSurvivalDesignsVMExact, String> {
1233    if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1234        return Err(format!(
1235            "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1236             got {} raw / {} compiled",
1237            map.raw_block_ranges.len(),
1238            map.compiled_block_ranges.len(),
1239        ));
1240    }
1241    let time_raw = map.raw_block_ranges[0].clone();
1242    let marg_raw = map.raw_block_ranges[1].clone();
1243    let log_raw = map.raw_block_ranges[2].clone();
1244    let time_compiled = map.compiled_block_ranges[0].clone();
1245    let marg_compiled = map.compiled_block_ranges[1].clone();
1246    let log_compiled = map.compiled_block_ranges[2].clone();
1247
1248    let t = &map.raw_from_compiled;
1249    let raw_total = t.nrows();
1250    let compiled_total = t.ncols();
1251    let expected_raw_total = log_raw.end;
1252    if raw_total != expected_raw_total {
1253        return Err(format!(
1254            "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1255             {expected_raw_total}"
1256        ));
1257    }
1258    let expected_compiled_total = log_compiled.end;
1259    if compiled_total != expected_compiled_total {
1260        return Err(format!(
1261            "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1262             sum to {expected_compiled_total}"
1263        ));
1264    }
1265
1266    let v_time = t
1267        .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1268        .to_owned();
1269    let v_marg = t
1270        .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1271        .to_owned();
1272    let v_log = t
1273        .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1274        .to_owned();
1275
1276    let time_entry_out =
1277        wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1278    let time_exit_out =
1279        wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1280    let time_deriv_out = wrap_design_with_transform(
1281        time_design_derivative_exit,
1282        &v_time,
1283        "compiled-map: time derivative_exit",
1284    )?;
1285    let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1286    let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1287
1288    // Pull each block's penalties back through that block's OWN diagonal
1289    // reparameterisation V_b (= the (b, b) block of T). This produces a
1290    // per-block-width `(w_b_compiled × w_b_compiled)` penalty — the only
1291    // shape a per-block `ParameterBlockSpec.penalties` slot accepts.
1292    //
1293    // The block-local penalty `V_bᵀ S_b V_b` is the correct per-block
1294    // penalty: in raw coords the model penalises `γ_bᵀ S_b γ_b` on block
1295    // b's own coefficients, and under the residualised reparameterisation
1296    // the cross-block carry `R_{a→b}` lives entirely in the *design*
1297    // columns (`C_b V_b − A_{<b} R_b`), not in the penalty.
1298    //
1299    // Pulling penalties back through the full joint triangular T instead
1300    // (`Tᵀ blkdiag(S_b) T`) yields a `(p_compiled × p_compiled)` dense
1301    // matrix whose off-diagonal couples θ_b to earlier blocks' θ_a;
1302    // jamming that joint-width matrix into a single block's `penalties`
1303    // produced the `block 0 penalty 0 must be 12x12, got 17x17` mismatch
1304    // that surfaced as the `assert_valid_blockspecs` FFI panic. The two
1305    // agree whenever the residualisation `R_{a→b}` lands in the null space
1306    // of S_a (the shared low-order / parametric directions the identifiable
1307    // quotient strips), which is the case the compiler targets.
1308    let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1309                    v_block: &Array2<f64>,
1310                    channel: &str|
1311     -> Result<Vec<PenaltyMatrix>, String> {
1312        pens.iter()
1313            .map(|p| {
1314                pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1315                    format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1316                })
1317            })
1318            .collect()
1319    };
1320
1321    let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1322    let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1323    let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1324    validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1325    validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1326    validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1327
1328    Ok(CompiledSurvivalDesignsVMExact {
1329        time_design_entry: time_entry_out,
1330        time_design_exit: time_exit_out,
1331        time_design_derivative_exit: time_deriv_out,
1332        marginal_design: marg_out,
1333        logslope_design: log_out,
1334        time_penalties,
1335        marginal_penalties,
1336        logslope_penalties,
1337    })
1338}
1339
1340fn validate_block_penalty_shapes(
1341    block: &str,
1342    width: usize,
1343    penalties: &[PenaltyMatrix],
1344) -> Result<(), String> {
1345    for (idx, penalty) in penalties.iter().enumerate() {
1346        let shape = penalty.shape();
1347        if shape != (width, width) {
1348            return Err(format!(
1349                "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1350                shape.0, shape.1
1351            ));
1352        }
1353    }
1354    Ok(())
1355}
1356
1357/// Run the identifiability compiler on the three survival parametric
1358/// blocks (time, marginal, logslope) at a pilot β and return the per-
1359/// block V reparameterisation matrices.
1360///
1361/// `row_hess` must be a PSD per-row 4×4 Hessian of `−log L_i(u_i)` at
1362/// the pilot β (see [`SurvivalRowHessian::from_pilot_primary_state`]).
1363/// The compiler residualises blocks left-to-right in priority order
1364/// (time → marginal → logslope) in the sqrt-H-metric so any aliased
1365/// direction lands in the lower-priority block, then runs a post-walk
1366/// column-pivoted QR on the cumulative anchor and drops trailing
1367/// pivots from the latest block. The returned V matrices are ready to
1368/// be applied to each block's raw design and penalty before the
1369/// `ParameterBlockSpec` list is assembled.
1370///
1371/// On `FullyAliased` from `compile()` (a block fully absorbed by its
1372/// cumulative anchor) this returns `Err`. The construction site should
1373/// surface that as a structured user-facing diagnostic — the model is
1374/// asking the compiler to assign zero degrees of freedom to a named
1375/// parametric block, which is a model-spec bug not a numerical one.
1376///
1377/// Sibling Phase-4b wiring (`bernoulli_marginal_slope::install_compiled_flex_block_into_runtime`)
1378/// already calls `compile()` for the flex blocks. This helper extends
1379/// that contract to the parametric blocks by giving the SMGS
1380/// construction site a one-line entry point — it does NOT yet apply
1381/// the V transforms to the family's captured designs (the captured-
1382/// design update is the remaining integration step that touches the
1383/// family's row-Hessian assembly assertions).
1384pub fn compile_survival_parametric_designs(
1385    time_dq0: Array2<f64>,
1386    time_dq1: Array2<f64>,
1387    time_dqd1: Array2<f64>,
1388    marginal_dq: Array2<f64>,
1389    marginal_dqd1: Array2<f64>,
1390    logslope_dg: Array2<f64>,
1391    row_hess: &dyn RowHessian,
1392) -> Result<SurvivalParametricCompiled, String> {
1393    use gam_identifiability::families::compiler::compile;
1394
1395    let p_time_raw = time_dq0.ncols();
1396    let p_marg_raw = marginal_dq.ncols();
1397    let p_log_raw = logslope_dg.ncols();
1398
1399    let inputs = build_survival_compiler_inputs(
1400        time_dq0,
1401        time_dq1,
1402        time_dqd1,
1403        marginal_dq,
1404        marginal_dqd1,
1405        logslope_dg,
1406        None,
1407        None,
1408    );
1409    if inputs.operators.len() != 3 {
1410        return Err(format!(
1411            "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1412             (time, marginal, logslope); got {}",
1413            inputs.operators.len(),
1414        ));
1415    }
1416    let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1417        .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1418    if compiled.blocks.len() != 3 {
1419        return Err(format!(
1420            "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1421            compiled.blocks.len(),
1422        ));
1423    }
1424    let v_time = compiled.blocks[0].t_lw.clone();
1425    let v_marginal = compiled.blocks[1].t_lw.clone();
1426    let v_logslope = compiled.blocks[2].t_lw.clone();
1427    let drops_by_block = (
1428        p_time_raw.saturating_sub(v_time.ncols()),
1429        p_marg_raw.saturating_sub(v_marginal.ncols()),
1430        p_log_raw.saturating_sub(v_logslope.ncols()),
1431    );
1432    Ok(SurvivalParametricCompiled {
1433        v_time,
1434        v_marginal,
1435        v_logslope,
1436        drops_by_block,
1437    })
1438}
1439
1440/// Build the operator stack from already-materialised dense designs.
1441///
1442/// `time_dq0/dq1/dqd1` are the time block's three primary-state Jacobians
1443/// at training rows. `marginal_dq` and `marginal_dqd1` are the marginal
1444/// block's contributions to q (shared between q0 and q1) and to qd1
1445/// (typically zero unless timewiggle interacts). `logslope_dg` is the
1446/// logslope block's contribution to g.
1447///
1448/// `score_warp_(dq, dqd1)` / `link_dev_(dq, dqd1)` are present only when
1449/// the corresponding flex block is active. The returned `ordering` parallels
1450/// `operators` so the caller can route compiled outputs back to runtime slots.
1451pub fn build_survival_compiler_inputs(
1452    time_dq0: Array2<f64>,
1453    time_dq1: Array2<f64>,
1454    time_dqd1: Array2<f64>,
1455    marginal_dq: Array2<f64>,
1456    marginal_dqd1: Array2<f64>,
1457    logslope_dg: Array2<f64>,
1458    score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1459    link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1460) -> SurvivalCompilerInputs {
1461    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1462    let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1463
1464    operators.push(Arc::new(TimeBlockOperator::new(
1465        time_dq0, time_dq1, time_dqd1,
1466    )));
1467    ordering.push(BlockOrder::Time);
1468
1469    operators.push(Arc::new(QChannelBlockOperator::new(
1470        marginal_dq,
1471        marginal_dqd1,
1472    )));
1473    ordering.push(BlockOrder::Marginal);
1474
1475    operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1476    ordering.push(BlockOrder::Logslope);
1477
1478    if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1479        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1480        ordering.push(BlockOrder::ScoreWarp);
1481    }
1482    if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1483        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1484        ordering.push(BlockOrder::LinkDev);
1485    }
1486
1487    SurvivalCompilerInputs {
1488        operators,
1489        ordering,
1490    }
1491}
1492
1493/// V+M-exact compiled designs + per-block penalties for the survival
1494/// time/marginal/logslope blocks, produced by
1495/// [`apply_compiled_map_to_designs`] from a `CompiledMap`. The
1496/// construction site swaps raw designs/penalties for these compiled
1497/// versions before building `ParameterBlockSpec`s.
1498///
1499/// The emitted designs carry the exact residualised `C_b·V_b − A_{<b}·R_b`
1500/// row form (via [`wrap_design_with_transform`] on `V_b = T[raw_b, comp_b]`):
1501/// the cross-block residualisation `R_{a→b}` lives in those design columns,
1502/// while each block's penalty is pulled back through that block's own
1503/// diagonal `V_b` as `V_bᵀ S_b V_b` (the `*_penalties` fields).
1504///
1505/// At fit result the joint compiled β is lifted back to raw via the
1506/// `gam_solve::gauge::Gauge` built from the *same* `CompiledMap`
1507/// (`β_raw = T · θ`, T block-upper-triangular with `V_b` on the diagonal
1508/// and `-R_{a→b}` off-diagonal). The full T therefore lives on that
1509/// `Gauge`, not on this struct — the caller holds the `CompiledMap` and
1510/// constructs both from it, so duplicating T here would be dead state.
1511pub struct CompiledSurvivalDesignsVMExact {
1512    pub time_design_entry: DesignMatrix,
1513    pub time_design_exit: DesignMatrix,
1514    pub time_design_derivative_exit: DesignMatrix,
1515    pub marginal_design: DesignMatrix,
1516    pub logslope_design: DesignMatrix,
1517    /// Per-block penalties, each pulled back through that block's OWN
1518    /// diagonal reparameterisation `V_b` as `V_bᵀ S_b V_b`. The result
1519    /// is a per-block-width `PenaltyMatrix::Dense`
1520    /// (`w_b_compiled × w_b_compiled`) — the shape a per-block
1521    /// `ParameterBlockSpec.penalties` slot requires. Cross-block
1522    /// residualisation `R_{a→b}` is carried by the residualised design
1523    /// columns, not the penalty.
1524    pub time_penalties: Vec<PenaltyMatrix>,
1525    pub marginal_penalties: Vec<PenaltyMatrix>,
1526    pub logslope_penalties: Vec<PenaltyMatrix>,
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531    use super::*;
1532    use gam_problem::Gauge;
1533
1534    #[test]
1535    fn psd_clamp_zeros_negative_eigenvalues() {
1536        // Construct M = U diag(2, -1, 0.5, -0.25) Uᵀ for a fixed U from
1537        // a small rotation, verify the clamped matrix has eigenvalues
1538        // (2, 0, 0.5, 0).
1539        let mut m = Array2::<f64>::zeros((4, 4));
1540        // Diagonal with mixed signs is sufficient for the test: the
1541        // eigenvalues equal the diagonal and the eigenvectors are e_i.
1542        m[[0, 0]] = 2.0;
1543        m[[1, 1]] = -1.0;
1544        m[[2, 2]] = 0.5;
1545        m[[3, 3]] = -0.25;
1546        let clamped = psd_clamp_4x4(&m);
1547        assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1548        assert!(clamped[[1, 1]].abs() < 1e-12);
1549        assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1550        assert!(clamped[[3, 3]].abs() < 1e-12);
1551    }
1552
1553    #[test]
1554    fn time_block_operator_evaluate_full_shape() {
1555        let n = 6;
1556        let p = 3;
1557        let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1558        let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1559        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1560        let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1561        let full = op.evaluate_full();
1562        assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1563        for i in 0..n {
1564            for j in 0..p {
1565                assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1566                assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1567                assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1568                assert_eq!(full[[i, j, 3]], 0.0);
1569            }
1570        }
1571    }
1572
1573    #[test]
1574    fn q_channel_block_apply_row_shares_q0_q1() {
1575        let n = 5;
1576        let p = 2;
1577        let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1578        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1579        let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1580        let mut out = [0.0_f64; K_SURVIVAL];
1581        let delta = [1.0_f64, -0.5];
1582        op.apply_row(3, &delta, &mut out);
1583        let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1584        let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1585        assert!((out[0] - want_q).abs() < 1e-12);
1586        assert!((out[1] - want_q).abs() < 1e-12);
1587        assert!((out[2] - want_qd).abs() < 1e-12);
1588        assert_eq!(out[3], 0.0);
1589    }
1590
1591    #[test]
1592    fn logslope_block_writes_only_g_channel() {
1593        let n = 4;
1594        let p = 2;
1595        let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1596        let op = LogslopeBlockOperator::new(dg.clone());
1597        let mut out = [0.0_f64; K_SURVIVAL];
1598        let delta = [2.0_f64, -1.0];
1599        op.apply_row(1, &delta, &mut out);
1600        assert_eq!(out[0], 0.0);
1601        assert_eq!(out[1], 0.0);
1602        assert_eq!(out[2], 0.0);
1603        let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1604        assert!((out[3] - want).abs() < 1e-12);
1605    }
1606
1607    #[test]
1608    fn extract_term_partition_simple_cases() {
1609        let full = 0..5usize;
1610        // No penalties: whole block is one term.
1611        let part = extract_term_partition_from_penalty_ranges(5, &[]);
1612        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1613        // One penalty covering the whole block.
1614        let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1615        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1616        // Two penalties with a gap: produces three terms (pen1, gap, pen2).
1617        let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1618        assert_eq!(part, vec![0..3, 3..6, 6..10]);
1619        // Duplicate penalty ranges coalesce.
1620        let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1621        assert_eq!(part, vec![0..3, 3..6]);
1622        // Empty block.
1623        let part = extract_term_partition_from_penalty_ranges(0, &[]);
1624        assert!(part.is_empty());
1625    }
1626
1627    #[test]
1628    fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1629        let v_a = Array2::<f64>::eye(2);
1630        let v_b = Array2::<f64>::eye(2);
1631        let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1632        assert_eq!(t.dim(), (4, 4));
1633        let eye4 = Array2::<f64>::eye(4);
1634        for i in 0..4 {
1635            for j in 0..4 {
1636                assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1637            }
1638        }
1639    }
1640
1641    #[test]
1642    fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1643        let mut v_a = Array2::<f64>::zeros((3, 2));
1644        v_a[[0, 0]] = 1.0;
1645        v_a[[1, 0]] = 0.5;
1646        v_a[[2, 1]] = 1.0;
1647        let v_b = Array2::<f64>::eye(2);
1648        let r_ab =
1649            Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1650        let t =
1651            assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1652        assert_eq!(t.dim(), (5, 4));
1653        for i in 0..3 {
1654            for j in 0..2 {
1655                assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1656            }
1657        }
1658        for i in 0..2 {
1659            for j in 0..2 {
1660                assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1661            }
1662        }
1663        for i in 0..3 {
1664            for j in 0..2 {
1665                assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1666            }
1667        }
1668        for i in 0..2 {
1669            for j in 0..2 {
1670                assert_eq!(t[[3 + i, j]], 0.0);
1671            }
1672        }
1673    }
1674
1675    #[test]
1676    fn validate_partition_rejects_bad_partitions() {
1677        let bad_start = 1..5usize;
1678        let short_cover = 0..3usize;
1679        let full_cover = 0..5usize;
1680        // Doesn't start at 0.
1681        assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1682        // Doesn't cover the block.
1683        assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1684        // Has a gap.
1685        assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1686        // Has overlap.
1687        assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1688        // Has empty range.
1689        assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1690        // Empty block + empty partition OK.
1691        assert!(validate_partition(&[], 0, "test").is_ok());
1692        // Valid partition.
1693        assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1694        assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1695    }
1696
1697    /// Regression for #368: the phase-4b compiled-map penalty pullback must
1698    /// emit a PER-BLOCK-WIDTH penalty for every block (sized to that block's
1699    /// COMPILED design width), even when a block drops columns and the
1700    /// triangular T carries nonzero off-diagonal cross-block residualisation
1701    /// `R_{a→b}`. The original bug pulled penalties back through the full
1702    /// joint T (`Tᵀ S T`), producing joint-compiled-width penalties (e.g.
1703    /// 7×7) that did not fit a single per-block `ParameterBlockSpec.penalties`
1704    /// slot (e.g. time block compiled width 3), making `validate_blockspecs`
1705    /// fail and `assert_valid_blockspecs` panic across the FFI boundary on
1706    /// ordinary survival data.
1707    #[test]
1708    fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1709        use gam_identifiability::families::compiler::CompiledMap;
1710        use gam_terms::smooth::BlockwisePenalty;
1711
1712        let n = 10;
1713        // Time raw 3 → compiled 3 (block 0: no anchor, V pure, R=None).
1714        // Marginal raw 3 → compiled 2 (a real drop, with nonzero R against time).
1715        // Logslope raw 2 → compiled 2 (nonzero R against time+marginal).
1716        let v_time =
1717            Array2::<f64>::from_shape_fn(
1718                (3, 3),
1719                |(i, j)| {
1720                    if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1721                },
1722            );
1723        let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1724            0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1725        });
1726        let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1727        // R_marg: rows = time raw width 3, cols = marginal compiled width 2.
1728        let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1729        // R_log: rows = time+marg RAW width 6 (3 + 3), cols = logslope compiled
1730        // width 2. `assemble_block_triangular_t` stacks R_{a→b} over a<b, so the row
1731        // count is the sum of the RAW widths of the prior blocks (not their
1732        // compiled widths — marginal's compiled width is 2 but its raw width is 3).
1733        let r_log =
1734            Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1735
1736        let t = assemble_block_triangular_t(
1737            &[v_time.clone(), v_marg.clone(), v_log.clone()],
1738            &[None, Some(r_marg.clone()), Some(r_log.clone())],
1739        );
1740        assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1741
1742        let map = CompiledMap {
1743            raw_from_compiled: t.clone(),
1744            compiled_block_ranges: vec![0..3, 3..5, 5..7],
1745            raw_block_ranges: vec![0..3, 3..6, 6..8],
1746        };
1747
1748        // Raw designs (dense, n rows).
1749        let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1750            Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1751        ));
1752        let raw_time_exit = raw_time_entry.clone();
1753        let raw_time_deriv = raw_time_entry.clone();
1754        let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1755            (n, 3),
1756            |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1757        )));
1758        let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1759            (n, 2),
1760            |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1761        )));
1762
1763        // Block-local penalties (col_range relative to each block's first col).
1764        let s_time =
1765            Array2::<f64>::from_shape_fn(
1766                (3, 3),
1767                |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1768            );
1769        let s_marg =
1770            Array2::<f64>::from_shape_fn(
1771                (3, 3),
1772                |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1773            );
1774        let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1775        let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1776        let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1777        let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1778
1779        let out = apply_compiled_map_to_designs(
1780            &map,
1781            raw_time_entry,
1782            raw_time_exit,
1783            raw_time_deriv,
1784            raw_marg,
1785            raw_log,
1786            &time_pens,
1787            &marg_pens,
1788            &log_pens,
1789        )
1790        .expect("apply_compiled_map_to_designs must succeed");
1791
1792        // Designs carry per-block compiled widths.
1793        assert_eq!(out.time_design_entry.ncols(), 3);
1794        assert_eq!(out.marginal_design.ncols(), 2);
1795        assert_eq!(out.logslope_design.ncols(), 2);
1796
1797        // Core invariant the bug violated: every penalty is sized to ITS
1798        // OWN block's compiled width, NOT the joint compiled width (7).
1799        for s in &out.time_penalties {
1800            assert_eq!(
1801                s.as_dense_cow().dim(),
1802                (3, 3),
1803                "time penalty must be per-block 3×3, not joint-width"
1804            );
1805        }
1806        for s in &out.marginal_penalties {
1807            assert_eq!(
1808                s.as_dense_cow().dim(),
1809                (2, 2),
1810                "marginal penalty must match reduced compiled width 2, not joint 7"
1811            );
1812        }
1813        for s in &out.logslope_penalties {
1814            assert_eq!(s.as_dense_cow().dim(), (2, 2));
1815        }
1816
1817        // For the time block (block 0, no anchor ⇒ R=None), the per-block
1818        // pullback is EXACT: θ_timeᵀ P_time θ_time == γ_timeᵀ S_time γ_time
1819        // with γ_time = V_time · θ_time. Verify the quadratic-form identity.
1820        let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1821        let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1822        let gamma_time = v_time.dot(&theta_time);
1823        let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1824        let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1825        assert!(
1826            (lhs - rhs).abs() < 1e-10,
1827            "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1828        );
1829
1830        // The marginal pullback must equal V_margᵀ S_marg V_marg exactly
1831        // (block-local; the cross-block R_marg lives in the design, not here).
1832        let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1833        let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1834        for i in 0..2 {
1835            for j in 0..2 {
1836                assert!(
1837                    (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1838                    "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1839                );
1840            }
1841        }
1842    }
1843
1844    /// Top-level Phase-4b API test for the SMGS parametric path:
1845    /// call `compile_survival_parametric_designs` on a shared-constant
1846    /// alias between time and marginal, with an identity row Hessian.
1847    /// Verify the returned `v_*` matrices have the expected widths
1848    /// (time keeps all 3, marginal loses 1, logslope keeps both) and
1849    /// `drops_by_block` reports `(0, 1, 0)`.
1850    #[test]
1851    fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1852        let n = 24;
1853        let p_time = 3;
1854        let p_marginal = 3;
1855        let p_logslope = 2;
1856        let x: Vec<f64> = (0..n)
1857            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1858            .collect();
1859        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1860        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1861        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1862        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1863        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1864        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1865        for i in 0..n {
1866            time_dq0[[i, 0]] = 1.0;
1867            time_dq0[[i, 1]] = x[i];
1868            time_dq0[[i, 2]] = x[i] * x[i];
1869            time_dq1[[i, 0]] = 1.0;
1870            time_dq1[[i, 1]] = x[i];
1871            time_dq1[[i, 2]] = x[i] * x[i];
1872            time_dqd1[[i, 0]] = 0.0;
1873            time_dqd1[[i, 1]] = 1.0;
1874            time_dqd1[[i, 2]] = 2.0 * x[i];
1875            marg_dq[[i, 0]] = 1.0; // alias with time col 0
1876            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1877            marg_dq[[i, 2]] = x[i].sin();
1878            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1879            log_dg[[i, 1]] = x[i].tanh();
1880        }
1881        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1882        for i in 0..n {
1883            for k in 0..K_SURVIVAL {
1884                h_full[[i, k, k]] = 1.0;
1885            }
1886        }
1887        let row_hess = SurvivalRowHessian::from_full(h_full);
1888        let out = compile_survival_parametric_designs(
1889            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1890        )
1891        .expect("Phase-4b parametric compile must succeed on single-direction alias");
1892        assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1893        assert_eq!(
1894            out.v_marginal.ncols(),
1895            p_marginal - 1,
1896            "marginal loses exactly the shared-constant direction"
1897        );
1898        assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1899        assert_eq!(
1900            out.drops_by_block,
1901            (0, 1, 0),
1902            "attribution: zero from time/logslope, one from marginal",
1903        );
1904    }
1905
1906    /// End-to-end Phase-4b smoke test: build the full 3-block survival
1907    /// parametric operator stack (time + marginal + logslope) with a
1908    /// shared-constant alias seeded between the time and marginal
1909    /// blocks, feed it into `compile()` with an identity 4×4 row
1910    /// Hessian on every row, and verify the compiler:
1911    ///
1912    ///   (1) returns a [`CompiledBlocks`] with one block per input;
1913    ///   (2) preserves all 3 columns of the highest-priority `Time`
1914    ///       block in `t_lw` (the time block enters first in the
1915    ///       ordering, so its full column span survives);
1916    ///   (3) drops exactly one direction from `Marginal` (the
1917    ///       constant aliased with the time intercept), leaving its
1918    ///       remaining columns in `t_lw`;
1919    ///   (4) reports `joint_rank` = (raw_total - 1).
1920    ///
1921    /// This validates the Phase-4b construction-time orthogonalisation
1922    /// path on the survival K=4 row primary state and then feeds the
1923    /// compiled per-block reduced bases through the SMGS lift [`Gauge`]
1924    /// (step 6), asserting the lift's reduced/raw block structure agrees
1925    /// with the compiled rank-drop — the construction contract end to end.
1926    #[test]
1927    fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1928        use gam_identifiability::families::compiler::compile;
1929
1930        let n = 32;
1931        let p_time = 3;
1932        let p_marginal = 3;
1933        let p_logslope = 2;
1934
1935        // Time block:
1936        //   col 0 = ones (the shared constant — aliases marginal col 0);
1937        //   col 1 = linear x;
1938        //   col 2 = quadratic x².
1939        // q0/q1 share the same design (so the alias surfaces in both
1940        // the entry and exit primary channels); qd1 is the derivative
1941        // of the design w.r.t. time at the exit point, which for the
1942        // constant column is exactly zero (the gauge identity that
1943        // makes the constant a true null direction under (q0, q1, qd1)
1944        // joint).
1945        let x: Vec<f64> = (0..n)
1946            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1947            .collect();
1948        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1949        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1950        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1951        for i in 0..n {
1952            time_dq0[[i, 0]] = 1.0;
1953            time_dq0[[i, 1]] = x[i];
1954            time_dq0[[i, 2]] = x[i] * x[i];
1955            time_dq1[[i, 0]] = 1.0;
1956            time_dq1[[i, 1]] = x[i];
1957            time_dq1[[i, 2]] = x[i] * x[i];
1958            // d/dt of a constant = 0; d/dt of x ≡ 1; d/dt of x² ≡ 2x.
1959            time_dqd1[[i, 0]] = 0.0;
1960            time_dqd1[[i, 1]] = 1.0;
1961            time_dqd1[[i, 2]] = 2.0 * x[i];
1962        }
1963
1964        // Marginal block (q-channel only; qd1 contribution zero — no
1965        // timewiggle in this scenario):
1966        //   col 0 = ones (the shared constant);
1967        //   col 1 = x³;
1968        //   col 2 = sin(x).
1969        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1970        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1971        for i in 0..n {
1972            marg_dq[[i, 0]] = 1.0;
1973            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1974            marg_dq[[i, 2]] = x[i].sin();
1975        }
1976
1977        // Logslope block (g-channel only):
1978        //   col 0 = cos(2x);
1979        //   col 1 = tanh(x).  (no shared constant — logslope is clean)
1980        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1981        for i in 0..n {
1982            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1983            log_dg[[i, 1]] = x[i].tanh();
1984        }
1985
1986        let inputs = build_survival_compiler_inputs(
1987            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1988        );
1989
1990        // Identity 4×4 row Hessian on every row. With H_i = I the
1991        // sqrt-H metric collapses to the standard Frobenius metric,
1992        // so the compiler's residualisation is ordinary least-squares
1993        // projection — exactly what we want for verifying the
1994        // structural rank-deficiency attribution.
1995        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1996        for i in 0..n {
1997            for k in 0..K_SURVIVAL {
1998                h_full[[i, k, k]] = 1.0;
1999            }
2000        }
2001        let row_hess = SurvivalRowHessian::from_full(h_full);
2002
2003        let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
2004            .expect("survival 3-block compile must succeed; aliasing is single-direction");
2005
2006        // (1) One CompiledBlock per input.
2007        assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
2008
2009        // (2) Time enters first; under sqrt-I metric every column of
2010        // the time block is residual-vs-empty-anchor and therefore
2011        // survives the eigendecomposition with positive eigenvalue.
2012        // V_time has p_time columns.
2013        let v_time = &compiled.blocks[0].t_lw;
2014        assert_eq!(
2015            v_time.ncols(),
2016            p_time,
2017            "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
2018            v_time.dim(),
2019        );
2020
2021        // (3) Marginal enters second. Its constant column is aliased
2022        // with time's constant column in (q0, q1) and contributes zero
2023        // to qd1. After residualising against the time anchor in the
2024        // K=4 stacked metric, the residual Gram has rank
2025        // p_marginal − 1 (one direction collapsed by the alias). So
2026        // V_marginal has exactly (p_marginal − 1) columns.
2027        let v_marg = &compiled.blocks[1].t_lw;
2028        assert_eq!(
2029            v_marg.ncols(),
2030            p_marginal - 1,
2031            "marginal block must lose exactly the shared-constant direction; \
2032             V_marginal cols = {}, expected {}",
2033            v_marg.ncols(),
2034            p_marginal - 1,
2035        );
2036
2037        // (4) Logslope enters third and carries no shared direction
2038        // with time or marginal in the g-channel. Both columns survive.
2039        let v_log = &compiled.blocks[2].t_lw;
2040        assert_eq!(
2041            v_log.ncols(),
2042            p_logslope,
2043            "logslope block (no shared direction) must retain all {p_logslope} columns",
2044        );
2045
2046        // (5) Joint rank consistency: sum of compiled column counts
2047        // equals raw_total minus the one aliased direction.
2048        let raw_total = p_time + p_marginal + p_logslope;
2049        let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2050        assert_eq!(
2051            kept_total,
2052            raw_total - 1,
2053            "joint kept = raw_total − aliased; got {kept_total}, expected {}",
2054            raw_total - 1,
2055        );
2056        assert_eq!(
2057            compiled.joint_rank, kept_total,
2058            "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
2059        );
2060
2061        // (6) SMGS construction contract. Feed the compiled per-block reduced
2062        // bases (V_k = t_lw, shaped raw_k × kept_k) into the SMGS lift `Gauge`
2063        // and verify the lift's coordinate bookkeeping matches the compiler's
2064        // rank attribution: the reduced dimension equals `joint_rank`, the
2065        // reduced block boundaries advance by each block's kept width, and —
2066        // with R = None (no residualised cross-block reparam in this V-only
2067        // construction) — the raw block boundaries advance by each block's raw
2068        // width. This exercises the SMGS construction hook directly on the
2069        // compiled output rather than asserting against a hypothetical shape.
2070        let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
2071        let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
2072        let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2073
2074        let mut expected_reduced = vec![0usize];
2075        let mut expected_raw = vec![0usize];
2076        for b in &compiled.blocks {
2077            let prev_reduced = *expected_reduced.last().unwrap();
2078            expected_reduced.push(prev_reduced + b.t_lw.ncols());
2079            let prev_raw = *expected_raw.last().unwrap();
2080            expected_raw.push(prev_raw + b.t_lw.nrows());
2081        }
2082        assert_eq!(
2083            *gauge.block_starts_reduced.last().unwrap(),
2084            compiled.joint_rank,
2085            "SMGS lift reduced dimension must equal the compiled joint_rank",
2086        );
2087        assert_eq!(
2088            gauge.block_starts_reduced, expected_reduced,
2089            "SMGS lift reduced block boundaries must match the compiled kept widths",
2090        );
2091        assert_eq!(
2092            gauge.block_starts_raw, expected_raw,
2093            "SMGS lift raw block boundaries must match the compiled per-block raw widths",
2094        );
2095
2096        // (7) Every kept direction is finite and non-degenerate. A retained
2097        // column with a zero or non-finite norm would be a spurious rank
2098        // contribution that the count-only checks above cannot catch, so verify
2099        // each compiled block's surviving directions directly.
2100        for (bi, block) in compiled.blocks.iter().enumerate() {
2101            for j in 0..block.t_lw.ncols() {
2102                let col = block.t_lw.column(j);
2103                assert!(
2104                    col.iter().all(|v| v.is_finite()),
2105                    "block {bi} kept direction {j} has a non-finite entry",
2106                );
2107                let norm = col.dot(&col).sqrt();
2108                assert!(
2109                    norm > 1e-10,
2110                    "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
2111                );
2112            }
2113        }
2114    }
2115
2116    /// `T = I` case: per-block V = identity, R = None. The triangular
2117    /// lift must be the identity on each block.
2118    #[test]
2119    fn smgs_lift_via_t_identity_passes_through() {
2120        let v0 = Array2::<f64>::eye(3);
2121        let v1 = Array2::<f64>::eye(2);
2122        let v_per_term = vec![v0, v1];
2123        let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
2124        let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2125        assert_eq!(lift.t_full.dim(), (5, 5));
2126        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2127        assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
2128        for i in 0..5 {
2129            for j in 0..5 {
2130                let want = if i == j { 1.0 } else { 0.0 };
2131                assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
2132            }
2133        }
2134        let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
2135        let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
2136        let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
2137        assert_eq!(lifted.len(), 2);
2138        for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
2139            assert!((a - b).abs() < 1e-14);
2140        }
2141        for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
2142            assert!((a - b).abs() < 1e-14);
2143        }
2144    }
2145
2146    /// Two-block toy: V_a = I_3, V_b drops the middle column, R is a
2147    /// non-trivial residualised reparam. Verify β_a_raw = θ_a − R · θ_b
2148    /// and β_b_raw = V_b · θ_b.
2149    #[test]
2150    fn smgs_lift_via_t_two_block_with_residualisation() {
2151        let v_a = Array2::<f64>::eye(3);
2152        let mut v_b = Array2::<f64>::zeros((3, 2));
2153        v_b[[0, 0]] = 1.0;
2154        v_b[[2, 1]] = 1.0;
2155        let mut r_b = Array2::<f64>::zeros((3, 2));
2156        r_b[[0, 0]] = 0.4;
2157        r_b[[0, 1]] = -0.1;
2158        r_b[[1, 0]] = 0.7;
2159        r_b[[1, 1]] = 1.3;
2160        r_b[[2, 0]] = -0.2;
2161        r_b[[2, 1]] = 0.5;
2162        let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
2163        assert_eq!(lift.t_full.dim(), (6, 5));
2164        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2165        assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
2166
2167        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2168        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2169        let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2170        let r_theta_b = r_b.dot(&theta_b);
2171        let expected_a = &theta_a - &r_theta_b;
2172        assert_eq!(lifted[0].len(), 3);
2173        for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
2174            assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
2175        }
2176        assert_eq!(lifted[1].len(), 3);
2177        assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
2178        assert!(lifted[1][1].abs() < 1e-12);
2179        assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
2180    }
2181
2182    /// Covariance pushforward `Σ_raw = T · Σ_θ · Tᵀ` must be the exact
2183    /// inference companion of the point-estimate lift. Two invariants:
2184    ///
2185    /// 1. Identity T (V = I, R = None): the lifted covariance equals the
2186    ///    input covariance — a true no-op for a rank-clean fit.
2187    /// 2. Rank-1 consistency with the β lift: for a degenerate posterior
2188    ///    `Σ_θ = θ θᵀ`, the pushforward must equal `(T θ)(T θ)ᵀ`, i.e.
2189    ///    lifting the covariance of a point mass agrees with lifting the
2190    ///    point itself. This couples `lift_covariance` to
2191    ///    `lift_block_betas` exactly, so the mean and its
2192    ///    uncertainty can never drift into inconsistent coordinates.
2193    #[test]
2194    fn smgs_lift_covariance_identity_and_rank1_consistency() {
2195        // ── Invariant 1: identity T leaves the covariance unchanged. ──
2196        let lift_id = Gauge::from_v_and_r(
2197            &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
2198            &[None, None],
2199        );
2200        let mut cov = Array2::<f64>::zeros((4, 4));
2201        // An arbitrary symmetric PSD-ish covariance.
2202        for i in 0..4 {
2203            for j in 0..4 {
2204                cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
2205            }
2206        }
2207        let lifted_id = lift_id.lift_covariance(&cov);
2208        assert_eq!(lifted_id.dim(), (4, 4));
2209        for i in 0..4 {
2210            for j in 0..4 {
2211                assert!(
2212                    (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
2213                    "identity-T covariance lift must be a no-op at [{i},{j}]",
2214                );
2215            }
2216        }
2217
2218        // ── Invariant 2: rank-1 Σ_θ = θθᵀ pushes to (Tθ)(Tθ)ᵀ. ──
2219        // Reuse the two-block-with-residualisation geometry: V_a = I_3,
2220        // V_b drops the middle raw column, R_b non-trivial → raw width 6,
2221        // compiled width 5.
2222        let v_a = Array2::<f64>::eye(3);
2223        let mut v_b = Array2::<f64>::zeros((3, 2));
2224        v_b[[0, 0]] = 1.0;
2225        v_b[[2, 1]] = 1.0;
2226        let mut r_b = Array2::<f64>::zeros((3, 2));
2227        r_b[[0, 0]] = 0.4;
2228        r_b[[0, 1]] = -0.1;
2229        r_b[[1, 0]] = 0.7;
2230        r_b[[1, 1]] = 1.3;
2231        r_b[[2, 0]] = -0.2;
2232        r_b[[2, 1]] = 0.5;
2233        let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2234
2235        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2236        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2237        // Concatenated compiled θ (width 5).
2238        let theta_full = Array1::from(vec![
2239            theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2240        ]);
2241        // Σ_θ = θ θᵀ (rank-1).
2242        let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2243        for i in 0..5 {
2244            for j in 0..5 {
2245                cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2246            }
2247        }
2248        let lifted_cov = lift.lift_covariance(&cov_rank1);
2249        // Reference: (T θ)(T θ)ᵀ via the point-estimate lift.
2250        let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2251        let beta_raw = Array1::from(
2252            lifted_blocks
2253                .iter()
2254                .flat_map(|b| b.iter().copied())
2255                .collect::<Vec<f64>>(),
2256        );
2257        assert_eq!(lifted_cov.dim(), (6, 6));
2258        assert_eq!(beta_raw.len(), 6);
2259        for i in 0..6 {
2260            for j in 0..6 {
2261                let want = beta_raw[i] * beta_raw[j];
2262                assert!(
2263                    (lifted_cov[[i, j]] - want).abs() < 1e-10,
2264                    "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2265                    lifted_cov[[i, j]],
2266                );
2267            }
2268        }
2269        // Symmetry sanity.
2270        for i in 0..6 {
2271            for j in 0..6 {
2272                assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2273            }
2274        }
2275    }
2276
2277    /// When all R's are None, the triangular gauge lift must equal the
2278    /// strictly per-block `V_b · θ_b` lift.
2279    #[test]
2280    fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2281        let mut v_a = Array2::<f64>::zeros((3, 2));
2282        v_a[[0, 0]] = 0.6;
2283        v_a[[1, 0]] = -0.8;
2284        v_a[[1, 1]] = 0.3;
2285        v_a[[2, 1]] = 0.9;
2286        let mut v_b = Array2::<f64>::zeros((4, 3));
2287        v_b[[0, 0]] = 1.0;
2288        v_b[[1, 1]] = -0.4;
2289        v_b[[2, 0]] = 0.2;
2290        v_b[[2, 2]] = 0.7;
2291        v_b[[3, 2]] = -1.1;
2292        let v_per_term = vec![v_a.clone(), v_b.clone()];
2293        let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2294        let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2295        let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2296        let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2297        let ref_a = v_a.dot(&theta_a);
2298        let ref_b = v_b.dot(&theta_b);
2299        assert_eq!(via_t[0].len(), ref_a.len());
2300        for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2301            assert!((g - w).abs() < 1e-12);
2302        }
2303        assert_eq!(via_t[1].len(), ref_b.len());
2304        for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2305            assert!((g - w).abs() < 1e-12);
2306        }
2307    }
2308
2309    /// Recompile-after-first-PIRLS-accept refinement: under a structural
2310    /// (identity) row Hessian, a direction that is *only* identifiable
2311    /// through the q1 channel survives the per-term compile; under a
2312    /// data-adaptive row Hessian that happens to zero out the q1/qd1/g
2313    /// metric weight (everything except q0), the same direction collapses.
2314    /// This pins the diagnostic the production hook in
2315    /// `fit_survival_marginal_slope_terms` watches for: the two row
2316    /// Hessians produce different `drops_by_block` on identical raw
2317    /// designs.
2318    #[test]
2319    fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2320        let n = 6usize;
2321        // Time block: a single column that only contributes through q0
2322        // (entry-time channel). Both row Hessians see it identically on
2323        // the q0 axis.
2324        let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2325        let time_dq1 = Array2::<f64>::zeros((n, 1));
2326        let time_dqd1 = Array2::<f64>::zeros((n, 1));
2327        // Marginal block: a single column whose q0 part is colinear with
2328        // the time block's q0 (both are ones-vectors). Its q-channel maps
2329        // into BOTH q0 and q1 under QChannelBlockOperator, so under a
2330        // metric that weighs q1 it carries a non-colinear component.
2331        let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2332        let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2333        // No logslope columns.
2334        let log_dg = Array2::<f64>::zeros((n, 0));
2335        let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2336        time_partition.push(0..1);
2337        let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2338        marg_partition.push(0..1);
2339        let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2340
2341        // Pass 1: structural identity row Hessian. q0/q1/qd1/g all weighted
2342        // equally → marg's q1 component is visible, so marg is identifiable
2343        // after residualising against the time block (drops_marg = 0).
2344        let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2345        for i in 0..n {
2346            for k in 0..K_SURVIVAL {
2347                h_ident[[i, k, k]] = 1.0;
2348            }
2349        }
2350        let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2351        let compiled_ident = compile_survival_parametric_designs_per_term(
2352            time_dq0.clone(),
2353            time_dq1.clone(),
2354            time_dqd1.clone(),
2355            &time_partition,
2356            marg_dq.clone(),
2357            marg_dqd1.clone(),
2358            &marg_partition,
2359            log_dg.clone(),
2360            &log_partition,
2361            &row_hess_ident,
2362            false,
2363        )
2364        .expect("identity-H compile must succeed");
2365
2366        // Pass 2: data-adaptive row Hessian that only weighs q0 (all
2367        // other channel diagonals zero). Marg's q1 contribution is now
2368        // invisible → marg fully aliases with time on q0 → drops_marg = 1.
2369        let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2370        for i in 0..n {
2371            h_q0_only[[i, 0, 0]] = 1.0;
2372        }
2373        let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2374        let compiled_q0 = compile_survival_parametric_designs_per_term(
2375            time_dq0,
2376            time_dq1,
2377            time_dqd1,
2378            &time_partition,
2379            marg_dq,
2380            marg_dqd1,
2381            &marg_partition,
2382            log_dg,
2383            &log_partition,
2384            &row_hess_q0,
2385            false,
2386        )
2387        .expect("q0-only-H compile must succeed");
2388
2389        // The two drops_by_block tuples disagree on the marginal block —
2390        // this is exactly the "pilot-curvature trap" the recompile-after-
2391        // accept hook is designed to surface.
2392        assert_ne!(
2393            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2394            "structural-H and data-adaptive-H compiles must produce different \
2395             drops_by_block on the constructed pilot-curvature-trap design; \
2396             identity={:?} q0-only={:?}",
2397            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2398        );
2399        // Under identity H, marg survives (no drop).
2400        assert_eq!(
2401            compiled_ident.drops_by_block.1, 0,
2402            "identity-H marg drops expected 0, got {:?}",
2403            compiled_ident.drops_by_block,
2404        );
2405        // Under q0-only H, marg fully aliases with time on q0.
2406        assert_eq!(
2407            compiled_q0.drops_by_block.1, 1,
2408            "q0-only-H marg drops expected 1, got {:?}",
2409            compiled_q0.drops_by_block,
2410        );
2411    }
2412
2413    #[test]
2414    fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2415        // Build a per-term compile by hand: time has one term (raw 2, kept 2),
2416        // marginal one term (raw 2, kept 1 — a drop), logslope one term
2417        // (raw 1, kept 1). No required channel is fully collapsed.
2418        let v_time = Array2::<f64>::eye(2);
2419        let mut v_marg = Array2::<f64>::zeros((2, 1));
2420        v_marg[[0, 0]] = 1.0;
2421        v_marg[[1, 0]] = 0.5;
2422        let v_log = Array2::<f64>::eye(1);
2423        // R for the marginal block (anchor = time, raw width 2) and logslope
2424        // block (anchors = time + marginal, raw width 2 + 2 = 4).
2425        let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2426        let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2427        let per_term = SurvivalParametricCompiledPerTerm {
2428            v_time_per_term: vec![v_time.clone()],
2429            v_marginal_per_term: vec![v_marg.clone()],
2430            v_logslope_per_term: vec![v_log.clone()],
2431            r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2432            drops_by_block: (0, 1, 0),
2433        };
2434
2435        let map = compiled_map_from_per_term(&per_term);
2436
2437        // Raw block ranges: time 0..2, marginal 2..4, logslope 4..5.
2438        assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2439        // Compiled block ranges: time 0..2, marginal 2..3, logslope 3..4.
2440        assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2441        assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2442
2443        // The block-diagonal slices recovered by apply_compiled_map_to_designs
2444        // must equal the per-term V's exactly.
2445        let v_time_slice = map
2446            .raw_from_compiled
2447            .slice(ndarray::s![0..2, 0..2])
2448            .to_owned();
2449        let v_marg_slice = map
2450            .raw_from_compiled
2451            .slice(ndarray::s![2..4, 2..3])
2452            .to_owned();
2453        let v_log_slice = map
2454            .raw_from_compiled
2455            .slice(ndarray::s![4..5, 3..4])
2456            .to_owned();
2457        for i in 0..2 {
2458            for j in 0..2 {
2459                assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2460            }
2461            assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2462        }
2463        assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2464
2465        // The cross-block carry (-R) must sit in the strict upper triangle, so
2466        // the map agrees with the lift assembled directly from V and R.
2467        let ordering = [
2468            gam_identifiability::families::compiler::BlockOrder::Time,
2469            gam_identifiability::families::compiler::BlockOrder::Marginal,
2470            gam_identifiability::families::compiler::BlockOrder::Logslope,
2471        ];
2472        let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2473        let v_all = vec![v_time, v_marg, v_log];
2474        let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2475        assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2476        for i in 0..lift_from_map.t_full.nrows() {
2477            for j in 0..lift_from_map.t_full.ncols() {
2478                assert!(
2479                    (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2480                    "T mismatch at ({i},{j}): map={} direct={}",
2481                    lift_from_map.t_full[[i, j]],
2482                    lift_direct.t_full[[i, j]],
2483                );
2484            }
2485        }
2486    }
2487
2488    // ----- #979 effective reduced-logslope confound removal -----------------
2489    //
2490    // Direct unit coverage of the two numerical routines added for #979,
2491    // mirroring the BMS reference cuts
2492    // (`bms::block_specs` `effective_reduction_*`): the scalar weight
2493    // contraction off the per-row 4×4 Hessian, and the block-diagonal map
2494    // assembly. The 900s end-to-end `survival_marginal_slope_converges_*`
2495    // guard exercises the same path but is slow and data-dependent; these pin
2496    // the distinguishing logic deterministically.
2497
2498    /// Constant per-row 4×4 PSD Hessian carrying ONLY the (q0, g) coupling,
2499    /// channel order (q0, q1, qd1, g): `H[0,0]=h00`, `H[0,3]=H[3,0]=h03`,
2500    /// `H[3,3]=h33`, all else zero. The effective scalar weights the
2501    /// contraction reads are then `w_mm=h00`, `w_mg=h03`, `w_gg=h33`. The
2502    /// 2×2 (q0,g) block `[[h00,h03],[h03,h33]]` is PSD when `h00·h33 ≥ h03²`.
2503    fn const_row_hess_q0g(n: usize, h00: f64, h03: f64, h33: f64) -> SurvivalRowHessian {
2504        let mut h = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2505        for i in 0..n {
2506            h[[i, 0, 0]] = h00;
2507            h[[i, 0, 3]] = h03;
2508            h[[i, 3, 0]] = h03;
2509            h[[i, 3, 3]] = h33;
2510        }
2511        SurvivalRowHessian::from_full(h)
2512    }
2513
2514    #[test]
2515    fn survival_reduced_logslope_drops_confounded_keeps_free_979() {
2516        // p_m=1 marginal column m; p_log=2 logslope columns [l1, l2] with
2517        // l1 == m (an exact rank-1 (q0,g) confound: h00·h33 = h03²) so l1 is
2518        // fully marginal-explained, and l2 ⊥ m with 100× the energy so it is
2519        // unambiguously free. The effective Schur Gram must drop ONLY the
2520        // confounded direction: 0 < r == 1 < p_log == 2.
2521        let n = 4;
2522        let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0); // (q0,g) = [[2,2],[2,2]], rank-1
2523        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2524        // l1 = m (confounded); l2 = [10,-10,10,-10] (Euclidean-orthogonal to m,
2525        // ‖l2‖² = 400 ≫ ‖m‖² = 4, so Gtt's free eigenvalue ≫ tol).
2526        let log =
2527            Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2528                .unwrap();
2529        let t = survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2530            .expect("contraction must succeed")
2531            .expect("a partial confound must yield a reduced transform");
2532        assert_eq!(t.dim(), (2, 1), "exactly one logslope direction survives");
2533        // The kept eigenvector is the free column ≈ e2 (up to sign); the
2534        // confounded e1 component is dropped.
2535        assert!(
2536            t[[0, 0]].abs() < 1e-6,
2537            "confounded (e1) direction must be dropped, got {}",
2538            t[[0, 0]]
2539        );
2540        assert!(
2541            (t[[1, 0]].abs() - 1.0).abs() < 1e-6,
2542            "free (e2) direction must be kept as a unit vector, got {}",
2543            t[[1, 0]]
2544        );
2545    }
2546
2547    #[test]
2548    fn survival_reduced_logslope_fully_confounded_returns_none_979() {
2549        // A single logslope column equal to the marginal column under the exact
2550        // rank-1 (q0,g) confound: the whole effective logslope image lies in the
2551        // marginal span. The conservative ridge floors the residual eigenvalue at
2552        // energy_scale·TOL/(1+TOL) < tol, so r == 0 → Ok(None) (keep the raw
2553        // design, defer to the measured-phantom gate).
2554        let n = 4;
2555        let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0);
2556        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2557        let log = marg.clone();
2558        let out =
2559            survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2560                .expect("contraction must succeed");
2561        assert!(
2562            out.is_none(),
2563            "a fully marginal-explained logslope column reduces to nothing → keep raw"
2564        );
2565    }
2566
2567    #[test]
2568    fn survival_reduced_logslope_no_confound_returns_none_979() {
2569        // No marginal↔logslope cross weight (h03 = 0): the channels are
2570        // W-orthogonal, so every logslope direction is free (r == p_log) and
2571        // there is nothing to remove → Ok(None).
2572        let n = 4;
2573        let row_hess = const_row_hess_q0g(n, 2.0, 0.0, 2.0);
2574        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2575        let log =
2576            Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2577                .unwrap();
2578        let out =
2579            survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2580                .expect("contraction must succeed");
2581        assert!(out.is_none(), "W-orthogonal channels need no reduction → keep raw");
2582    }
2583
2584    #[test]
2585    fn survival_block_diagonal_logslope_map_is_identity_on_time_and_marginal_979() {
2586        // Time (p=2) and marginal (p=3) blocks pass through as identities; only
2587        // the logslope block (raw p_log=4) is reparameterised by t_log (4×2).
2588        let p_time = 2;
2589        let p_marg = 3;
2590        let t_log = Array2::from_shape_fn((4, 2), |(i, j)| 1.0 + (i * 2 + j) as f64);
2591        let map = survival_block_diagonal_logslope_map(p_time, p_marg, &t_log);
2592
2593        assert_eq!(map.raw_block_ranges, vec![0..2, 2..5, 5..9]);
2594        assert_eq!(map.compiled_block_ranges, vec![0..2, 2..5, 5..7]);
2595        assert_eq!(map.raw_from_compiled.dim(), (9, 7));
2596
2597        let t = &map.raw_from_compiled;
2598        // V_time = I2.
2599        for i in 0..p_time {
2600            for j in 0..p_time {
2601                let want = if i == j { 1.0 } else { 0.0 };
2602                assert!((t[[i, j]] - want).abs() < 1e-14, "V_time[{i},{j}]");
2603            }
2604        }
2605        // V_marg = I3.
2606        for i in 0..p_marg {
2607            for j in 0..p_marg {
2608                let want = if i == j { 1.0 } else { 0.0 };
2609                assert!((t[[p_time + i, p_time + j]] - want).abs() < 1e-14, "V_marg[{i},{j}]");
2610            }
2611        }
2612        // V_log = t_log.
2613        for i in 0..4 {
2614            for j in 0..2 {
2615                assert!(
2616                    (t[[p_time + p_marg + i, p_time + p_marg + j]] - t_log[[i, j]]).abs() < 1e-14,
2617                    "V_log[{i},{j}]"
2618                );
2619            }
2620        }
2621        // No cross-block bleed: the only nonzeros are the two identities and the
2622        // t_log block (every t_log entry here is nonzero).
2623        let nnz = t.iter().filter(|&&v| v != 0.0).count();
2624        assert_eq!(nnz, p_time + p_marg + t_log.iter().filter(|&&v| v != 0.0).count());
2625    }
2626}