Skip to main content

gam_models/survival/marginal_slope/
identifiability.rs

1//! Survival marginal-slope concrete impls for the family-agnostic
2//! identifiability compiler (`gam_identifiability::families::compiler`).
3//!
4//! Survival's row primary state is the 4-vector `u_i = (q0, q1, qd1, g)`,
5//! so `K = 4`. The row Hessian is the 4×4 second-derivative block of the
6//! per-row neg-log-likelihood kernel `row_primary_closed_form` at a pilot
7//! `β`, PSD-clamped via eigendecomposition (negative eigenvalues projected
8//! to zero) to handle pilot points far from the optimum.
9//!
10//! Each block exposes its row Jacobian as the contribution of `δβ_block`
11//! to the row primary-state vector:
12//!
13//! - **TimeBlockOperator**: `(δq0, δq1, δqd1, 0)` from `design_entry`,
14//!   `design_exit`, `design_derivative_exit` rows.
15//! - **MarginalBlockOperator**: `(δq, δq, δqd_marginal, 0)` from the
16//!   marginal design row (shared by q0 and q1; qd contribution zero unless
17//!   timewiggle is active — captured by an explicit derivative row matrix).
18//! - **LogslopeBlockOperator**: `(0, 0, 0, δg)` from the logslope design.
19//! - **ScoreWarpBlockOperator**: `(δq, δq, δqd_warp, 0)` from the warp
20//!   basis (shifts q at entry/exit; chain rule via dq0_seed/dt for qd1).
21//! - **LinkDevBlockOperator**: `(δq, δq, δqd_link, 0)` from the link-dev
22//!   basis on the rigid/pilot q-seed.
23//!
24//! Phase 4a delivery: trait impls + an input-builder helper. Phase 4b
25//! threads these through SMGS's construction site and the migrated pilot
26//! β; Phase 4c deletes the legacy
27//! `install_compiled_flex_block_into_runtime` path.
28
29use std::sync::Arc;
30
31use ndarray::{Array1, Array2, Array3};
32
33use gam_identifiability::families::compiler::{
34    BlockOrder, RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with,
35};
36use gam_problem::gauge::assemble_block_triangular_t;
37use faer::Side;
38use gam_linalg::faer_ndarray::FaerEigh;
39use gam_linalg::matrix::{CoefficientTransformOperator, DenseDesignMatrix, DesignMatrix};
40use gam_problem::{FamilyChannelHessian, PenaltyMatrix};
41
42const K_SURVIVAL: usize = 4;
43
44/// Threshold below which a coefficient vector is treated as the trivial
45/// (all-zero) pilot point in the drift-detection audit. At β ≈ 0 the
46/// primary-state coupling g vanishes (c ≡ 1), so the frozen pilot W is exact;
47/// any |β_j| above this is "non-trivial" and requires the family scalars to
48/// re-evaluate W(β). The bound is well below any meaningful fitted coefficient.
49const BETA_NONTRIVIAL_ABS_THRESHOLD: f64 = 1e-12;
50
51/// Per-row 4×4 row Hessian for the survival marginal-slope likelihood at a
52/// pilot `β`. The pilot supplies the primary-state vector
53/// `(q0_i, q1_i, qd1_i, g_i)` and the per-row sample weight + event
54/// indicator + z + probit scale. The 4×4 block is evaluated via the
55/// existing `row_primary_closed_form` kernel (which already returns the
56/// full Hessian in `(q0, q1, qd1, g)` order) and PSD-clamped per row.
57pub struct SurvivalRowHessian {
58    /// PSD-projected per-row 4×4 Hessian, stored row-major as
59    /// `(n × 4 × 4)`.
60    h: Array3<f64>,
61}
62
63impl SurvivalRowHessian {
64    /// Construct from explicit per-row pilot primary-state and the row
65    /// data needed by `row_primary_closed_form`. Negative eigenvalues are
66    /// projected to zero before storage so the matrix is PSD.
67    pub fn from_pilot_primary_state(
68        q0: &Array1<f64>,
69        q1: &Array1<f64>,
70        qd1: &Array1<f64>,
71        g: &Array1<f64>,
72        z: &Array1<f64>,
73        weights: &Array1<f64>,
74        event: &Array1<f64>,
75        derivative_guard: f64,
76        probit_scale: f64,
77    ) -> Result<Self, String> {
78        let n = q0.len();
79        if [
80            q1.len(),
81            qd1.len(),
82            g.len(),
83            z.len(),
84            weights.len(),
85            event.len(),
86        ]
87        .iter()
88        .any(|&l| l != n)
89        {
90            return Err(format!(
91                "SurvivalRowHessian: length mismatch \
92                 q0={n}, q1={}, qd1={}, g={}, z={}, weights={}, event={}",
93                q1.len(),
94                qd1.len(),
95                g.len(),
96                z.len(),
97                weights.len(),
98                event.len()
99            ));
100        }
101        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
102        for i in 0..n {
103            let (_, _grad, hess) =
104                crate::survival::marginal_slope::row_primary_for_compiler(
105                    q0[i],
106                    q1[i],
107                    qd1[i],
108                    g[i],
109                    z[i],
110                    weights[i],
111                    event[i],
112                    derivative_guard,
113                    probit_scale,
114                )?;
115            // PSD-clamp via eigendecomposition: project negative eigvals to 0.
116            let mut h_i = Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
117            for a in 0..K_SURVIVAL {
118                for b in 0..K_SURVIVAL {
119                    h_i[[a, b]] = hess[a][b];
120                }
121            }
122            let clamped = psd_clamp_4x4(&h_i);
123            for a in 0..K_SURVIVAL {
124                for b in 0..K_SURVIVAL {
125                    h_full[[i, a, b]] = clamped[[a, b]];
126                }
127            }
128        }
129        Ok(Self { h: h_full })
130    }
131
132    /// Construct from an already-PSD per-row tensor. Used by callers that
133    /// have computed the Hessian via a different route.
134    pub fn from_full(h: Array3<f64>) -> Self {
135        assert_eq!(h.shape()[1], K_SURVIVAL);
136        assert_eq!(h.shape()[2], K_SURVIVAL);
137        Self { h }
138    }
139}
140
141impl RowHessian for SurvivalRowHessian {
142    fn k(&self) -> usize {
143        K_SURVIVAL
144    }
145    fn nrows(&self) -> usize {
146        self.h.shape()[0]
147    }
148    fn fill_row(&self, row: usize, out: &mut [f64]) {
149        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
150        for a in 0..K_SURVIVAL {
151            for b in 0..K_SURVIVAL {
152                out[a * K_SURVIVAL + b] = self.h[[row, a, b]];
153            }
154        }
155    }
156    fn evaluate_full(&self) -> Array3<f64> {
157        self.h.clone()
158    }
159}
160
161/// `FamilyChannelHessian` for survival marginal-slope.
162///
163/// The 4×4 per-subject W_i is the Hessian of the row negative log-likelihood
164/// `ρ_i(q0, q1, qd1, g) = −δ_i log f(η_1, ad1) − (1−δ_i) log S(η_1) + log S(η_0)`
165/// with respect to the 4-vector primary state `(q0, q1, qd1, g)`.
166///
167/// Derivation of W_i entries (all from `row_primary_closed_form`):
168///
169/// - W[0,0] = u2_η0 · c²  (q0–q0; only η0 depends on q0)
170/// - W[1,1] = (u2_η1 + w·δ) · c²  (q1–q1; η1 and log-φ both depend on q1)
171/// - W[2,2] = w·δ · (∂ad1/∂qd1)² · (−1/ad1²)  (qd1–qd1 via neglog(ad1))
172/// - W[3,3] = u2_η0·(∂η0/∂g)² + u1_η0·(∂²η0/∂g²) + u2_η1·(∂η1/∂g)² + ...
173/// - W[0,3] = W[3,0] = u2_η0·c·(q0·c1 + s_f·z) + u1_η0·c1  (cross q0–g)
174/// - W[1,3] = W[3,1] = u2_η1·c·(q1·c1 + s_f·z) + u1_η1·c1  (cross q1–g)
175/// - W[2,3] = W[3,2] = u2_ad1·c·(qd1·c1) + u1_ad1·c1  (cross qd1–g)
176/// - All other off-diagonals are zero (η0, η1, ad1 depend on non-overlapping
177///   subsets of (q0,q1,qd1,g), and only g is shared across all three).
178///
179/// This is already computed by `row_primary_closed_form` and stored in
180/// `SurvivalRowHessian::h` after PSD-clamping.
181///
182/// # β-dependent W via `channel_hessian_at`
183///
184/// `channel_hessian_at` overrides the default β-independent path.  When
185/// `family_scalars` carries `SurvivalMarginalSlopeFamilyScalars`, the
186/// current per-row primary state `(q0_i, q1_i, qd1_i, g_i)` is read from
187/// those scalars and the 4×4 W_i is recomputed via `row_primary_for_compiler`.
188/// This makes `I(β) = J(β)^T W(β) J(β)` accurate at the current β instead of
189/// at the frozen pilot β=0 state.
190///
191/// When `family_scalars` is `None` but `beta` is zero-ish (all entries ≤ ε),
192/// the frozen pilot W is returned unchanged.  When `family_scalars` is `None`
193/// and any `beta` entry is non-trivial, `Err` is returned — the caller must
194/// supply scalars for a correct W at non-pilot β (same contract as T26's
195/// Jacobian callbacks: scalars required when β affects the primary state).
196impl FamilyChannelHessian for SurvivalRowHessian {
197    fn n_outputs(&self) -> usize {
198        K_SURVIVAL
199    }
200
201    fn n_subjects(&self) -> usize {
202        self.h.shape()[0]
203    }
204
205    fn fill_subject(&self, i: usize, out: &mut [f64]) {
206        assert_eq!(out.len(), K_SURVIVAL * K_SURVIVAL);
207        for a in 0..K_SURVIVAL {
208            for b in 0..K_SURVIVAL {
209                out[a * K_SURVIVAL + b] = self.h[[i, a, b]];
210            }
211        }
212    }
213
214    fn evaluate_full(&self) -> ndarray::Array3<f64> {
215        self.h.clone()
216    }
217
218    fn channel_hessian_at(
219        &self,
220        beta: &[f64],
221        family_scalars: Option<&Arc<dyn std::any::Any + Send + Sync>>,
222    ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
223        use crate::survival::marginal_slope::SurvivalMarginalSlopeFamilyScalars;
224
225        let scalars_opt =
226            family_scalars.and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
227
228        // Determine whether beta is non-trivial (any |β_j| > ε).
229        let beta_nontrivial = beta
230            .iter()
231            .any(|&b| b.abs() > BETA_NONTRIVIAL_ABS_THRESHOLD);
232
233        match scalars_opt {
234            None if beta_nontrivial => {
235                // β is non-zero in a way that would change W via the primary-state
236                // coupling (g ≠ 0 → c ≠ 1 → W changes).  Scalars are required.
237                Err(
238                    "SurvivalRowHessian::channel_hessian_at: beta is non-trivial but \
239                     family_scalars is None; supply SurvivalMarginalSlopeFamilyScalars \
240                     via FamilyLinearizationState::family_scalars to evaluate W(β) \
241                     correctly (same contract as T26 Jacobian callbacks)."
242                        .to_string(),
243                )
244            }
245            None => {
246                // β ≈ 0: return the frozen pilot W unchanged.
247                Ok(Arc::new(gam_problem::TensorChannelHessian {
248                    h: self.h.clone(),
249                }))
250            }
251            Some(sc) => {
252                let n = self.h.shape()[0];
253                if sc.q0_i.len() != n
254                    || sc.q1_i.len() != n
255                    || sc.qd1_i.len() != n
256                    || sc.g_i.len() != n
257                    || sc.z_i.len() != n
258                {
259                    return Err(format!(
260                        "SurvivalRowHessian::channel_hessian_at: scalars length mismatch \
261                         (expected n={n}, got q0={} q1={} qd1={} g={} z={})",
262                        sc.q0_i.len(),
263                        sc.q1_i.len(),
264                        sc.qd1_i.len(),
265                        sc.g_i.len(),
266                        sc.z_i.len(),
267                    ));
268                }
269                // We do not have weights/event stored in SurvivalRowHessian itself.
270                // The scalars carry the per-row primary state; we need per-row weights
271                // and event indicators to call row_primary_for_compiler.  Those are
272                // NOT stored in SurvivalMarginalSlopeFamilyScalars — so we can only
273                // recompute W's structural shape (the 4×4 curvature geometry) using
274                // unit weights and event=1, which gives us the correct _direction_ of
275                // W at the current β even if the magnitude is off by the sample weight.
276                //
277                // For the drift-detection audit the direction matters more than the
278                // exact per-row magnitudes: rank changes emerge from structural
279                // identifiability, not from per-row weight scaling. Using w=1, d=1
280                // is therefore the principled approximation for the audit path.
281                //
282                // Production callers that need exact W (e.g. for the Fisher Gram in
283                // the compiler) should use SurvivalRowHessian::from_pilot_primary_state
284                // directly with the true per-row weights and event indicators.
285                let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
286                for i in 0..n {
287                    let q0 = sc.q0_i[i];
288                    let q1 = sc.q1_i[i];
289                    let qd1 = sc.qd1_i[i];
290                    let g = sc.g_i[i];
291                    let z = sc.z_i[i];
292                    // Use unit weight and d=1 (event indicator 1) for the audit path.
293                    // The derivative_guard is the family default (small but non-zero).
294                    match crate::survival::marginal_slope::row_primary_for_compiler(
295                        q0, q1, qd1, g, z, 1.0,  // w = unit weight
296                        1.0,  // d = event
297                        crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
298                        sc.s, // probit_scale from scalars
299                    ) {
300                        Ok((_nll, _grad, hess)) => {
301                            let mut h_i = ndarray::Array2::<f64>::zeros((K_SURVIVAL, K_SURVIVAL));
302                            for a in 0..K_SURVIVAL {
303                                for b in 0..K_SURVIVAL {
304                                    h_i[[a, b]] = hess[a][b];
305                                }
306                            }
307                            let clamped = psd_clamp_4x4(&h_i);
308                            for a in 0..K_SURVIVAL {
309                                for b in 0..K_SURVIVAL {
310                                    h_full[[i, a, b]] = clamped[[a, b]];
311                                }
312                            }
313                        }
314                        Err(_) => {
315                            // Monotonicity violation or other numerical issue at this
316                            // row: fall back to the frozen pilot W for this row only.
317                            for a in 0..K_SURVIVAL {
318                                for b in 0..K_SURVIVAL {
319                                    h_full[[i, a, b]] = self.h[[i, a, b]];
320                                }
321                            }
322                        }
323                    }
324                }
325                Ok(Arc::new(SurvivalRowHessian::from_full(h_full)))
326            }
327        }
328    }
329}
330
331/// Project a 4×4 symmetric matrix onto the PSD cone: zero negative
332/// eigenvalues. If the eigendecomposition fails (extremely defensive —
333/// `row_primary_closed_form` already guarantees finite entries), return
334/// the diagonal with negatives clamped.
335fn psd_clamp_4x4(m: &Array2<f64>) -> Array2<f64> {
336    let k = m.nrows();
337    let (evals, evecs) = match m.eigh(Side::Lower) {
338        Ok(pair) => pair,
339        Err(_) => {
340            let mut out = Array2::<f64>::zeros((k, k));
341            for i in 0..k {
342                out[[i, i]] = m[[i, i]].max(0.0);
343            }
344            return out;
345        }
346    };
347    let mut out = Array2::<f64>::zeros((k, k));
348    for i in 0..k {
349        for j in 0..k {
350            let mut acc = 0.0;
351            for l in 0..k {
352                acc += evecs[[i, l]] * evals[l].max(0.0) * evecs[[j, l]];
353            }
354            out[[i, j]] = acc;
355        }
356    }
357    out
358}
359
360/// Row Jacobian operator for the survival time block. Channels (q0, q1,
361/// qd1) come from the three time designs; the g channel is zero.
362pub struct TimeBlockOperator {
363    dq0: Array2<f64>,
364    dq1: Array2<f64>,
365    dqd1: Array2<f64>,
366}
367
368impl TimeBlockOperator {
369    pub fn new(dq0: Array2<f64>, dq1: Array2<f64>, dqd1: Array2<f64>) -> Self {
370        assert_eq!(dq0.dim(), dq1.dim());
371        assert_eq!(dq0.dim(), dqd1.dim());
372        Self { dq0, dq1, dqd1 }
373    }
374}
375
376impl RowJacobianOperator for TimeBlockOperator {
377    fn k(&self) -> usize {
378        K_SURVIVAL
379    }
380    fn ncols(&self) -> usize {
381        self.dq0.ncols()
382    }
383    fn nrows(&self) -> usize {
384        self.dq0.nrows()
385    }
386    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
387        assert_eq!(out.len(), K_SURVIVAL);
388        assert_eq!(delta_beta.len(), self.dq0.ncols());
389        let mut acc = [0.0_f64; K_SURVIVAL];
390        for (j, &b) in delta_beta.iter().enumerate() {
391            acc[0] += self.dq0[[row, j]] * b;
392            acc[1] += self.dq1[[row, j]] * b;
393            acc[2] += self.dqd1[[row, j]] * b;
394        }
395        out.copy_from_slice(&acc);
396    }
397    fn evaluate_full(&self) -> Array3<f64> {
398        let n = self.dq0.nrows();
399        let p = self.dq0.ncols();
400        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
401        for i in 0..n {
402            for j in 0..p {
403                out[[i, j, 0]] = self.dq0[[i, j]];
404                out[[i, j, 1]] = self.dq1[[i, j]];
405                out[[i, j, 2]] = self.dqd1[[i, j]];
406            }
407        }
408        out
409    }
410    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
411        // Scale straight out of the three compact `(n, p)` channel designs —
412        // the compiler consumes the `(n·K, p)` sqrt(H)-scaled design, so the
413        // dense `(n, p, K)` tensor (3 of its 4 channels held explicitly, the
414        // 4th identically zero) that `evaluate_full()` builds is never needed.
415        // (#738: a capability is not a representation.)
416        let n = self.dq0.nrows();
417        let p = self.dq0.ncols();
418        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
419            0 => self.dq0[[i, a]],
420            1 => self.dq1[[i, a]],
421            2 => self.dqd1[[i, a]],
422            _ => 0.0,
423        })
424    }
425}
426
427/// Row Jacobian operator for a block whose contribution flows into the
428/// q-channels (q0 and q1 identically) and optionally the qd1 channel.
429/// Covers the survival marginal, score-warp, and link-dev blocks (all
430/// three share the structural property `δq0 = δq1 = basis·δβ`, `δg = 0`).
431pub struct QChannelBlockOperator {
432    dq: Array2<f64>,
433    dqd1: Array2<f64>,
434}
435
436impl QChannelBlockOperator {
437    pub fn new(dq: Array2<f64>, dqd1: Array2<f64>) -> Self {
438        assert_eq!(dq.dim(), dqd1.dim());
439        Self { dq, dqd1 }
440    }
441}
442
443impl RowJacobianOperator for QChannelBlockOperator {
444    fn k(&self) -> usize {
445        K_SURVIVAL
446    }
447    fn ncols(&self) -> usize {
448        self.dq.ncols()
449    }
450    fn nrows(&self) -> usize {
451        self.dq.nrows()
452    }
453    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
454        assert_eq!(out.len(), K_SURVIVAL);
455        assert_eq!(delta_beta.len(), self.dq.ncols());
456        let mut dq_acc = 0.0;
457        let mut dqd_acc = 0.0;
458        for (j, &b) in delta_beta.iter().enumerate() {
459            dq_acc += self.dq[[row, j]] * b;
460            dqd_acc += self.dqd1[[row, j]] * b;
461        }
462        out[0] = dq_acc;
463        out[1] = dq_acc;
464        out[2] = dqd_acc;
465        out[3] = 0.0;
466    }
467    fn evaluate_full(&self) -> Array3<f64> {
468        let n = self.dq.nrows();
469        let p = self.dq.ncols();
470        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
471        for i in 0..n {
472            for j in 0..p {
473                let v = self.dq[[i, j]];
474                out[[i, j, 0]] = v;
475                out[[i, j, 1]] = v;
476                out[[i, j, 2]] = self.dqd1[[i, j]];
477            }
478        }
479        out
480    }
481    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
482        // q0 and q1 share `dq`; qd1 is `dqd1`; the g channel is identically
483        // zero. Scale directly from the compact `(n, p)` designs, skipping the
484        // dense `(n, p, K)` tensor `evaluate_full()` would build. (#738.)
485        let n = self.dq.nrows();
486        let p = self.dq.ncols();
487        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| match c {
488            0 | 1 => self.dq[[i, a]],
489            2 => self.dqd1[[i, a]],
490            _ => 0.0,
491        })
492    }
493}
494
495/// Row Jacobian operator for the survival logslope block: contribution
496/// lives entirely on the g channel.
497pub struct LogslopeBlockOperator {
498    dg: Array2<f64>,
499}
500
501impl LogslopeBlockOperator {
502    pub fn new(dg: Array2<f64>) -> Self {
503        Self { dg }
504    }
505}
506
507impl RowJacobianOperator for LogslopeBlockOperator {
508    fn k(&self) -> usize {
509        K_SURVIVAL
510    }
511    fn ncols(&self) -> usize {
512        self.dg.ncols()
513    }
514    fn nrows(&self) -> usize {
515        self.dg.nrows()
516    }
517    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
518        assert_eq!(out.len(), K_SURVIVAL);
519        assert_eq!(delta_beta.len(), self.dg.ncols());
520        let mut acc = 0.0;
521        for (j, &b) in delta_beta.iter().enumerate() {
522            acc += self.dg[[row, j]] * b;
523        }
524        out[0] = 0.0;
525        out[1] = 0.0;
526        out[2] = 0.0;
527        out[3] = acc;
528    }
529    fn evaluate_full(&self) -> Array3<f64> {
530        let n = self.dg.nrows();
531        let p = self.dg.ncols();
532        let mut out = Array3::<f64>::zeros((n, p, K_SURVIVAL));
533        for i in 0..n {
534            for j in 0..p {
535                out[[i, j, 3]] = self.dg[[i, j]];
536            }
537        }
538        out
539    }
540    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
541        // The logslope contribution lives entirely on the g channel (3); the
542        // other three channels are identically zero. Scale directly from the
543        // compact `(n, p)` design, skipping the mostly-zero dense `(n, p, K)`
544        // tensor `evaluate_full()` would build. (#738.)
545        let n = self.dg.nrows();
546        let p = self.dg.ncols();
547        scale_jacobian_by_sqrt_h_with(n, p, K_SURVIVAL, h_full, |i, a, c| {
548            if c == 3 { self.dg[[i, a]] } else { 0.0 }
549        })
550    }
551}
552
553/// Inputs assembled for the survival fit driver to feed `compile()`. The
554/// ordering follows `gauge_priority` descending (time=200 → marginal=150 →
555/// logslope=120 → score_warp=80 → link_dev=60).
556pub struct SurvivalCompilerInputs {
557    pub operators: Vec<Arc<dyn RowJacobianOperator>>,
558    pub ordering: Vec<BlockOrder>,
559}
560
561/// Per-block V reparameterisation matrices for the three parametric
562/// survival blocks emitted by [`compile_survival_parametric_designs`].
563/// Each `v_*` is a `(p_block_raw × p_block_kept)` selection-or-rotation
564/// matrix that maps a `β_kept` coefficient vector to its `β_raw`
565/// equivalent: `β_raw = V · β_kept`. The construction site applies these
566/// to the raw block designs (`design_raw · V → design_compiled`) and
567/// to the penalties (`Vᵀ S V`) before building `ParameterBlockSpec`s
568/// and passing the compiled designs into `make_family`.
569///
570/// Phase-4b architecture: this is the seam where the family-agnostic
571/// row-Jacobian compiler hands control back to the family-specific
572/// construction site. Each `v_*` width equals the corresponding
573/// `CompiledBlocks::blocks[i].t_lw.ncols()` — i.e., the kept-direction
574/// count after sqrt-H-metric residualisation and post-walk RRQR
575/// trailing-pivot drop.
576pub struct SurvivalParametricCompiled {
577    pub v_time: Array2<f64>,
578    pub v_marginal: Array2<f64>,
579    pub v_logslope: Array2<f64>,
580    /// Per-block dropped raw-column count, indexed
581    /// `(time_dropped, marginal_dropped, logslope_dropped)`. Equal to
582    /// `(p_raw − v.ncols())` for each block. Useful for logging the
583    /// gauge-attribution summary at the construction site.
584    pub drops_by_block: (usize, usize, usize),
585}
586
587fn wrap_design_with_transform(
588    raw: DesignMatrix,
589    v: &Array2<f64>,
590    context: &str,
591) -> Result<DesignMatrix, String> {
592    if raw.ncols() != v.nrows() {
593        return Err(format!(
594            "{context}: raw design has {} cols but V has {} rows (V is {}×{})",
595            raw.ncols(),
596            v.nrows(),
597            v.nrows(),
598            v.ncols(),
599        ));
600    }
601    let inner_dense = match raw {
602        DesignMatrix::Dense(d) => d,
603        DesignMatrix::Sparse(_) => {
604            let dense = raw
605                .try_to_dense_by_chunks(&format!("{context} sparse→dense for V apply"))
606                .map_err(|reason| format!("{context}: densify failed: {reason}"))?;
607            DenseDesignMatrix::from(dense)
608        }
609    };
610    let op = CoefficientTransformOperator::new(inner_dense, v.clone())
611        .map_err(|reason| format!("{context}: CoefficientTransformOperator::new: {reason}"))?;
612    Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(op))))
613}
614
615/// Per-term V reparameterisation matrices for the three parametric
616/// survival blocks. Each block's full V is the block-diagonal assembly
617/// of its per-term V's (one entry per element of the input
618/// `*_partition`). Preserves per-term penalty structure: applying
619/// `V_b = block_diag(V_term1, ..., V_termM)` to a per-term BlockwisePenalty
620/// pulls each penalty back only via its OWN term's V, so what was a
621/// per-term λ tunable in REML stays per-term tunable.
622pub struct SurvivalParametricCompiledPerTerm {
623    pub v_time_per_term: Vec<Array2<f64>>,
624    pub v_marginal_per_term: Vec<Array2<f64>>,
625    pub v_logslope_per_term: Vec<Array2<f64>>,
626    /// Per-term residualised reparam `R_b = M_b · V_b` from the
627    /// identifiability compiler, in the same global compile order
628    /// (time terms, then marginal terms, then logslope terms). `None`
629    /// for the very first compiled block (no anchor). Used by the
630    /// V+M-exact apply path to emit residualised rows
631    /// `C_b·V_b − A_{<b}·R_b` and to assemble the full triangular T.
632    pub r_lw_per_term: Vec<Option<Array2<f64>>>,
633    /// Per-block drops (raw_cols − sum(kept_cols across terms)).
634    pub drops_by_block: (usize, usize, usize),
635}
636
637/// Per-term-aware compile: residualise each block's TERMS individually
638/// in priority order so the emitted V is block-diagonal on term
639/// boundaries. This preserves the per-term penalty structure that
640/// REML's per-λ accounting depends on.
641///
642/// Each `*_partition` is a list of disjoint contiguous column ranges
643/// covering `[0..p_block)`. For the marginal/logslope blocks the
644/// natural source is the union of `BlockwisePenalty::col_range` values
645/// (one per smoothness penalty / term) plus the complement
646/// (unpenalised parametric columns).
647///
648/// Order of residualisation: time terms first (in their partition
649/// order), then marginal terms, then logslope terms. Within each
650/// block, terms are residualised against ALL prior anchor columns
651/// (terms from earlier blocks + earlier terms within this block).
652/// Aliased directions land in the lowest-priority block that contains
653/// them, in the natural term order within that block — matching the
654/// gauge-priority ownership contract.
655pub fn compile_survival_parametric_designs_per_term(
656    time_dq0: Array2<f64>,
657    time_dq1: Array2<f64>,
658    time_dqd1: Array2<f64>,
659    time_partition: &[std::ops::Range<usize>],
660    marginal_dq: Array2<f64>,
661    marginal_dqd1: Array2<f64>,
662    marginal_partition: &[std::ops::Range<usize>],
663    logslope_dg: Array2<f64>,
664    logslope_partition: &[std::ops::Range<usize>],
665    row_hess: &dyn RowHessian,
666    protect_time: bool,
667) -> Result<SurvivalParametricCompiledPerTerm, String> {
668    use gam_identifiability::families::compiler::compile_protected;
669
670    let p_time = time_dq0.ncols();
671    let p_marg = marginal_dq.ncols();
672    let p_log = logslope_dg.ncols();
673    validate_partition(time_partition, p_time, "time")?;
674    validate_partition(marginal_partition, p_marg, "marginal")?;
675    validate_partition(logslope_partition, p_log, "logslope")?;
676
677    // Build per-term operators. Each term gets its own RowJacobianOperator
678    // restricted to its column slice; the operator type matches the
679    // block's K-channel signature (Time, QChannel, Logslope).
680    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::new();
681    let mut ordering: Vec<BlockOrder> = Vec::new();
682    for range in time_partition {
683        let dq0 = time_dq0.slice(ndarray::s![.., range.clone()]).to_owned();
684        let dq1 = time_dq1.slice(ndarray::s![.., range.clone()]).to_owned();
685        let dqd1 = time_dqd1.slice(ndarray::s![.., range.clone()]).to_owned();
686        operators.push(Arc::new(TimeBlockOperator::new(dq0, dq1, dqd1)));
687        ordering.push(BlockOrder::Time);
688    }
689    for range in marginal_partition {
690        let dq = marginal_dq.slice(ndarray::s![.., range.clone()]).to_owned();
691        let dqd1 = marginal_dqd1
692            .slice(ndarray::s![.., range.clone()])
693            .to_owned();
694        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
695        ordering.push(BlockOrder::Marginal);
696    }
697    for range in logslope_partition {
698        let dg = logslope_dg.slice(ndarray::s![.., range.clone()]).to_owned();
699        operators.push(Arc::new(LogslopeBlockOperator::new(dg)));
700        ordering.push(BlockOrder::Logslope);
701    }
702
703    // The time block carries the monotone time-wiggle basis, whose effective
704    // Jacobian is a fixed nonlinear functional basis rather than a linear
705    // design. When `protect_time` is set it must be kept at full raw width: a
706    // linear reparameterisation of it would desynchronise the raw-width
707    // wiggle-basis chain rule (`SmsTimewiggleTimeJacobian`), which recomputes
708    // the basis on every evaluation. Marginal/logslope still reduce against the
709    // full time anchor. The time block spans operators `0..n_time` (pushed
710    // first above); mark exactly those protected.
711    let n_time = time_partition.len();
712    let protected: Vec<bool> = if protect_time {
713        (0..operators.len()).map(|i| i < n_time).collect()
714    } else {
715        Vec::new()
716    };
717    let compiled =
718        compile_protected(&operators, row_hess, &ordering, &protected).map_err(|e| {
719            format!("identifiability::families::compiler::compile (per-term) failed: {e}")
720        })?;
721    let blocks = compiled.blocks;
722    let n_marg = marginal_partition.len();
723    let n_log = logslope_partition.len();
724    if blocks.len() != n_time + n_marg + n_log {
725        return Err(format!(
726            "per-term compile: expected {} compiled blocks (time={}, marg={}, log={}), got {}",
727            n_time + n_marg + n_log,
728            n_time,
729            n_marg,
730            n_log,
731            blocks.len(),
732        ));
733    }
734    let mut iter = blocks.into_iter();
735    let mut v_time_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_time);
736    let mut r_time_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time);
737    for _ in 0..n_time {
738        let blk = iter.next().unwrap();
739        v_time_per_term.push(blk.t_lw);
740        r_time_per_term.push(blk.r_lw);
741    }
742    let mut v_marginal_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_marg);
743    let mut r_marginal_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_marg);
744    for _ in 0..n_marg {
745        let blk = iter.next().unwrap();
746        v_marginal_per_term.push(blk.t_lw);
747        r_marginal_per_term.push(blk.r_lw);
748    }
749    let mut v_logslope_per_term: Vec<Array2<f64>> = Vec::with_capacity(n_log);
750    let mut r_logslope_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_log);
751    for _ in 0..n_log {
752        let blk = iter.next().unwrap();
753        v_logslope_per_term.push(blk.t_lw);
754        r_logslope_per_term.push(blk.r_lw);
755    }
756    let mut r_lw_per_term: Vec<Option<Array2<f64>>> = Vec::with_capacity(n_time + n_marg + n_log);
757    r_lw_per_term.extend(r_time_per_term);
758    r_lw_per_term.extend(r_marginal_per_term);
759    r_lw_per_term.extend(r_logslope_per_term);
760    let drops_time: usize = time_partition
761        .iter()
762        .zip(v_time_per_term.iter())
763        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
764        .sum();
765    let drops_marg: usize = marginal_partition
766        .iter()
767        .zip(v_marginal_per_term.iter())
768        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
769        .sum();
770    let drops_log: usize = logslope_partition
771        .iter()
772        .zip(v_logslope_per_term.iter())
773        .map(|(r, v)| r.len().saturating_sub(v.ncols()))
774        .sum();
775    Ok(SurvivalParametricCompiledPerTerm {
776        v_time_per_term,
777        v_marginal_per_term,
778        v_logslope_per_term,
779        r_lw_per_term,
780        drops_by_block: (drops_time, drops_marg, drops_log),
781    })
782}
783
784fn validate_partition(
785    partition: &[std::ops::Range<usize>],
786    p_block: usize,
787    label: &str,
788) -> Result<(), String> {
789    if partition.is_empty() {
790        if p_block == 0 {
791            return Ok(());
792        }
793        return Err(format!(
794            "{label} partition empty but block has p={p_block} columns"
795        ));
796    }
797    if partition[0].start != 0 {
798        return Err(format!(
799            "{label} partition must start at 0, got start={}",
800            partition[0].start
801        ));
802    }
803    if partition.last().unwrap().end != p_block {
804        return Err(format!(
805            "{label} partition must cover [0, {p_block}); last range ends at {}",
806            partition.last().unwrap().end
807        ));
808    }
809    for w in partition.windows(2) {
810        if w[0].end != w[1].start {
811            return Err(format!(
812                "{label} partition has gap/overlap between [{}..{}) and [{}..{})",
813                w[0].start, w[0].end, w[1].start, w[1].end
814            ));
815        }
816        if w[0].is_empty() {
817            return Err(format!(
818                "{label} partition has empty range [{}..{})",
819                w[0].start, w[0].end
820            ));
821        }
822    }
823    if partition.last().unwrap().is_empty() {
824        return Err(format!("{label} partition's final range is empty",));
825    }
826    Ok(())
827}
828
829/// Derive a disjoint contiguous partition of `[0..p_block)` from a
830/// list of BlockwisePenalty col_ranges. Distinct penalty ranges define
831/// term boundaries; gaps between them (unpenalised columns) become
832/// their own single-column partitions. Multiple penalties with the
833/// SAME col_range (e.g. tensor anisotropy axes) coalesce to one term.
834pub fn extract_term_partition_from_penalty_ranges(
835    p_block: usize,
836    penalty_ranges: &[std::ops::Range<usize>],
837) -> Vec<std::ops::Range<usize>> {
838    use std::collections::BTreeSet;
839    let mut starts: BTreeSet<usize> = BTreeSet::new();
840    starts.insert(0);
841    starts.insert(p_block);
842    for r in penalty_ranges {
843        starts.insert(r.start.min(p_block));
844        starts.insert(r.end.min(p_block));
845    }
846    let v: Vec<usize> = starts.into_iter().collect();
847    v.windows(2)
848        .filter_map(|w| if w[0] < w[1] { Some(w[0]..w[1]) } else { None })
849        .collect()
850}
851
852/// Pull a single raw block-local [`BlockwisePenalty`] back through the
853/// block's own diagonal reparameterisation `V_b` (the `(b, b)` block of
854/// the triangular T), producing a per-block-width compiled penalty.
855///
856/// The penalty's `local` is `pen.col_range.len()` square and covers a
857/// sub-region of the raw block at offset `pen.col_range.start` (which is
858/// block-local, i.e. relative to the block's first raw column). It is
859/// embedded into the full raw block width `v_block.nrows()` at that
860/// offset, then pulled back as `V_bᵀ · embed(S) · V_b`, giving a
861/// `(w_b_compiled × w_b_compiled)` symmetric `PenaltyMatrix::Dense`
862/// where `w_b_compiled == v_block.ncols()`.
863///
864/// This is the penalty contract a per-block `ParameterBlockSpec`
865/// requires: each block's penalty acts on that block's own compiled
866/// coordinate `θ_b`. The cross-block residualisation `R_{a→b}` carried
867/// in T's strict-upper triangle is absorbed into the *design* columns
868/// (the residualised emitted design `C_b V_b − A_{<b} R_b`), not into
869/// the penalty — exactly as the VM-exact compile-map path
870/// [`apply_compiled_map_to_designs`] does. Pulling the penalty back
871/// through the full joint T instead would yield a `(p_compiled × p_compiled)` dense
872/// matrix that cannot live in a single block's `penalties` slot and
873/// would violate the `p_b × p_b` block-spec validation.
874pub fn pull_back_blockwise_penalty_through_block_v(
875    pen: &gam_terms::smooth::BlockwisePenalty,
876    v_block: &Array2<f64>,
877) -> Result<PenaltyMatrix, String> {
878    let raw_p = v_block.nrows();
879    let compiled_p = v_block.ncols();
880    let block_p = pen.col_range.len();
881    let embed_start = pen.col_range.start;
882    let embed_end = pen.col_range.end;
883    if embed_end > raw_p {
884        return Err(format!(
885            "pull_back_blockwise_penalty_through_block_v: penalty col_range {embed_start}..{embed_end} \
886             exceeds block raw width {raw_p}"
887        ));
888    }
889    if pen.local.nrows() != block_p || pen.local.ncols() != block_p {
890        return Err(format!(
891            "pull_back_blockwise_penalty_through_block_v: penalty local is {}x{} but col_range \
892             width is {block_p}",
893            pen.local.nrows(),
894            pen.local.ncols(),
895        ));
896    }
897    let mut embedded = Array2::<f64>::zeros((raw_p, raw_p));
898    if block_p > 0 {
899        let mut dst =
900            embedded.slice_mut(ndarray::s![embed_start..embed_end, embed_start..embed_end]);
901        for i in 0..block_p {
902            for j in 0..block_p {
903                dst[[i, j]] = pen.local[[i, j]];
904            }
905        }
906    }
907    // V_bᵀ · embed(S) · V_b → (compiled_p × compiled_p).
908    let temp = embedded.dot(v_block);
909    let pulled = v_block.t().dot(&temp);
910    let mut sym = Array2::<f64>::zeros((compiled_p, compiled_p));
911    for i in 0..compiled_p {
912        for j in 0..compiled_p {
913            sym[[i, j]] = 0.5 * (pulled[[i, j]] + pulled[[j, i]]);
914        }
915    }
916    Ok(PenaltyMatrix::Dense(sym))
917}
918
919/// Assemble a 3-block [`CompiledMap`] (time, marginal, logslope) from a
920/// [`SurvivalParametricCompiledPerTerm`] produced by the full 4×4 row-Hessian
921/// driver [`compile_survival_parametric_designs_per_term`].
922///
923/// The full global triangular `T` is built from the per-term `V`/`R` blocks
924/// (diagonal `V_b`, strict-upper `−R_{a→b}` — identical to the matrix the
925/// result-time lift [`Gauge::from_v_and_r`] uses), then partitioned
926/// into the three *block* ranges (raw = summed per-term raw widths, compiled =
927/// summed per-term kept widths). The resulting `CompiledMap` is interchangeable
928/// with one from
929/// [`gam_identifiability::families::compiler::compile_from_raw_grams`], so the
930/// existing [`apply_compiled_map_to_designs`] +
931/// [`Gauge::from_compiled_map`] machinery consumes it unchanged.
932///
933/// This is the seam that lets the survival closed-form fast path engage on the
934/// *correct* identifiable quotient: the cheap η₁-only rawstack metric can
935/// falsely collapse a whole channel (marginal/logslope share a PC surface in
936/// the η₁ row curvature), but the full survival row Hessian is 4×4 in
937/// `(q0, q1, qd1, g)` and chains differently into each block, so it keeps the
938/// channels distinct when no *true* alias exists. The reduced basis it emits
939/// goes to Newton in place of the rank-deficient raw basis.
940pub fn compiled_map_from_per_term(
941    compiled: &SurvivalParametricCompiledPerTerm,
942) -> gam_identifiability::families::compiler::CompiledMap {
943    // Per-term V's and R's in global compile order: time terms, then marginal,
944    // then logslope — exactly the order `r_lw_per_term` is stored in.
945    let mut v_all: Vec<Array2<f64>> = Vec::new();
946    v_all.extend(compiled.v_time_per_term.iter().cloned());
947    v_all.extend(compiled.v_marginal_per_term.iter().cloned());
948    v_all.extend(compiled.v_logslope_per_term.iter().cloned());
949
950    let t_full = assemble_block_triangular_t(&v_all, &compiled.r_lw_per_term);
951
952    // Per-block raw / compiled widths = summed per-term widths within the block.
953    let raw_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.nrows()).sum() };
954    let kept_w = |terms: &[Array2<f64>]| -> usize { terms.iter().map(|v| v.ncols()).sum() };
955    let raw_time = raw_w(&compiled.v_time_per_term);
956    let raw_marg = raw_w(&compiled.v_marginal_per_term);
957    let raw_log = raw_w(&compiled.v_logslope_per_term);
958    let kept_time = kept_w(&compiled.v_time_per_term);
959    let kept_marg = kept_w(&compiled.v_marginal_per_term);
960    let kept_log = kept_w(&compiled.v_logslope_per_term);
961
962    let raw_block_ranges = vec![
963        0..raw_time,
964        raw_time..(raw_time + raw_marg),
965        (raw_time + raw_marg)..(raw_time + raw_marg + raw_log),
966    ];
967    let compiled_block_ranges = vec![
968        0..kept_time,
969        kept_time..(kept_time + kept_marg),
970        (kept_time + kept_marg)..(kept_time + kept_marg + kept_log),
971    ];
972
973    gam_identifiability::families::compiler::CompiledMap {
974        raw_from_compiled: t_full,
975        compiled_block_ranges,
976        raw_block_ranges,
977    }
978}
979
980/// Build a W-orthogonal **partial** reduced-logslope reparameterisation `T`
981/// (`p_log × r`, `0 < r < p_log`) for the survival marginal↔logslope confound,
982/// mirroring the proven-correct BMS effective-Schur-Gram construction
983/// [`crate::bms::block_specs::reduced_logslope_transform_effective`] but in
984/// survival's per-row 4×4 primary-state Hessian metric.
985///
986/// # Why this exists (#979)
987///
988/// The survival marginal and logslope channels share the SAME spatial basis
989/// (e.g. `matern(PC1,PC2,PC3)`), so on clustered-PC data the full 4×4
990/// row-Hessian identifiability compiler can attribute the *entire* shared
991/// surface to the lowest-priority logslope block and collapse it to zero width
992/// — which the `#741` required-channel guard rejects, forcing a fallback to the
993/// UNREDUCED design + Jeffreys conditioning. That fallback leaves a
994/// quadratically-flat near-null direction in the joint penalised Hessian
995/// `M = JᵀHJ + S`, so the inner joint-Newton cannot certify stationarity and
996/// the outer wall-clock deadline becomes load-bearing rather than a backstop.
997///
998/// The BMS path never hits this because it does a *partial* reduction: it
999/// removes from the logslope block ONLY the directions whose effective image is
1000/// W-explained by the marginal span (the confounded null space of the effective
1001/// Schur Gram), keeping every surviving logslope direction. The result is
1002/// full-rank `M` BY CONSTRUCTION — no runtime projection, deadline demoted to a
1003/// pure backstop.
1004///
1005/// # The metric collapse to scalar weights
1006///
1007/// At the pilot the marginal design feeds the primary channels `(q0, q1)`
1008/// identically (`∂q0/∂β_m = ∂q1/∂β_m = m`, and `∂qd1/∂β_m = 0` because the
1009/// `#808` fallback always builds a zero marginal-derivative design) and the
1010/// logslope design feeds only `g` (`∂g/∂β_s = g_dg`). With the per-row PSD 4×4
1011/// Hessian `H` in channel order `(q0, q1, qd1, g)` and `H[0,1] = 0` (q0 and q1
1012/// enter the disjoint outputs η0, η1), the combined effective Gram
1013/// `[[A, Bᵀ], [B, C]] = J_combinedᵀ H J_combined` (PSD per row) collapses to
1014/// scalar-weighted Grams of the raw block designs:
1015///
1016/// ```text
1017///     w_mm = H00 + H11           (marginal self weight, ≥ 0)
1018///     w_mg = H03 + H13           (marginal↔logslope cross weight)
1019///     w_gg = H33                 (logslope self weight, ≥ 0)
1020///     A = m_dqᵀ diag(w_mm) m_dq + εI     (p_m × p_m)
1021///     B = m_dqᵀ diag(w_mg) g_dg          (p_m × p_log)
1022///     C = g_dgᵀ diag(w_gg) g_dg          (p_log × p_log)
1023///     Gtt = C − Bᵀ A⁻¹ B                 (p_log × p_log, PSD Schur complement)
1024/// ```
1025///
1026/// `T` is the orthonormal eigenbasis of `Gtt` for eigenvalues above a tolerance
1027/// relative to the effective logslope energy scale (single-sourced from the BMS
1028/// reference cut). Returns `Ok(None)` when there is nothing to reduce
1029/// (`r == p_log`) or the entire effective logslope image collapses into the
1030/// marginal span (`r == 0`); in both cases the caller keeps its existing path.
1031///
1032/// Precondition: `marginal_dq`'s derivative-into-qd1 contribution is zero (the
1033/// `#808` fallback constructs `m_dqd1` as an all-zero matrix), so marginal
1034/// touches only `(q0, q1)` and the scalar collapse above is exact.
1035pub fn survival_reduced_logslope_transform_effective(
1036    marginal_dq: ndarray::ArrayView2<'_, f64>,
1037    logslope_dg: ndarray::ArrayView2<'_, f64>,
1038    row_hess: &SurvivalRowHessian,
1039) -> Result<Option<Array2<f64>>, String> {
1040    use crate::bms::block_specs::LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1041    use gam_linalg::faer_ndarray::{
1042        FaerArrayView, factorize_symmetricwith_fallback, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
1043    };
1044
1045    let n = marginal_dq.nrows();
1046    let p_m = marginal_dq.ncols();
1047    let p_log = logslope_dg.ncols();
1048    if p_m == 0 || p_log == 0 {
1049        return Ok(None);
1050    }
1051    if logslope_dg.nrows() != n || row_hess.h.shape()[0] != n {
1052        return Err(format!(
1053            "survival reduced logslope: row mismatch marginal={n}, logslope={}, row_hess={}",
1054            logslope_dg.nrows(),
1055            row_hess.h.shape()[0],
1056        ));
1057    }
1058
1059    // Scalar effective weights from the per-row 4×4 PSD Hessian, channel order
1060    // (q0, q1, qd1, g). Marginal → {q0, q1} (identical column m), logslope → {g}.
1061    let mut w_mm = Array1::<f64>::zeros(n);
1062    let mut w_mg = Array1::<f64>::zeros(n);
1063    let mut w_gg = Array1::<f64>::zeros(n);
1064    for i in 0..n {
1065        w_mm[i] = row_hess.h[[i, 0, 0]] + row_hess.h[[i, 1, 1]];
1066        w_mg[i] = row_hess.h[[i, 0, 3]] + row_hess.h[[i, 1, 3]];
1067        w_gg[i] = row_hess.h[[i, 3, 3]];
1068        if !(w_mm[i].is_finite() && w_mg[i].is_finite() && w_gg[i].is_finite()) {
1069            return Err("survival reduced logslope: non-finite row Hessian weight".to_string());
1070        }
1071    }
1072
1073    let marg = marginal_dq.to_owned();
1074    let log = logslope_dg.to_owned();
1075
1076    // C = G_effᵀ W G_eff (raw-coordinate effective logslope Gram); its diagonal
1077    // sets the energy scale for the relative kept-direction tolerance.
1078    let c_gram = fast_xt_diag_x(&log, &w_gg);
1079    let energy_scale = (0..p_log).map(|i| c_gram[[i, i]]).fold(0.0_f64, f64::max);
1080    if !energy_scale.is_finite() || energy_scale <= 0.0 {
1081        return Ok(None);
1082    }
1083
1084    // A = M_effᵀ W M_eff + εI (ridge relative to the marginal effective energy
1085    // so the Schur solve is well-posed even when the marginal pilot Gram is
1086    // rank-soft; the ridge only under-removes, i.e. is conservative).
1087    let mut a_gram = fast_xt_diag_x(&marg, &w_mm);
1088    let a_scale = (0..p_m).map(|i| a_gram[[i, i]]).fold(0.0_f64, f64::max);
1089    let a_ridge = (a_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL).max(f64::EPSILON);
1090    for i in 0..p_m {
1091        a_gram[[i, i]] += a_ridge;
1092    }
1093
1094    // B = M_effᵀ W G_eff (p_m × p_log);  Gtt = C − Bᵀ A⁻¹ B (p_log × p_log, PSD).
1095    let b_cross = fast_xt_diag_y(&marg, &w_mg, &log);
1096    let a_view = FaerArrayView::new(&a_gram);
1097    let a_factor = factorize_symmetricwith_fallback(a_view.as_ref(), Side::Lower).map_err(|e| {
1098        format!("survival reduced logslope: marginal effective Gram factorization failed: {e}")
1099    })?;
1100    let b_view = FaerArrayView::new(&b_cross);
1101    let solved = a_factor.solve(b_view.as_ref()); // A⁻¹ B  (p_m × p_log)
1102    let a_inv_b = Array2::from_shape_fn((p_m, p_log), |(i, j)| solved[(i, j)]);
1103    let schur = fast_atb(&b_cross, &a_inv_b); // Bᵀ A⁻¹ B  (p_log × p_log)
1104    let mut stt = &c_gram - &schur;
1105    stt = (&stt + &stt.t()) * 0.5;
1106    if stt.iter().any(|v| !v.is_finite()) {
1107        return Err(
1108            "survival reduced logslope: effective Schur Gram produced non-finite entries"
1109                .to_string(),
1110        );
1111    }
1112
1113    let (evals, evecs) = stt
1114        .eigh(Side::Lower)
1115        .map_err(|e| format!("survival reduced logslope: eigendecomposition failed: {e:?}"))?;
1116    // A `Gtt` eigenvalue far below the effective logslope energy scale means that
1117    // direction's effective logslope column is W-explained by the effective
1118    // marginal span — exactly the joint-Hessian rank-soft confounded direction.
1119    let tol = energy_scale * LOGSLOPE_REDUCED_BASIS_RELATIVE_TOL;
1120    let mut kept: Vec<usize> = (0..evals.len()).filter(|&i| evals[i] > tol).collect();
1121    kept.sort_by(|&a, &b| {
1122        evals[b]
1123            .partial_cmp(&evals[a])
1124            .unwrap_or(std::cmp::Ordering::Equal)
1125    });
1126    let r = kept.len();
1127    // r == p_log: no confounded direction to remove. r == 0: the whole effective
1128    // logslope image is in the marginal span. In both cases keep the raw design.
1129    if r == p_log || r == 0 {
1130        return Ok(None);
1131    }
1132    let mut transform = Array2::<f64>::zeros((p_log, r));
1133    for (out_col, &src) in kept.iter().enumerate() {
1134        transform.column_mut(out_col).assign(&evecs.column(src));
1135    }
1136    if transform.iter().any(|v| !v.is_finite()) {
1137        return Err(
1138            "survival reduced logslope: reduced transform produced non-finite entries".to_string(),
1139        );
1140    }
1141    Ok(Some(transform))
1142}
1143
1144/// Assemble a block-diagonal 3-block [`CompiledMap`] that passes the time and
1145/// marginal blocks through unchanged (identity) and reparameterises ONLY the
1146/// logslope block via `t_log` (`p_log × r`). Used by the survival `#979`
1147/// partial reduced-logslope confound removal
1148/// ([`survival_reduced_logslope_transform_effective`]): the marginal/time
1149/// channels are untouched, the logslope block drops only its confounded
1150/// directions, and the joint penalised Hessian is full-rank by construction.
1151///
1152/// The resulting `CompiledMap` is interchangeable with one from
1153/// [`compiled_map_from_per_term`] /
1154/// [`gam_identifiability::families::compiler::compile_from_raw_grams`], so the
1155/// existing [`apply_compiled_map_to_designs`] + [`Gauge::from_compiled_map`]
1156/// machinery consumes it unchanged. Because the map is block-diagonal there is
1157/// no strict-upper cross-block residual `R`, and `apply_compiled_map_to_designs`
1158/// reads only the per-block diagonal `V_b = T[raw_b, compiled_b]` — `V_time` and
1159/// `V_marg` are identities, `V_log = t_log`.
1160pub fn survival_block_diagonal_logslope_map(
1161    p_time: usize,
1162    p_marg: usize,
1163    t_log: &Array2<f64>,
1164) -> gam_identifiability::families::compiler::CompiledMap {
1165    let p_log = t_log.nrows();
1166    let r = t_log.ncols();
1167    let raw_total = p_time + p_marg + p_log;
1168    let compiled_total = p_time + p_marg + r;
1169    let mut t_full = Array2::<f64>::zeros((raw_total, compiled_total));
1170    for i in 0..p_time {
1171        t_full[[i, i]] = 1.0;
1172    }
1173    for i in 0..p_marg {
1174        t_full[[p_time + i, p_time + i]] = 1.0;
1175    }
1176    for ri in 0..p_log {
1177        for cj in 0..r {
1178            t_full[[p_time + p_marg + ri, p_time + p_marg + cj]] = t_log[[ri, cj]];
1179        }
1180    }
1181    gam_identifiability::families::compiler::CompiledMap {
1182        raw_from_compiled: t_full,
1183        compiled_block_ranges: vec![
1184            0..p_time,
1185            p_time..(p_time + p_marg),
1186            (p_time + p_marg)..compiled_total,
1187        ],
1188        raw_block_ranges: vec![
1189            0..p_time,
1190            p_time..(p_time + p_marg),
1191            (p_time + p_marg)..raw_total,
1192        ],
1193    }
1194}
1195
1196/// Apply a global [`CompiledMap`] T directly to the three survival
1197/// parametric block designs (time/marginal/logslope). Slices the
1198/// per-block diagonal of T into `V_b = T[raw_range_b, compiled_range_b]`
1199/// (shape `p_b_raw × w_b_compiled`), wraps each channel's raw design via
1200/// [`wrap_design_with_transform`], and pulls each block's penalties back
1201/// through that block's OWN `V_b` via
1202/// [`pull_back_blockwise_penalty_through_block_v`], producing
1203/// per-block-width `(w_b_compiled × w_b_compiled)` penalties — the shape
1204/// a per-block `ParameterBlockSpec.penalties` slot requires.
1205///
1206/// `map.raw_block_ranges` must equal three contiguous ranges in the
1207/// order Time → Marginal → Logslope (matching the input designs).
1208/// `map.compiled_block_ranges` runs in the same order.
1209///
1210/// Penalties supplied to this function:
1211/// - `time_penalties` are `BlockwisePenalty`s whose `col_range` is in
1212///   the time block's local raw coords (e.g. `0..p_time`).
1213/// - `marginal_penalties` / `logslope_penalties` likewise — local to
1214///   their own channel's raw width.
1215///
1216/// Each penalty's block-local `col_range` is embedded into the block's
1217/// raw width and pulled back as `V_bᵀ S_b V_b`. The cross-block
1218/// residualisation `R_{a→b}` carried in T's strict-upper triangle is
1219/// absorbed into the residualised *design* columns, not the penalty, so
1220/// the per-block penalty model stays exact for the highest-priority
1221/// block (time, no anchor → `R = []`) and matches the sibling per-block
1222/// compile path for the rest.
1223pub fn apply_compiled_map_to_designs(
1224    map: &gam_identifiability::families::compiler::CompiledMap,
1225    time_design_entry: DesignMatrix,
1226    time_design_exit: DesignMatrix,
1227    time_design_derivative_exit: DesignMatrix,
1228    marginal_design: DesignMatrix,
1229    logslope_design: DesignMatrix,
1230    time_penalties: &[gam_terms::smooth::BlockwisePenalty],
1231    marginal_penalties: &[gam_terms::smooth::BlockwisePenalty],
1232    logslope_penalties: &[gam_terms::smooth::BlockwisePenalty],
1233) -> Result<CompiledSurvivalDesignsVMExact, String> {
1234    if map.raw_block_ranges.len() != 3 || map.compiled_block_ranges.len() != 3 {
1235        return Err(format!(
1236            "apply_compiled_map_to_designs: expected exactly 3 blocks (time, marginal, logslope), \
1237             got {} raw / {} compiled",
1238            map.raw_block_ranges.len(),
1239            map.compiled_block_ranges.len(),
1240        ));
1241    }
1242    let time_raw = map.raw_block_ranges[0].clone();
1243    let marg_raw = map.raw_block_ranges[1].clone();
1244    let log_raw = map.raw_block_ranges[2].clone();
1245    let time_compiled = map.compiled_block_ranges[0].clone();
1246    let marg_compiled = map.compiled_block_ranges[1].clone();
1247    let log_compiled = map.compiled_block_ranges[2].clone();
1248
1249    let t = &map.raw_from_compiled;
1250    let raw_total = t.nrows();
1251    let compiled_total = t.ncols();
1252    let expected_raw_total = log_raw.end;
1253    if raw_total != expected_raw_total {
1254        return Err(format!(
1255            "apply_compiled_map_to_designs: T has {raw_total} raw rows but block ranges sum to \
1256             {expected_raw_total}"
1257        ));
1258    }
1259    let expected_compiled_total = log_compiled.end;
1260    if compiled_total != expected_compiled_total {
1261        return Err(format!(
1262            "apply_compiled_map_to_designs: T has {compiled_total} compiled cols but block ranges \
1263             sum to {expected_compiled_total}"
1264        ));
1265    }
1266
1267    let v_time = t
1268        .slice(ndarray::s![time_raw.clone(), time_compiled.clone()])
1269        .to_owned();
1270    let v_marg = t
1271        .slice(ndarray::s![marg_raw.clone(), marg_compiled.clone()])
1272        .to_owned();
1273    let v_log = t
1274        .slice(ndarray::s![log_raw.clone(), log_compiled.clone()])
1275        .to_owned();
1276
1277    let time_entry_out =
1278        wrap_design_with_transform(time_design_entry, &v_time, "compiled-map: time entry")?;
1279    let time_exit_out =
1280        wrap_design_with_transform(time_design_exit, &v_time, "compiled-map: time exit")?;
1281    let time_deriv_out = wrap_design_with_transform(
1282        time_design_derivative_exit,
1283        &v_time,
1284        "compiled-map: time derivative_exit",
1285    )?;
1286    let marg_out = wrap_design_with_transform(marginal_design, &v_marg, "compiled-map: marginal")?;
1287    let log_out = wrap_design_with_transform(logslope_design, &v_log, "compiled-map: logslope")?;
1288
1289    // Pull each block's penalties back through that block's OWN diagonal
1290    // reparameterisation V_b (= the (b, b) block of T). This produces a
1291    // per-block-width `(w_b_compiled × w_b_compiled)` penalty — the only
1292    // shape a per-block `ParameterBlockSpec.penalties` slot accepts.
1293    //
1294    // The block-local penalty `V_bᵀ S_b V_b` is the correct per-block
1295    // penalty: in raw coords the model penalises `γ_bᵀ S_b γ_b` on block
1296    // b's own coefficients, and under the residualised reparameterisation
1297    // the cross-block carry `R_{a→b}` lives entirely in the *design*
1298    // columns (`C_b V_b − A_{<b} R_b`), not in the penalty.
1299    //
1300    // Pulling penalties back through the full joint triangular T instead
1301    // (`Tᵀ blkdiag(S_b) T`) yields a `(p_compiled × p_compiled)` dense
1302    // matrix whose off-diagonal couples θ_b to earlier blocks' θ_a;
1303    // jamming that joint-width matrix into a single block's `penalties`
1304    // produced the `block 0 penalty 0 must be 12x12, got 17x17` mismatch
1305    // that surfaced as the `assert_valid_blockspecs` FFI panic. The two
1306    // agree whenever the residualisation `R_{a→b}` lands in the null space
1307    // of S_a (the shared low-order / parametric directions the identifiable
1308    // quotient strips), which is the case the compiler targets.
1309    let pull_set = |pens: &[gam_terms::smooth::BlockwisePenalty],
1310                    v_block: &Array2<f64>,
1311                    channel: &str|
1312     -> Result<Vec<PenaltyMatrix>, String> {
1313        pens.iter()
1314            .map(|p| {
1315                pull_back_blockwise_penalty_through_block_v(p, v_block).map_err(|e| {
1316                    format!("apply_compiled_map_to_designs: {channel} penalty pullback: {e}")
1317                })
1318            })
1319            .collect()
1320    };
1321
1322    let time_penalties = pull_set(time_penalties, &v_time, "time")?;
1323    let marginal_penalties = pull_set(marginal_penalties, &v_marg, "marginal")?;
1324    let logslope_penalties = pull_set(logslope_penalties, &v_log, "logslope")?;
1325    validate_block_penalty_shapes("time", time_exit_out.ncols(), &time_penalties)?;
1326    validate_block_penalty_shapes("marginal", marg_out.ncols(), &marginal_penalties)?;
1327    validate_block_penalty_shapes("logslope", log_out.ncols(), &logslope_penalties)?;
1328
1329    Ok(CompiledSurvivalDesignsVMExact {
1330        time_design_entry: time_entry_out,
1331        time_design_exit: time_exit_out,
1332        time_design_derivative_exit: time_deriv_out,
1333        marginal_design: marg_out,
1334        logslope_design: log_out,
1335        time_penalties,
1336        marginal_penalties,
1337        logslope_penalties,
1338    })
1339}
1340
1341fn validate_block_penalty_shapes(
1342    block: &str,
1343    width: usize,
1344    penalties: &[PenaltyMatrix],
1345) -> Result<(), String> {
1346    for (idx, penalty) in penalties.iter().enumerate() {
1347        let shape = penalty.shape();
1348        if shape != (width, width) {
1349            return Err(format!(
1350                "apply_compiled_map_to_designs: {block} penalty {idx} must be {width}x{width}, got {}x{}",
1351                shape.0, shape.1
1352            ));
1353        }
1354    }
1355    Ok(())
1356}
1357
1358/// Run the identifiability compiler on the three survival parametric
1359/// blocks (time, marginal, logslope) at a pilot β and return the per-
1360/// block V reparameterisation matrices.
1361///
1362/// `row_hess` must be a PSD per-row 4×4 Hessian of `−log L_i(u_i)` at
1363/// the pilot β (see [`SurvivalRowHessian::from_pilot_primary_state`]).
1364/// The compiler residualises blocks left-to-right in priority order
1365/// (time → marginal → logslope) in the sqrt-H-metric so any aliased
1366/// direction lands in the lower-priority block, then runs a post-walk
1367/// column-pivoted QR on the cumulative anchor and drops trailing
1368/// pivots from the latest block. The returned V matrices are ready to
1369/// be applied to each block's raw design and penalty before the
1370/// `ParameterBlockSpec` list is assembled.
1371///
1372/// On `FullyAliased` from `compile()` (a block fully absorbed by its
1373/// cumulative anchor) this returns `Err`. The construction site should
1374/// surface that as a structured user-facing diagnostic — the model is
1375/// asking the compiler to assign zero degrees of freedom to a named
1376/// parametric block, which is a model-spec bug not a numerical one.
1377///
1378/// Sibling Phase-4b wiring (`bernoulli_marginal_slope::install_compiled_flex_block_into_runtime`)
1379/// already calls `compile()` for the flex blocks. This helper extends
1380/// that contract to the parametric blocks by giving the SMGS
1381/// construction site a one-line entry point — it does NOT yet apply
1382/// the V transforms to the family's captured designs (the captured-
1383/// design update is the remaining integration step that touches the
1384/// family's row-Hessian assembly assertions).
1385pub fn compile_survival_parametric_designs(
1386    time_dq0: Array2<f64>,
1387    time_dq1: Array2<f64>,
1388    time_dqd1: Array2<f64>,
1389    marginal_dq: Array2<f64>,
1390    marginal_dqd1: Array2<f64>,
1391    logslope_dg: Array2<f64>,
1392    row_hess: &dyn RowHessian,
1393) -> Result<SurvivalParametricCompiled, String> {
1394    use gam_identifiability::families::compiler::compile;
1395
1396    let p_time_raw = time_dq0.ncols();
1397    let p_marg_raw = marginal_dq.ncols();
1398    let p_log_raw = logslope_dg.ncols();
1399
1400    let inputs = build_survival_compiler_inputs(
1401        time_dq0,
1402        time_dq1,
1403        time_dqd1,
1404        marginal_dq,
1405        marginal_dqd1,
1406        logslope_dg,
1407        None,
1408        None,
1409    );
1410    if inputs.operators.len() != 3 {
1411        return Err(format!(
1412            "compile_survival_parametric_designs: expected exactly 3 parametric operators \
1413             (time, marginal, logslope); got {}",
1414            inputs.operators.len(),
1415        ));
1416    }
1417    let compiled = compile(&inputs.operators, row_hess, &inputs.ordering)
1418        .map_err(|e| format!("identifiability::families::compiler::compile failed: {e}"))?;
1419    if compiled.blocks.len() != 3 {
1420        return Err(format!(
1421            "compile_survival_parametric_designs: compiler emitted {} blocks; expected 3",
1422            compiled.blocks.len(),
1423        ));
1424    }
1425    let v_time = compiled.blocks[0].t_lw.clone();
1426    let v_marginal = compiled.blocks[1].t_lw.clone();
1427    let v_logslope = compiled.blocks[2].t_lw.clone();
1428    let drops_by_block = (
1429        p_time_raw.saturating_sub(v_time.ncols()),
1430        p_marg_raw.saturating_sub(v_marginal.ncols()),
1431        p_log_raw.saturating_sub(v_logslope.ncols()),
1432    );
1433    Ok(SurvivalParametricCompiled {
1434        v_time,
1435        v_marginal,
1436        v_logslope,
1437        drops_by_block,
1438    })
1439}
1440
1441/// Build the operator stack from already-materialised dense designs.
1442///
1443/// `time_dq0/dq1/dqd1` are the time block's three primary-state Jacobians
1444/// at training rows. `marginal_dq` and `marginal_dqd1` are the marginal
1445/// block's contributions to q (shared between q0 and q1) and to qd1
1446/// (typically zero unless timewiggle interacts). `logslope_dg` is the
1447/// logslope block's contribution to g.
1448///
1449/// `score_warp_(dq, dqd1)` / `link_dev_(dq, dqd1)` are present only when
1450/// the corresponding flex block is active. The returned `ordering` parallels
1451/// `operators` so the caller can route compiled outputs back to runtime slots.
1452pub fn build_survival_compiler_inputs(
1453    time_dq0: Array2<f64>,
1454    time_dq1: Array2<f64>,
1455    time_dqd1: Array2<f64>,
1456    marginal_dq: Array2<f64>,
1457    marginal_dqd1: Array2<f64>,
1458    logslope_dg: Array2<f64>,
1459    score_warp_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1460    link_dev_dq_dqd1: Option<(Array2<f64>, Array2<f64>)>,
1461) -> SurvivalCompilerInputs {
1462    let mut operators: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(5);
1463    let mut ordering: Vec<BlockOrder> = Vec::with_capacity(5);
1464
1465    operators.push(Arc::new(TimeBlockOperator::new(
1466        time_dq0, time_dq1, time_dqd1,
1467    )));
1468    ordering.push(BlockOrder::Time);
1469
1470    operators.push(Arc::new(QChannelBlockOperator::new(
1471        marginal_dq,
1472        marginal_dqd1,
1473    )));
1474    ordering.push(BlockOrder::Marginal);
1475
1476    operators.push(Arc::new(LogslopeBlockOperator::new(logslope_dg)));
1477    ordering.push(BlockOrder::Logslope);
1478
1479    if let Some((dq, dqd1)) = score_warp_dq_dqd1 {
1480        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1481        ordering.push(BlockOrder::ScoreWarp);
1482    }
1483    if let Some((dq, dqd1)) = link_dev_dq_dqd1 {
1484        operators.push(Arc::new(QChannelBlockOperator::new(dq, dqd1)));
1485        ordering.push(BlockOrder::LinkDev);
1486    }
1487
1488    SurvivalCompilerInputs {
1489        operators,
1490        ordering,
1491    }
1492}
1493
1494/// V+M-exact compiled designs + per-block penalties for the survival
1495/// time/marginal/logslope blocks, produced by
1496/// [`apply_compiled_map_to_designs`] from a `CompiledMap`. The
1497/// construction site swaps raw designs/penalties for these compiled
1498/// versions before building `ParameterBlockSpec`s.
1499///
1500/// The emitted designs carry the exact residualised `C_b·V_b − A_{<b}·R_b`
1501/// row form (via [`wrap_design_with_transform`] on `V_b = T[raw_b, comp_b]`):
1502/// the cross-block residualisation `R_{a→b}` lives in those design columns,
1503/// while each block's penalty is pulled back through that block's own
1504/// diagonal `V_b` as `V_bᵀ S_b V_b` (the `*_penalties` fields).
1505///
1506/// At fit result the joint compiled β is lifted back to raw via the
1507/// `gam_solve::gauge::Gauge` built from the *same* `CompiledMap`
1508/// (`β_raw = T · θ`, T block-upper-triangular with `V_b` on the diagonal
1509/// and `-R_{a→b}` off-diagonal). The full T therefore lives on that
1510/// `Gauge`, not on this struct — the caller holds the `CompiledMap` and
1511/// constructs both from it, so duplicating T here would be dead state.
1512pub struct CompiledSurvivalDesignsVMExact {
1513    pub time_design_entry: DesignMatrix,
1514    pub time_design_exit: DesignMatrix,
1515    pub time_design_derivative_exit: DesignMatrix,
1516    pub marginal_design: DesignMatrix,
1517    pub logslope_design: DesignMatrix,
1518    /// Per-block penalties, each pulled back through that block's OWN
1519    /// diagonal reparameterisation `V_b` as `V_bᵀ S_b V_b`. The result
1520    /// is a per-block-width `PenaltyMatrix::Dense`
1521    /// (`w_b_compiled × w_b_compiled`) — the shape a per-block
1522    /// `ParameterBlockSpec.penalties` slot requires. Cross-block
1523    /// residualisation `R_{a→b}` is carried by the residualised design
1524    /// columns, not the penalty.
1525    pub time_penalties: Vec<PenaltyMatrix>,
1526    pub marginal_penalties: Vec<PenaltyMatrix>,
1527    pub logslope_penalties: Vec<PenaltyMatrix>,
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532    use super::*;
1533    use gam_problem::Gauge;
1534
1535    #[test]
1536    fn psd_clamp_zeros_negative_eigenvalues() {
1537        // Construct M = U diag(2, -1, 0.5, -0.25) Uᵀ for a fixed U from
1538        // a small rotation, verify the clamped matrix has eigenvalues
1539        // (2, 0, 0.5, 0).
1540        let mut m = Array2::<f64>::zeros((4, 4));
1541        // Diagonal with mixed signs is sufficient for the test: the
1542        // eigenvalues equal the diagonal and the eigenvectors are e_i.
1543        m[[0, 0]] = 2.0;
1544        m[[1, 1]] = -1.0;
1545        m[[2, 2]] = 0.5;
1546        m[[3, 3]] = -0.25;
1547        let clamped = psd_clamp_4x4(&m);
1548        assert!((clamped[[0, 0]] - 2.0).abs() < 1e-12);
1549        assert!(clamped[[1, 1]].abs() < 1e-12);
1550        assert!((clamped[[2, 2]] - 0.5).abs() < 1e-12);
1551        assert!(clamped[[3, 3]].abs() < 1e-12);
1552    }
1553
1554    #[test]
1555    fn time_block_operator_evaluate_full_shape() {
1556        let n = 6;
1557        let p = 3;
1558        let dq0 = Array2::from_shape_fn((n, p), |(i, j)| (i + j) as f64);
1559        let dq1 = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * 2.0 + j as f64);
1560        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| 0.5 * ((i * j) as f64));
1561        let op = TimeBlockOperator::new(dq0.clone(), dq1.clone(), dqd1.clone());
1562        let full = op.evaluate_full();
1563        assert_eq!(full.shape(), &[n, p, K_SURVIVAL]);
1564        for i in 0..n {
1565            for j in 0..p {
1566                assert_eq!(full[[i, j, 0]], dq0[[i, j]]);
1567                assert_eq!(full[[i, j, 1]], dq1[[i, j]]);
1568                assert_eq!(full[[i, j, 2]], dqd1[[i, j]]);
1569                assert_eq!(full[[i, j, 3]], 0.0);
1570            }
1571        }
1572    }
1573
1574    #[test]
1575    fn q_channel_block_apply_row_shares_q0_q1() {
1576        let n = 5;
1577        let p = 2;
1578        let dq = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) * (j as f64 + 1.0));
1579        let dqd1 = Array2::from_shape_fn((n, p), |(i, j)| (j as f64) - (i as f64));
1580        let op = QChannelBlockOperator::new(dq.clone(), dqd1.clone());
1581        let mut out = [0.0_f64; K_SURVIVAL];
1582        let delta = [1.0_f64, -0.5];
1583        op.apply_row(3, &delta, &mut out);
1584        let want_q = dq[[3, 0]] * 1.0 + dq[[3, 1]] * (-0.5);
1585        let want_qd = dqd1[[3, 0]] * 1.0 + dqd1[[3, 1]] * (-0.5);
1586        assert!((out[0] - want_q).abs() < 1e-12);
1587        assert!((out[1] - want_q).abs() < 1e-12);
1588        assert!((out[2] - want_qd).abs() < 1e-12);
1589        assert_eq!(out[3], 0.0);
1590    }
1591
1592    #[test]
1593    fn logslope_block_writes_only_g_channel() {
1594        let n = 4;
1595        let p = 2;
1596        let dg = Array2::from_shape_fn((n, p), |(i, j)| (i as f64) + 0.1 * (j as f64));
1597        let op = LogslopeBlockOperator::new(dg.clone());
1598        let mut out = [0.0_f64; K_SURVIVAL];
1599        let delta = [2.0_f64, -1.0];
1600        op.apply_row(1, &delta, &mut out);
1601        assert_eq!(out[0], 0.0);
1602        assert_eq!(out[1], 0.0);
1603        assert_eq!(out[2], 0.0);
1604        let want = dg[[1, 0]] * 2.0 + dg[[1, 1]] * (-1.0);
1605        assert!((out[3] - want).abs() < 1e-12);
1606    }
1607
1608    #[test]
1609    fn extract_term_partition_simple_cases() {
1610        let full = 0..5usize;
1611        // No penalties: whole block is one term.
1612        let part = extract_term_partition_from_penalty_ranges(5, &[]);
1613        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1614        // One penalty covering the whole block.
1615        let part = extract_term_partition_from_penalty_ranges(5, std::slice::from_ref(&full));
1616        assert_eq!(part.as_slice(), std::slice::from_ref(&full));
1617        // Two penalties with a gap: produces three terms (pen1, gap, pen2).
1618        let part = extract_term_partition_from_penalty_ranges(10, &[0..3, 6..10]);
1619        assert_eq!(part, vec![0..3, 3..6, 6..10]);
1620        // Duplicate penalty ranges coalesce.
1621        let part = extract_term_partition_from_penalty_ranges(6, &[0..3, 0..3, 3..6]);
1622        assert_eq!(part, vec![0..3, 3..6]);
1623        // Empty block.
1624        let part = extract_term_partition_from_penalty_ranges(0, &[]);
1625        assert!(part.is_empty());
1626    }
1627
1628    #[test]
1629    fn assemble_block_triangular_t_identity_when_v_eye_and_r_none() {
1630        let v_a = Array2::<f64>::eye(2);
1631        let v_b = Array2::<f64>::eye(2);
1632        let t = assemble_block_triangular_t(&[v_a, v_b], &[None, None]);
1633        assert_eq!(t.dim(), (4, 4));
1634        let eye4 = Array2::<f64>::eye(4);
1635        for i in 0..4 {
1636            for j in 0..4 {
1637                assert!((t[[i, j]] - eye4[[i, j]]).abs() < 1e-14);
1638            }
1639        }
1640    }
1641
1642    #[test]
1643    fn assemble_block_triangular_t_with_drops_and_nonzero_r() {
1644        let mut v_a = Array2::<f64>::zeros((3, 2));
1645        v_a[[0, 0]] = 1.0;
1646        v_a[[1, 0]] = 0.5;
1647        v_a[[2, 1]] = 1.0;
1648        let v_b = Array2::<f64>::eye(2);
1649        let r_ab =
1650            Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 1.0 + (i as f64) + 0.25 * (j as f64));
1651        let t =
1652            assemble_block_triangular_t(&[v_a.clone(), v_b.clone()], &[None, Some(r_ab.clone())]);
1653        assert_eq!(t.dim(), (5, 4));
1654        for i in 0..3 {
1655            for j in 0..2 {
1656                assert!((t[[i, j]] - v_a[[i, j]]).abs() < 1e-14);
1657            }
1658        }
1659        for i in 0..2 {
1660            for j in 0..2 {
1661                assert!((t[[3 + i, 2 + j]] - v_b[[i, j]]).abs() < 1e-14);
1662            }
1663        }
1664        for i in 0..3 {
1665            for j in 0..2 {
1666                assert!((t[[i, 2 + j]] + r_ab[[i, j]]).abs() < 1e-14);
1667            }
1668        }
1669        for i in 0..2 {
1670            for j in 0..2 {
1671                assert_eq!(t[[3 + i, j]], 0.0);
1672            }
1673        }
1674    }
1675
1676    #[test]
1677    fn validate_partition_rejects_bad_partitions() {
1678        let bad_start = 1..5usize;
1679        let short_cover = 0..3usize;
1680        let full_cover = 0..5usize;
1681        // Doesn't start at 0.
1682        assert!(validate_partition(std::slice::from_ref(&bad_start), 5, "test").is_err());
1683        // Doesn't cover the block.
1684        assert!(validate_partition(std::slice::from_ref(&short_cover), 5, "test").is_err());
1685        // Has a gap.
1686        assert!(validate_partition(&[0..2, 3..5], 5, "test").is_err());
1687        // Has overlap.
1688        assert!(validate_partition(&[0..3, 2..5], 5, "test").is_err());
1689        // Has empty range.
1690        assert!(validate_partition(&[0..0, 0..5], 5, "test").is_err());
1691        // Empty block + empty partition OK.
1692        assert!(validate_partition(&[], 0, "test").is_ok());
1693        // Valid partition.
1694        assert!(validate_partition(&[0..2, 2..5], 5, "test").is_ok());
1695        assert!(validate_partition(std::slice::from_ref(&full_cover), 5, "test").is_ok());
1696    }
1697
1698    /// Regression for #368: the phase-4b compiled-map penalty pullback must
1699    /// emit a PER-BLOCK-WIDTH penalty for every block (sized to that block's
1700    /// COMPILED design width), even when a block drops columns and the
1701    /// triangular T carries nonzero off-diagonal cross-block residualisation
1702    /// `R_{a→b}`. The original bug pulled penalties back through the full
1703    /// joint T (`Tᵀ S T`), producing joint-compiled-width penalties (e.g.
1704    /// 7×7) that did not fit a single per-block `ParameterBlockSpec.penalties`
1705    /// slot (e.g. time block compiled width 3), making `validate_blockspecs`
1706    /// fail and `assert_valid_blockspecs` panic across the FFI boundary on
1707    /// ordinary survival data.
1708    #[test]
1709    fn compiled_map_penalty_pullback_is_per_block_width_with_nonzero_residual() {
1710        use gam_identifiability::families::compiler::CompiledMap;
1711        use gam_terms::smooth::BlockwisePenalty;
1712
1713        let n = 10;
1714        // Time raw 3 → compiled 3 (block 0: no anchor, V pure, R=None).
1715        // Marginal raw 3 → compiled 2 (a real drop, with nonzero R against time).
1716        // Logslope raw 2 → compiled 2 (nonzero R against time+marginal).
1717        let v_time =
1718            Array2::<f64>::from_shape_fn(
1719                (3, 3),
1720                |(i, j)| {
1721                    if i == j { 1.0 } else { 0.1 * ((i + j) as f64) }
1722                },
1723            );
1724        let v_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| {
1725            0.5 + 0.3 * (i as f64) - 0.2 * (j as f64)
1726        });
1727        let v_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 1.2 } else { 0.4 });
1728        // R_marg: rows = time raw width 3, cols = marginal compiled width 2.
1729        let r_marg = Array2::<f64>::from_shape_fn((3, 2), |(i, j)| 0.7 - 0.1 * ((i + j) as f64));
1730        // R_log: rows = time+marg RAW width 6 (3 + 3), cols = logslope compiled
1731        // width 2. `assemble_block_triangular_t` stacks R_{a→b} over a<b, so the row
1732        // count is the sum of the RAW widths of the prior blocks (not their
1733        // compiled widths — marginal's compiled width is 2 but its raw width is 3).
1734        let r_log =
1735            Array2::<f64>::from_shape_fn((6, 2), |(i, j)| 0.3 + 0.05 * ((i * 2 + j) as f64));
1736
1737        let t = assemble_block_triangular_t(
1738            &[v_time.clone(), v_marg.clone(), v_log.clone()],
1739            &[None, Some(r_marg.clone()), Some(r_log.clone())],
1740        );
1741        assert_eq!(t.dim(), (8, 7), "joint raw 8 × joint compiled 7");
1742
1743        let map = CompiledMap {
1744            raw_from_compiled: t.clone(),
1745            compiled_block_ranges: vec![0..3, 3..5, 5..7],
1746            raw_block_ranges: vec![0..3, 3..6, 6..8],
1747        };
1748
1749        // Raw designs (dense, n rows).
1750        let raw_time_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
1751            Array2::<f64>::from_shape_fn((n, 3), |(i, j)| 1.0 + (i as f64) * 0.1 + (j as f64)),
1752        ));
1753        let raw_time_exit = raw_time_entry.clone();
1754        let raw_time_deriv = raw_time_entry.clone();
1755        let raw_marg = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1756            (n, 3),
1757            |(i, j)| 0.2 * (i as f64) - 0.3 * (j as f64),
1758        )));
1759        let raw_log = DesignMatrix::Dense(DenseDesignMatrix::from(Array2::<f64>::from_shape_fn(
1760            (n, 2),
1761            |(i, j)| 0.5 + (i as f64) * (j as f64 + 1.0),
1762        )));
1763
1764        // Block-local penalties (col_range relative to each block's first col).
1765        let s_time =
1766            Array2::<f64>::from_shape_fn(
1767                (3, 3),
1768                |(i, j)| if i == j { (i + 2) as f64 } else { 0.3 },
1769            );
1770        let s_marg =
1771            Array2::<f64>::from_shape_fn(
1772                (3, 3),
1773                |(i, j)| if i == j { 1.5 + i as f64 } else { 0.2 },
1774            );
1775        let s_log = Array2::<f64>::from_shape_fn((2, 2), |(i, j)| if i == j { 2.0 } else { 0.5 });
1776        let time_pens = vec![BlockwisePenalty::new(0..3, s_time.clone())];
1777        let marg_pens = vec![BlockwisePenalty::new(0..3, s_marg.clone())];
1778        let log_pens = vec![BlockwisePenalty::new(0..2, s_log.clone())];
1779
1780        let out = apply_compiled_map_to_designs(
1781            &map,
1782            raw_time_entry,
1783            raw_time_exit,
1784            raw_time_deriv,
1785            raw_marg,
1786            raw_log,
1787            &time_pens,
1788            &marg_pens,
1789            &log_pens,
1790        )
1791        .expect("apply_compiled_map_to_designs must succeed");
1792
1793        // Designs carry per-block compiled widths.
1794        assert_eq!(out.time_design_entry.ncols(), 3);
1795        assert_eq!(out.marginal_design.ncols(), 2);
1796        assert_eq!(out.logslope_design.ncols(), 2);
1797
1798        // Core invariant the bug violated: every penalty is sized to ITS
1799        // OWN block's compiled width, NOT the joint compiled width (7).
1800        for s in &out.time_penalties {
1801            assert_eq!(
1802                s.as_dense_cow().dim(),
1803                (3, 3),
1804                "time penalty must be per-block 3×3, not joint-width"
1805            );
1806        }
1807        for s in &out.marginal_penalties {
1808            assert_eq!(
1809                s.as_dense_cow().dim(),
1810                (2, 2),
1811                "marginal penalty must match reduced compiled width 2, not joint 7"
1812            );
1813        }
1814        for s in &out.logslope_penalties {
1815            assert_eq!(s.as_dense_cow().dim(), (2, 2));
1816        }
1817
1818        // For the time block (block 0, no anchor ⇒ R=None), the per-block
1819        // pullback is EXACT: θ_timeᵀ P_time θ_time == γ_timeᵀ S_time γ_time
1820        // with γ_time = V_time · θ_time. Verify the quadratic-form identity.
1821        let p_time_dense = out.time_penalties[0].as_dense_cow().into_owned();
1822        let theta_time = Array1::<f64>::from_shape_fn(3, |k| 0.4 + 0.7 * (k as f64));
1823        let gamma_time = v_time.dot(&theta_time);
1824        let lhs = theta_time.dot(&p_time_dense.dot(&theta_time));
1825        let rhs = gamma_time.dot(&s_time.dot(&gamma_time));
1826        assert!(
1827            (lhs - rhs).abs() < 1e-10,
1828            "time-block per-block pullback must be exact: lhs={lhs}, rhs={rhs}"
1829        );
1830
1831        // The marginal pullback must equal V_margᵀ S_marg V_marg exactly
1832        // (block-local; the cross-block R_marg lives in the design, not here).
1833        let p_marg_dense = out.marginal_penalties[0].as_dense_cow().into_owned();
1834        let want_marg = v_marg.t().dot(&s_marg.dot(&v_marg));
1835        for i in 0..2 {
1836            for j in 0..2 {
1837                assert!(
1838                    (p_marg_dense[[i, j]] - want_marg[[i, j]]).abs() < 1e-12,
1839                    "marginal penalty must be V_margᵀ S_marg V_marg at ({i},{j})"
1840                );
1841            }
1842        }
1843    }
1844
1845    /// Top-level Phase-4b API test for the SMGS parametric path:
1846    /// call `compile_survival_parametric_designs` on a shared-constant
1847    /// alias between time and marginal, with an identity row Hessian.
1848    /// Verify the returned `v_*` matrices have the expected widths
1849    /// (time keeps all 3, marginal loses 1, logslope keeps both) and
1850    /// `drops_by_block` reports `(0, 1, 0)`.
1851    #[test]
1852    fn compile_survival_parametric_designs_helper_attributes_drop_to_marginal() {
1853        let n = 24;
1854        let p_time = 3;
1855        let p_marginal = 3;
1856        let p_logslope = 2;
1857        let x: Vec<f64> = (0..n)
1858            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1859            .collect();
1860        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1861        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1862        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1863        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1864        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1865        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1866        for i in 0..n {
1867            time_dq0[[i, 0]] = 1.0;
1868            time_dq0[[i, 1]] = x[i];
1869            time_dq0[[i, 2]] = x[i] * x[i];
1870            time_dq1[[i, 0]] = 1.0;
1871            time_dq1[[i, 1]] = x[i];
1872            time_dq1[[i, 2]] = x[i] * x[i];
1873            time_dqd1[[i, 0]] = 0.0;
1874            time_dqd1[[i, 1]] = 1.0;
1875            time_dqd1[[i, 2]] = 2.0 * x[i];
1876            marg_dq[[i, 0]] = 1.0; // alias with time col 0
1877            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1878            marg_dq[[i, 2]] = x[i].sin();
1879            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1880            log_dg[[i, 1]] = x[i].tanh();
1881        }
1882        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1883        for i in 0..n {
1884            for k in 0..K_SURVIVAL {
1885                h_full[[i, k, k]] = 1.0;
1886            }
1887        }
1888        let row_hess = SurvivalRowHessian::from_full(h_full);
1889        let out = compile_survival_parametric_designs(
1890            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, &row_hess,
1891        )
1892        .expect("Phase-4b parametric compile must succeed on single-direction alias");
1893        assert_eq!(out.v_time.ncols(), p_time, "time keeps all columns");
1894        assert_eq!(
1895            out.v_marginal.ncols(),
1896            p_marginal - 1,
1897            "marginal loses exactly the shared-constant direction"
1898        );
1899        assert_eq!(out.v_logslope.ncols(), p_logslope, "logslope is clean");
1900        assert_eq!(
1901            out.drops_by_block,
1902            (0, 1, 0),
1903            "attribution: zero from time/logslope, one from marginal",
1904        );
1905    }
1906
1907    /// End-to-end Phase-4b smoke test: build the full 3-block survival
1908    /// parametric operator stack (time + marginal + logslope) with a
1909    /// shared-constant alias seeded between the time and marginal
1910    /// blocks, feed it into `compile()` with an identity 4×4 row
1911    /// Hessian on every row, and verify the compiler:
1912    ///
1913    ///   (1) returns a [`CompiledBlocks`] with one block per input;
1914    ///   (2) preserves all 3 columns of the highest-priority `Time`
1915    ///       block in `t_lw` (the time block enters first in the
1916    ///       ordering, so its full column span survives);
1917    ///   (3) drops exactly one direction from `Marginal` (the
1918    ///       constant aliased with the time intercept), leaving its
1919    ///       remaining columns in `t_lw`;
1920    ///   (4) reports `joint_rank` = (raw_total - 1).
1921    ///
1922    /// This validates the Phase-4b construction-time orthogonalisation
1923    /// path on the survival K=4 row primary state and then feeds the
1924    /// compiled per-block reduced bases through the SMGS lift [`Gauge`]
1925    /// (step 6), asserting the lift's reduced/raw block structure agrees
1926    /// with the compiled rank-drop — the construction contract end to end.
1927    #[test]
1928    fn compile_survival_three_block_with_shared_constant_drops_one_direction() {
1929        use gam_identifiability::families::compiler::compile;
1930
1931        let n = 32;
1932        let p_time = 3;
1933        let p_marginal = 3;
1934        let p_logslope = 2;
1935
1936        // Time block:
1937        //   col 0 = ones (the shared constant — aliases marginal col 0);
1938        //   col 1 = linear x;
1939        //   col 2 = quadratic x².
1940        // q0/q1 share the same design (so the alias surfaces in both
1941        // the entry and exit primary channels); qd1 is the derivative
1942        // of the design w.r.t. time at the exit point, which for the
1943        // constant column is exactly zero (the gauge identity that
1944        // makes the constant a true null direction under (q0, q1, qd1)
1945        // joint).
1946        let x: Vec<f64> = (0..n)
1947            .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
1948            .collect();
1949        let mut time_dq0 = Array2::<f64>::zeros((n, p_time));
1950        let mut time_dq1 = Array2::<f64>::zeros((n, p_time));
1951        let mut time_dqd1 = Array2::<f64>::zeros((n, p_time));
1952        for i in 0..n {
1953            time_dq0[[i, 0]] = 1.0;
1954            time_dq0[[i, 1]] = x[i];
1955            time_dq0[[i, 2]] = x[i] * x[i];
1956            time_dq1[[i, 0]] = 1.0;
1957            time_dq1[[i, 1]] = x[i];
1958            time_dq1[[i, 2]] = x[i] * x[i];
1959            // d/dt of a constant = 0; d/dt of x ≡ 1; d/dt of x² ≡ 2x.
1960            time_dqd1[[i, 0]] = 0.0;
1961            time_dqd1[[i, 1]] = 1.0;
1962            time_dqd1[[i, 2]] = 2.0 * x[i];
1963        }
1964
1965        // Marginal block (q-channel only; qd1 contribution zero — no
1966        // timewiggle in this scenario):
1967        //   col 0 = ones (the shared constant);
1968        //   col 1 = x³;
1969        //   col 2 = sin(x).
1970        let mut marg_dq = Array2::<f64>::zeros((n, p_marginal));
1971        let marg_dqd1 = Array2::<f64>::zeros((n, p_marginal));
1972        for i in 0..n {
1973            marg_dq[[i, 0]] = 1.0;
1974            marg_dq[[i, 1]] = x[i] * x[i] * x[i];
1975            marg_dq[[i, 2]] = x[i].sin();
1976        }
1977
1978        // Logslope block (g-channel only):
1979        //   col 0 = cos(2x);
1980        //   col 1 = tanh(x).  (no shared constant — logslope is clean)
1981        let mut log_dg = Array2::<f64>::zeros((n, p_logslope));
1982        for i in 0..n {
1983            log_dg[[i, 0]] = (2.0 * x[i]).cos();
1984            log_dg[[i, 1]] = x[i].tanh();
1985        }
1986
1987        let inputs = build_survival_compiler_inputs(
1988            time_dq0, time_dq1, time_dqd1, marg_dq, marg_dqd1, log_dg, None, None,
1989        );
1990
1991        // Identity 4×4 row Hessian on every row. With H_i = I the
1992        // sqrt-H metric collapses to the standard Frobenius metric,
1993        // so the compiler's residualisation is ordinary least-squares
1994        // projection — exactly what we want for verifying the
1995        // structural rank-deficiency attribution.
1996        let mut h_full = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
1997        for i in 0..n {
1998            for k in 0..K_SURVIVAL {
1999                h_full[[i, k, k]] = 1.0;
2000            }
2001        }
2002        let row_hess = SurvivalRowHessian::from_full(h_full);
2003
2004        let compiled = compile(&inputs.operators, &row_hess, &inputs.ordering)
2005            .expect("survival 3-block compile must succeed; aliasing is single-direction");
2006
2007        // (1) One CompiledBlock per input.
2008        assert_eq!(compiled.blocks.len(), 3, "expected 3 CompiledBlocks");
2009
2010        // (2) Time enters first; under sqrt-I metric every column of
2011        // the time block is residual-vs-empty-anchor and therefore
2012        // survives the eigendecomposition with positive eigenvalue.
2013        // V_time has p_time columns.
2014        let v_time = &compiled.blocks[0].t_lw;
2015        assert_eq!(
2016            v_time.ncols(),
2017            p_time,
2018            "time block (first in ordering) must retain all {p_time} of its columns; V_time={:?}",
2019            v_time.dim(),
2020        );
2021
2022        // (3) Marginal enters second. Its constant column is aliased
2023        // with time's constant column in (q0, q1) and contributes zero
2024        // to qd1. After residualising against the time anchor in the
2025        // K=4 stacked metric, the residual Gram has rank
2026        // p_marginal − 1 (one direction collapsed by the alias). So
2027        // V_marginal has exactly (p_marginal − 1) columns.
2028        let v_marg = &compiled.blocks[1].t_lw;
2029        assert_eq!(
2030            v_marg.ncols(),
2031            p_marginal - 1,
2032            "marginal block must lose exactly the shared-constant direction; \
2033             V_marginal cols = {}, expected {}",
2034            v_marg.ncols(),
2035            p_marginal - 1,
2036        );
2037
2038        // (4) Logslope enters third and carries no shared direction
2039        // with time or marginal in the g-channel. Both columns survive.
2040        let v_log = &compiled.blocks[2].t_lw;
2041        assert_eq!(
2042            v_log.ncols(),
2043            p_logslope,
2044            "logslope block (no shared direction) must retain all {p_logslope} columns",
2045        );
2046
2047        // (5) Joint rank consistency: sum of compiled column counts
2048        // equals raw_total minus the one aliased direction.
2049        let raw_total = p_time + p_marginal + p_logslope;
2050        let kept_total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2051        assert_eq!(
2052            kept_total,
2053            raw_total - 1,
2054            "joint kept = raw_total − aliased; got {kept_total}, expected {}",
2055            raw_total - 1,
2056        );
2057        assert_eq!(
2058            compiled.joint_rank, kept_total,
2059            "CompiledBlocks::joint_rank must match the sum of per-block t_lw widths",
2060        );
2061
2062        // (6) SMGS construction contract. Feed the compiled per-block reduced
2063        // bases (V_k = t_lw, shaped raw_k × kept_k) into the SMGS lift `Gauge`
2064        // and verify the lift's coordinate bookkeeping matches the compiler's
2065        // rank attribution: the reduced dimension equals `joint_rank`, the
2066        // reduced block boundaries advance by each block's kept width, and —
2067        // with R = None (no residualised cross-block reparam in this V-only
2068        // construction) — the raw block boundaries advance by each block's raw
2069        // width. This exercises the SMGS construction hook directly on the
2070        // compiled output rather than asserting against a hypothetical shape.
2071        let v_per_term: Vec<Array2<f64>> = compiled.blocks.iter().map(|b| b.t_lw.clone()).collect();
2072        let r_per_term: Vec<Option<Array2<f64>>> = vec![None; v_per_term.len()];
2073        let gauge = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2074
2075        let mut expected_reduced = vec![0usize];
2076        let mut expected_raw = vec![0usize];
2077        for b in &compiled.blocks {
2078            let prev_reduced = *expected_reduced.last().unwrap();
2079            expected_reduced.push(prev_reduced + b.t_lw.ncols());
2080            let prev_raw = *expected_raw.last().unwrap();
2081            expected_raw.push(prev_raw + b.t_lw.nrows());
2082        }
2083        assert_eq!(
2084            *gauge.block_starts_reduced.last().unwrap(),
2085            compiled.joint_rank,
2086            "SMGS lift reduced dimension must equal the compiled joint_rank",
2087        );
2088        assert_eq!(
2089            gauge.block_starts_reduced, expected_reduced,
2090            "SMGS lift reduced block boundaries must match the compiled kept widths",
2091        );
2092        assert_eq!(
2093            gauge.block_starts_raw, expected_raw,
2094            "SMGS lift raw block boundaries must match the compiled per-block raw widths",
2095        );
2096
2097        // (7) Every kept direction is finite and non-degenerate. A retained
2098        // column with a zero or non-finite norm would be a spurious rank
2099        // contribution that the count-only checks above cannot catch, so verify
2100        // each compiled block's surviving directions directly.
2101        for (bi, block) in compiled.blocks.iter().enumerate() {
2102            for j in 0..block.t_lw.ncols() {
2103                let col = block.t_lw.column(j);
2104                assert!(
2105                    col.iter().all(|v| v.is_finite()),
2106                    "block {bi} kept direction {j} has a non-finite entry",
2107                );
2108                let norm = col.dot(&col).sqrt();
2109                assert!(
2110                    norm > 1e-10,
2111                    "block {bi} kept direction {j} is degenerate (norm {norm:.3e})",
2112                );
2113            }
2114        }
2115    }
2116
2117    /// `T = I` case: per-block V = identity, R = None. The triangular
2118    /// lift must be the identity on each block.
2119    #[test]
2120    fn smgs_lift_via_t_identity_passes_through() {
2121        let v0 = Array2::<f64>::eye(3);
2122        let v1 = Array2::<f64>::eye(2);
2123        let v_per_term = vec![v0, v1];
2124        let r_per_term: Vec<Option<Array2<f64>>> = vec![None, None];
2125        let lift = Gauge::from_v_and_r(&v_per_term, &r_per_term);
2126        assert_eq!(lift.t_full.dim(), (5, 5));
2127        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2128        assert_eq!(lift.block_starts_raw, vec![0, 3, 5]);
2129        for i in 0..5 {
2130            for j in 0..5 {
2131                let want = if i == j { 1.0 } else { 0.0 };
2132                assert!((lift.t_full[[i, j]] - want).abs() < 1e-14);
2133            }
2134        }
2135        let theta_0 = Array1::from(vec![1.0_f64, -2.0, 3.5]);
2136        let theta_1 = Array1::from(vec![-0.5_f64, 7.0]);
2137        let lifted = lift.lift_block_betas(&[theta_0.clone(), theta_1.clone()]);
2138        assert_eq!(lifted.len(), 2);
2139        for (a, b) in theta_0.iter().zip(lifted[0].iter()) {
2140            assert!((a - b).abs() < 1e-14);
2141        }
2142        for (a, b) in theta_1.iter().zip(lifted[1].iter()) {
2143            assert!((a - b).abs() < 1e-14);
2144        }
2145    }
2146
2147    /// Two-block toy: V_a = I_3, V_b drops the middle column, R is a
2148    /// non-trivial residualised reparam. Verify β_a_raw = θ_a − R · θ_b
2149    /// and β_b_raw = V_b · θ_b.
2150    #[test]
2151    fn smgs_lift_via_t_two_block_with_residualisation() {
2152        let v_a = Array2::<f64>::eye(3);
2153        let mut v_b = Array2::<f64>::zeros((3, 2));
2154        v_b[[0, 0]] = 1.0;
2155        v_b[[2, 1]] = 1.0;
2156        let mut r_b = Array2::<f64>::zeros((3, 2));
2157        r_b[[0, 0]] = 0.4;
2158        r_b[[0, 1]] = -0.1;
2159        r_b[[1, 0]] = 0.7;
2160        r_b[[1, 1]] = 1.3;
2161        r_b[[2, 0]] = -0.2;
2162        r_b[[2, 1]] = 0.5;
2163        let lift = Gauge::from_v_and_r(&[v_a.clone(), v_b.clone()], &[None, Some(r_b.clone())]);
2164        assert_eq!(lift.t_full.dim(), (6, 5));
2165        assert_eq!(lift.block_starts_reduced, vec![0, 3, 5]);
2166        assert_eq!(lift.block_starts_raw, vec![0, 3, 6]);
2167
2168        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2169        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2170        let lifted = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2171        let r_theta_b = r_b.dot(&theta_b);
2172        let expected_a = &theta_a - &r_theta_b;
2173        assert_eq!(lifted[0].len(), 3);
2174        for (got, want) in lifted[0].iter().zip(expected_a.iter()) {
2175            assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
2176        }
2177        assert_eq!(lifted[1].len(), 3);
2178        assert!((lifted[1][0] - theta_b[0]).abs() < 1e-12);
2179        assert!(lifted[1][1].abs() < 1e-12);
2180        assert!((lifted[1][2] - theta_b[1]).abs() < 1e-12);
2181    }
2182
2183    /// Covariance pushforward `Σ_raw = T · Σ_θ · Tᵀ` must be the exact
2184    /// inference companion of the point-estimate lift. Two invariants:
2185    ///
2186    /// 1. Identity T (V = I, R = None): the lifted covariance equals the
2187    ///    input covariance — a true no-op for a rank-clean fit.
2188    /// 2. Rank-1 consistency with the β lift: for a degenerate posterior
2189    ///    `Σ_θ = θ θᵀ`, the pushforward must equal `(T θ)(T θ)ᵀ`, i.e.
2190    ///    lifting the covariance of a point mass agrees with lifting the
2191    ///    point itself. This couples `lift_covariance` to
2192    ///    `lift_block_betas` exactly, so the mean and its
2193    ///    uncertainty can never drift into inconsistent coordinates.
2194    #[test]
2195    fn smgs_lift_covariance_identity_and_rank1_consistency() {
2196        // ── Invariant 1: identity T leaves the covariance unchanged. ──
2197        let lift_id = Gauge::from_v_and_r(
2198            &[Array2::<f64>::eye(2), Array2::<f64>::eye(2)],
2199            &[None, None],
2200        );
2201        let mut cov = Array2::<f64>::zeros((4, 4));
2202        // An arbitrary symmetric PSD-ish covariance.
2203        for i in 0..4 {
2204            for j in 0..4 {
2205                cov[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
2206            }
2207        }
2208        let lifted_id = lift_id.lift_covariance(&cov);
2209        assert_eq!(lifted_id.dim(), (4, 4));
2210        for i in 0..4 {
2211            for j in 0..4 {
2212                assert!(
2213                    (lifted_id[[i, j]] - cov[[i, j]]).abs() < 1e-12,
2214                    "identity-T covariance lift must be a no-op at [{i},{j}]",
2215                );
2216            }
2217        }
2218
2219        // ── Invariant 2: rank-1 Σ_θ = θθᵀ pushes to (Tθ)(Tθ)ᵀ. ──
2220        // Reuse the two-block-with-residualisation geometry: V_a = I_3,
2221        // V_b drops the middle raw column, R_b non-trivial → raw width 6,
2222        // compiled width 5.
2223        let v_a = Array2::<f64>::eye(3);
2224        let mut v_b = Array2::<f64>::zeros((3, 2));
2225        v_b[[0, 0]] = 1.0;
2226        v_b[[2, 1]] = 1.0;
2227        let mut r_b = Array2::<f64>::zeros((3, 2));
2228        r_b[[0, 0]] = 0.4;
2229        r_b[[0, 1]] = -0.1;
2230        r_b[[1, 0]] = 0.7;
2231        r_b[[1, 1]] = 1.3;
2232        r_b[[2, 0]] = -0.2;
2233        r_b[[2, 1]] = 0.5;
2234        let lift = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_b)]);
2235
2236        let theta_a = Array1::from(vec![1.0_f64, 2.0, -1.5]);
2237        let theta_b = Array1::from(vec![0.5_f64, -0.25]);
2238        // Concatenated compiled θ (width 5).
2239        let theta_full = Array1::from(vec![
2240            theta_a[0], theta_a[1], theta_a[2], theta_b[0], theta_b[1],
2241        ]);
2242        // Σ_θ = θ θᵀ (rank-1).
2243        let mut cov_rank1 = Array2::<f64>::zeros((5, 5));
2244        for i in 0..5 {
2245            for j in 0..5 {
2246                cov_rank1[[i, j]] = theta_full[i] * theta_full[j];
2247            }
2248        }
2249        let lifted_cov = lift.lift_covariance(&cov_rank1);
2250        // Reference: (T θ)(T θ)ᵀ via the point-estimate lift.
2251        let lifted_blocks = lift.lift_block_betas(&[theta_a, theta_b]);
2252        let beta_raw = Array1::from(
2253            lifted_blocks
2254                .iter()
2255                .flat_map(|b| b.iter().copied())
2256                .collect::<Vec<f64>>(),
2257        );
2258        assert_eq!(lifted_cov.dim(), (6, 6));
2259        assert_eq!(beta_raw.len(), 6);
2260        for i in 0..6 {
2261            for j in 0..6 {
2262                let want = beta_raw[i] * beta_raw[j];
2263                assert!(
2264                    (lifted_cov[[i, j]] - want).abs() < 1e-10,
2265                    "rank-1 covariance pushforward must equal (Tθ)(Tθ)ᵀ at [{i},{j}]: got {}, want {want}",
2266                    lifted_cov[[i, j]],
2267                );
2268            }
2269        }
2270        // Symmetry sanity.
2271        for i in 0..6 {
2272            for j in 0..6 {
2273                assert!((lifted_cov[[i, j]] - lifted_cov[[j, i]]).abs() < 1e-14);
2274            }
2275        }
2276    }
2277
2278    /// When all R's are None, the triangular gauge lift must equal the
2279    /// strictly per-block `V_b · θ_b` lift.
2280    #[test]
2281    fn smgs_lift_via_t_zero_r_matches_per_block_v_lift() {
2282        let mut v_a = Array2::<f64>::zeros((3, 2));
2283        v_a[[0, 0]] = 0.6;
2284        v_a[[1, 0]] = -0.8;
2285        v_a[[1, 1]] = 0.3;
2286        v_a[[2, 1]] = 0.9;
2287        let mut v_b = Array2::<f64>::zeros((4, 3));
2288        v_b[[0, 0]] = 1.0;
2289        v_b[[1, 1]] = -0.4;
2290        v_b[[2, 0]] = 0.2;
2291        v_b[[2, 2]] = 0.7;
2292        v_b[[3, 2]] = -1.1;
2293        let v_per_term = vec![v_a.clone(), v_b.clone()];
2294        let lift = Gauge::from_v_and_r(&v_per_term, &[None, None]);
2295        let theta_a = Array1::from(vec![0.3_f64, -1.4]);
2296        let theta_b = Array1::from(vec![2.1_f64, 0.0, -0.7]);
2297        let via_t = lift.lift_block_betas(&[theta_a.clone(), theta_b.clone()]);
2298        let ref_a = v_a.dot(&theta_a);
2299        let ref_b = v_b.dot(&theta_b);
2300        assert_eq!(via_t[0].len(), ref_a.len());
2301        for (g, w) in via_t[0].iter().zip(ref_a.iter()) {
2302            assert!((g - w).abs() < 1e-12);
2303        }
2304        assert_eq!(via_t[1].len(), ref_b.len());
2305        for (g, w) in via_t[1].iter().zip(ref_b.iter()) {
2306            assert!((g - w).abs() < 1e-12);
2307        }
2308    }
2309
2310    /// Recompile-after-first-PIRLS-accept refinement: under a structural
2311    /// (identity) row Hessian, a direction that is *only* identifiable
2312    /// through the q1 channel survives the per-term compile; under a
2313    /// data-adaptive row Hessian that happens to zero out the q1/qd1/g
2314    /// metric weight (everything except q0), the same direction collapses.
2315    /// This pins the diagnostic the production hook in
2316    /// `fit_survival_marginal_slope_terms` watches for: the two row
2317    /// Hessians produce different `drops_by_block` on identical raw
2318    /// designs.
2319    #[test]
2320    fn recompile_after_accept_diff_detection_pilot_curvature_trap() {
2321        let n = 6usize;
2322        // Time block: a single column that only contributes through q0
2323        // (entry-time channel). Both row Hessians see it identically on
2324        // the q0 axis.
2325        let time_dq0 = Array2::<f64>::from_elem((n, 1), 1.0);
2326        let time_dq1 = Array2::<f64>::zeros((n, 1));
2327        let time_dqd1 = Array2::<f64>::zeros((n, 1));
2328        // Marginal block: a single column whose q0 part is colinear with
2329        // the time block's q0 (both are ones-vectors). Its q-channel maps
2330        // into BOTH q0 and q1 under QChannelBlockOperator, so under a
2331        // metric that weighs q1 it carries a non-colinear component.
2332        let marg_dq = Array2::<f64>::from_elem((n, 1), 1.0);
2333        let marg_dqd1 = Array2::<f64>::zeros((n, 1));
2334        // No logslope columns.
2335        let log_dg = Array2::<f64>::zeros((n, 0));
2336        let mut time_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2337        time_partition.push(0..1);
2338        let mut marg_partition: Vec<std::ops::Range<usize>> = Vec::with_capacity(1);
2339        marg_partition.push(0..1);
2340        let log_partition: Vec<std::ops::Range<usize>> = Vec::new();
2341
2342        // Pass 1: structural identity row Hessian. q0/q1/qd1/g all weighted
2343        // equally → marg's q1 component is visible, so marg is identifiable
2344        // after residualising against the time block (drops_marg = 0).
2345        let mut h_ident = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2346        for i in 0..n {
2347            for k in 0..K_SURVIVAL {
2348                h_ident[[i, k, k]] = 1.0;
2349            }
2350        }
2351        let row_hess_ident = SurvivalRowHessian::from_full(h_ident);
2352        let compiled_ident = compile_survival_parametric_designs_per_term(
2353            time_dq0.clone(),
2354            time_dq1.clone(),
2355            time_dqd1.clone(),
2356            &time_partition,
2357            marg_dq.clone(),
2358            marg_dqd1.clone(),
2359            &marg_partition,
2360            log_dg.clone(),
2361            &log_partition,
2362            &row_hess_ident,
2363            false,
2364        )
2365        .expect("identity-H compile must succeed");
2366
2367        // Pass 2: data-adaptive row Hessian that only weighs q0 (all
2368        // other channel diagonals zero). Marg's q1 contribution is now
2369        // invisible → marg fully aliases with time on q0 → drops_marg = 1.
2370        let mut h_q0_only = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2371        for i in 0..n {
2372            h_q0_only[[i, 0, 0]] = 1.0;
2373        }
2374        let row_hess_q0 = SurvivalRowHessian::from_full(h_q0_only);
2375        let compiled_q0 = compile_survival_parametric_designs_per_term(
2376            time_dq0,
2377            time_dq1,
2378            time_dqd1,
2379            &time_partition,
2380            marg_dq,
2381            marg_dqd1,
2382            &marg_partition,
2383            log_dg,
2384            &log_partition,
2385            &row_hess_q0,
2386            false,
2387        )
2388        .expect("q0-only-H compile must succeed");
2389
2390        // The two drops_by_block tuples disagree on the marginal block —
2391        // this is exactly the "pilot-curvature trap" the recompile-after-
2392        // accept hook is designed to surface.
2393        assert_ne!(
2394            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2395            "structural-H and data-adaptive-H compiles must produce different \
2396             drops_by_block on the constructed pilot-curvature-trap design; \
2397             identity={:?} q0-only={:?}",
2398            compiled_ident.drops_by_block, compiled_q0.drops_by_block,
2399        );
2400        // Under identity H, marg survives (no drop).
2401        assert_eq!(
2402            compiled_ident.drops_by_block.1, 0,
2403            "identity-H marg drops expected 0, got {:?}",
2404            compiled_ident.drops_by_block,
2405        );
2406        // Under q0-only H, marg fully aliases with time on q0.
2407        assert_eq!(
2408            compiled_q0.drops_by_block.1, 1,
2409            "q0-only-H marg drops expected 1, got {:?}",
2410            compiled_q0.drops_by_block,
2411        );
2412    }
2413
2414    #[test]
2415    fn compiled_map_from_per_term_partitions_and_lift_round_trip() {
2416        // Build a per-term compile by hand: time has one term (raw 2, kept 2),
2417        // marginal one term (raw 2, kept 1 — a drop), logslope one term
2418        // (raw 1, kept 1). No required channel is fully collapsed.
2419        let v_time = Array2::<f64>::eye(2);
2420        let mut v_marg = Array2::<f64>::zeros((2, 1));
2421        v_marg[[0, 0]] = 1.0;
2422        v_marg[[1, 0]] = 0.5;
2423        let v_log = Array2::<f64>::eye(1);
2424        // R for the marginal block (anchor = time, raw width 2) and logslope
2425        // block (anchors = time + marginal, raw width 2 + 2 = 4).
2426        let r_marg = Array2::<f64>::from_shape_fn((2, 1), |(i, _)| 0.25 + i as f64);
2427        let r_log = Array2::<f64>::from_shape_fn((4, 1), |(i, _)| 0.1 * (i as f64 + 1.0));
2428        let per_term = SurvivalParametricCompiledPerTerm {
2429            v_time_per_term: vec![v_time.clone()],
2430            v_marginal_per_term: vec![v_marg.clone()],
2431            v_logslope_per_term: vec![v_log.clone()],
2432            r_lw_per_term: vec![None, Some(r_marg.clone()), Some(r_log.clone())],
2433            drops_by_block: (0, 1, 0),
2434        };
2435
2436        let map = compiled_map_from_per_term(&per_term);
2437
2438        // Raw block ranges: time 0..2, marginal 2..4, logslope 4..5.
2439        assert_eq!(map.raw_block_ranges, vec![0..2, 2..4, 4..5]);
2440        // Compiled block ranges: time 0..2, marginal 2..3, logslope 3..4.
2441        assert_eq!(map.compiled_block_ranges, vec![0..2, 2..3, 3..4]);
2442        assert_eq!(map.raw_from_compiled.dim(), (5, 4));
2443
2444        // The block-diagonal slices recovered by apply_compiled_map_to_designs
2445        // must equal the per-term V's exactly.
2446        let v_time_slice = map
2447            .raw_from_compiled
2448            .slice(ndarray::s![0..2, 0..2])
2449            .to_owned();
2450        let v_marg_slice = map
2451            .raw_from_compiled
2452            .slice(ndarray::s![2..4, 2..3])
2453            .to_owned();
2454        let v_log_slice = map
2455            .raw_from_compiled
2456            .slice(ndarray::s![4..5, 3..4])
2457            .to_owned();
2458        for i in 0..2 {
2459            for j in 0..2 {
2460                assert!((v_time_slice[[i, j]] - v_time[[i, j]]).abs() < 1e-14);
2461            }
2462            assert!((v_marg_slice[[i, 0]] - v_marg[[i, 0]]).abs() < 1e-14);
2463        }
2464        assert!((v_log_slice[[0, 0]] - v_log[[0, 0]]).abs() < 1e-14);
2465
2466        // The cross-block carry (-R) must sit in the strict upper triangle, so
2467        // the map agrees with the lift assembled directly from V and R.
2468        let ordering = [
2469            gam_identifiability::families::compiler::BlockOrder::Time,
2470            gam_identifiability::families::compiler::BlockOrder::Marginal,
2471            gam_identifiability::families::compiler::BlockOrder::Logslope,
2472        ];
2473        let lift_from_map = Gauge::from_compiled_map(&map, &ordering);
2474        let v_all = vec![v_time, v_marg, v_log];
2475        let lift_direct = Gauge::from_v_and_r(&v_all, &[None, Some(r_marg), Some(r_log)]);
2476        assert_eq!(lift_from_map.t_full.dim(), lift_direct.t_full.dim());
2477        for i in 0..lift_from_map.t_full.nrows() {
2478            for j in 0..lift_from_map.t_full.ncols() {
2479                assert!(
2480                    (lift_from_map.t_full[[i, j]] - lift_direct.t_full[[i, j]]).abs() < 1e-14,
2481                    "T mismatch at ({i},{j}): map={} direct={}",
2482                    lift_from_map.t_full[[i, j]],
2483                    lift_direct.t_full[[i, j]],
2484                );
2485            }
2486        }
2487    }
2488
2489    // ----- #979 effective reduced-logslope confound removal -----------------
2490    //
2491    // Direct unit coverage of the two numerical routines added for #979,
2492    // mirroring the BMS reference cuts
2493    // (`bms::block_specs` `effective_reduction_*`): the scalar weight
2494    // contraction off the per-row 4×4 Hessian, and the block-diagonal map
2495    // assembly. The 900s end-to-end `survival_marginal_slope_converges_*`
2496    // guard exercises the same path but is slow and data-dependent; these pin
2497    // the distinguishing logic deterministically.
2498
2499    /// Constant per-row 4×4 PSD Hessian carrying ONLY the (q0, g) coupling,
2500    /// channel order (q0, q1, qd1, g): `H[0,0]=h00`, `H[0,3]=H[3,0]=h03`,
2501    /// `H[3,3]=h33`, all else zero. The effective scalar weights the
2502    /// contraction reads are then `w_mm=h00`, `w_mg=h03`, `w_gg=h33`. The
2503    /// 2×2 (q0,g) block `[[h00,h03],[h03,h33]]` is PSD when `h00·h33 ≥ h03²`.
2504    fn const_row_hess_q0g(n: usize, h00: f64, h03: f64, h33: f64) -> SurvivalRowHessian {
2505        let mut h = Array3::<f64>::zeros((n, K_SURVIVAL, K_SURVIVAL));
2506        for i in 0..n {
2507            h[[i, 0, 0]] = h00;
2508            h[[i, 0, 3]] = h03;
2509            h[[i, 3, 0]] = h03;
2510            h[[i, 3, 3]] = h33;
2511        }
2512        SurvivalRowHessian::from_full(h)
2513    }
2514
2515    #[test]
2516    fn survival_reduced_logslope_drops_confounded_keeps_free_979() {
2517        // p_m=1 marginal column m; p_log=2 logslope columns [l1, l2] with
2518        // l1 == m (an exact rank-1 (q0,g) confound: h00·h33 = h03²) so l1 is
2519        // fully marginal-explained, and l2 ⊥ m with 100× the energy so it is
2520        // unambiguously free. The effective Schur Gram must drop ONLY the
2521        // confounded direction: 0 < r == 1 < p_log == 2.
2522        let n = 4;
2523        let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0); // (q0,g) = [[2,2],[2,2]], rank-1
2524        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2525        // l1 = m (confounded); l2 = [10,-10,10,-10] (Euclidean-orthogonal to m,
2526        // ‖l2‖² = 400 ≫ ‖m‖² = 4, so Gtt's free eigenvalue ≫ tol).
2527        let log =
2528            Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2529                .unwrap();
2530        let t = survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2531            .expect("contraction must succeed")
2532            .expect("a partial confound must yield a reduced transform");
2533        assert_eq!(t.dim(), (2, 1), "exactly one logslope direction survives");
2534        // The kept eigenvector is the free column ≈ e2 (up to sign); the
2535        // confounded e1 component is dropped.
2536        assert!(
2537            t[[0, 0]].abs() < 1e-6,
2538            "confounded (e1) direction must be dropped, got {}",
2539            t[[0, 0]]
2540        );
2541        assert!(
2542            (t[[1, 0]].abs() - 1.0).abs() < 1e-6,
2543            "free (e2) direction must be kept as a unit vector, got {}",
2544            t[[1, 0]]
2545        );
2546    }
2547
2548    #[test]
2549    fn survival_reduced_logslope_fully_confounded_returns_none_979() {
2550        // A single logslope column equal to the marginal column under the exact
2551        // rank-1 (q0,g) confound: the whole effective logslope image lies in the
2552        // marginal span. The conservative ridge floors the residual eigenvalue at
2553        // energy_scale·TOL/(1+TOL) < tol, so r == 0 → Ok(None) (keep the raw
2554        // design, defer to the measured-phantom gate).
2555        let n = 4;
2556        let row_hess = const_row_hess_q0g(n, 2.0, 2.0, 2.0);
2557        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2558        let log = marg.clone();
2559        let out =
2560            survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2561                .expect("contraction must succeed");
2562        assert!(
2563            out.is_none(),
2564            "a fully marginal-explained logslope column reduces to nothing → keep raw"
2565        );
2566    }
2567
2568    #[test]
2569    fn survival_reduced_logslope_no_confound_returns_none_979() {
2570        // No marginal↔logslope cross weight (h03 = 0): the channels are
2571        // W-orthogonal, so every logslope direction is free (r == p_log) and
2572        // there is nothing to remove → Ok(None).
2573        let n = 4;
2574        let row_hess = const_row_hess_q0g(n, 2.0, 0.0, 2.0);
2575        let marg = Array2::from_shape_vec((n, 1), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
2576        let log =
2577            Array2::from_shape_vec((n, 2), vec![1.0, 10.0, 1.0, -10.0, 1.0, 10.0, 1.0, -10.0])
2578                .unwrap();
2579        let out =
2580            survival_reduced_logslope_transform_effective(marg.view(), log.view(), &row_hess)
2581                .expect("contraction must succeed");
2582        assert!(out.is_none(), "W-orthogonal channels need no reduction → keep raw");
2583    }
2584
2585    #[test]
2586    fn survival_block_diagonal_logslope_map_is_identity_on_time_and_marginal_979() {
2587        // Time (p=2) and marginal (p=3) blocks pass through as identities; only
2588        // the logslope block (raw p_log=4) is reparameterised by t_log (4×2).
2589        let p_time = 2;
2590        let p_marg = 3;
2591        let t_log = Array2::from_shape_fn((4, 2), |(i, j)| 1.0 + (i * 2 + j) as f64);
2592        let map = survival_block_diagonal_logslope_map(p_time, p_marg, &t_log);
2593
2594        assert_eq!(map.raw_block_ranges, vec![0..2, 2..5, 5..9]);
2595        assert_eq!(map.compiled_block_ranges, vec![0..2, 2..5, 5..7]);
2596        assert_eq!(map.raw_from_compiled.dim(), (9, 7));
2597
2598        let t = &map.raw_from_compiled;
2599        // V_time = I2.
2600        for i in 0..p_time {
2601            for j in 0..p_time {
2602                let want = if i == j { 1.0 } else { 0.0 };
2603                assert!((t[[i, j]] - want).abs() < 1e-14, "V_time[{i},{j}]");
2604            }
2605        }
2606        // V_marg = I3.
2607        for i in 0..p_marg {
2608            for j in 0..p_marg {
2609                let want = if i == j { 1.0 } else { 0.0 };
2610                assert!((t[[p_time + i, p_time + j]] - want).abs() < 1e-14, "V_marg[{i},{j}]");
2611            }
2612        }
2613        // V_log = t_log.
2614        for i in 0..4 {
2615            for j in 0..2 {
2616                assert!(
2617                    (t[[p_time + p_marg + i, p_time + p_marg + j]] - t_log[[i, j]]).abs() < 1e-14,
2618                    "V_log[{i},{j}]"
2619                );
2620            }
2621        }
2622        // No cross-block bleed: the only nonzeros are the two identities and the
2623        // t_log block (every t_log entry here is nonzero).
2624        let nnz = t.iter().filter(|&&v| v != 0.0).count();
2625        assert_eq!(nnz, p_time + p_marg + t_log.iter().filter(|&&v| v != 0.0).count());
2626    }
2627}