Skip to main content

gam_identifiability/families/
compiler.rs

1//! Family-agnostic identifiability compiler.
2//!
3//! Single source of truth for cross-block W-metric residualisation across
4//! every blockwise family (BMS, SMGS, …). Row-Jacobian compiler that
5//! orthogonalises parameter blocks in the *row primary-state* metric `H_i`. Each block
6//! exposes a [`RowJacobianOperator`] that maps a coefficient perturbation
7//! `δβ ∈ R^p` to its contribution to the per-row primary state
8//! `u_i ∈ R^K`. The compiler walks the supplied ordering left-to-right,
9//! solves the weighted Gram system against the cumulative anchor, and
10//! emits a [`CompiledBlock`] per stage. A post-walk column-pivoted QR
11//! audit on the joint primary-state design deterministically drops
12//! trailing pivots from the latest block when joint rank is lost.
13
14use std::ops::Range;
15use std::sync::Arc;
16
17use ndarray::{Array1, Array2, Array3, Axis, s};
18
19use faer::Side;
20use gam_linalg::faer_ndarray::{
21    FaerEigh, default_rrqr_rank_alpha, fast_ab, fast_ata, fast_atb, fast_xt_diag_y,
22    rrqr_with_permutation,
23};
24
25/// Slack factor (multiples of machine ε) for the rank-revealing eigenvalue
26/// threshold used when pseudo-inverting a Gram matrix or selecting the
27/// positive eigenspace of a residual Gram. The retain threshold is
28/// `scale · RANK_REVEAL_EPS_SLACK · size · ε`, where `scale` is the dominant
29/// eigenvalue (and matrix size accounts for the worst-case roundoff
30/// accumulation in the `O(size)` inner products forming each Gram entry). 64×
31/// keeps numerically-zero directions out of the kept subspace while preserving
32/// every genuinely identified direction at large-scale conditioning.
33const RANK_REVEAL_EPS_SLACK: f64 = 64.0;
34
35/// Maps a coefficient perturbation `δβ ∈ R^p` for one parameter block into
36/// its contribution to the per-row primary state `u_i ∈ R^K`.
37///
38/// For affine blocks (everything in this compiler), `J_i = ∂u_i/∂β_block` is
39/// independent of `β` and equals the transposed row of the block's effective
40/// design matrix lifted into `R^K`.
41pub trait RowJacobianOperator: Send + Sync {
42    /// Dimension of the row primary state (survival: 4, Bernoulli: 1).
43    fn k(&self) -> usize;
44
45    /// Number of coefficients in this block (= width of `J_i`).
46    fn ncols(&self) -> usize;
47
48    /// Number of training rows.
49    fn nrows(&self) -> usize;
50
51    /// Apply the row Jacobian: writes `J_i · δβ ∈ R^K` for `row` into `out`.
52    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]);
53
54    /// Materialise the full operator as an `(n_rows × ncols × K)` tensor.
55    fn evaluate_full(&self) -> Array3<f64>;
56
57    /// Build the sqrt(H)-scaled design `W = stack_i sqrt(H_i) · J_i`, flattened
58    /// channel-major to `(n_rows·K × ncols)`.
59    ///
60    /// This is the representation the identifiability *compiler*
61    /// ([`compile_with_dual_metric`]) actually consumes — it residualises and
62    /// eigendecomposes Grams of `W`, and never indexes the per-row `(n, p, K)`
63    /// tensor element-wise. Requesting the scaled design directly lets an
64    /// operator with a structured / streaming form supply it without
65    /// materialising and cloning the whole `O(n·p·K)` tensor; the default
66    /// implementation routes through [`evaluate_full`] so existing operators
67    /// remain correct unchanged. (#738: a capability is not a representation —
68    /// the compiler asks for the scaled design it needs, not the dense tensor.)
69    ///
70    /// [`evaluate_full`]: RowJacobianOperator::evaluate_full
71    /// [`compile_with_dual_metric`]: crate::families::compiler::compile_with_dual_metric
72    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
73        scale_block_by_sqrt_h(&self.evaluate_full(), h_full)
74    }
75
76    /// Write the channel-flattened column `col` — the `(n_rows · K)` vector
77    /// whose entry `i·K + ch` is `J[i, col, ch]` — into `out`.
78    ///
79    /// This is the representation the identifiability *audit* actually consumes
80    /// (per-column leverage statistics and pairwise overlaps), as opposed to the
81    /// dense `(n, p, K)` tensor. Requesting a column directly lets an operator
82    /// that has a structured / streaming form supply it without materialising
83    /// and cloning the whole `O(n·p·K)` tensor on every audit pass; the default
84    /// implementation routes through [`evaluate_full`] so existing operators
85    /// remain correct unchanged. (#738: a capability is not a representation —
86    /// the audit asks for the column view it needs, not the tensor.)
87    ///
88    /// [`evaluate_full`]: RowJacobianOperator::evaluate_full
89    fn channel_flattened_column(&self, col: usize, out: &mut [f64]) {
90        let k = self.k();
91        let n = self.nrows();
92        assert!(
93            col < self.ncols(),
94            "channel_flattened_column col {col} out of range {}",
95            self.ncols()
96        );
97        assert_eq!(
98            out.len(),
99            n * k,
100            "channel_flattened_column out length {} != n*k = {}*{}",
101            out.len(),
102            n,
103            k
104        );
105        let full = self.evaluate_full();
106        for i in 0..n {
107            for ch in 0..k {
108                out[i * k + ch] = full[[i, col, ch]];
109            }
110        }
111    }
112
113    /// Write channel-flattened rows for `rows` into `out`.
114    ///
115    /// `out` has shape `(rows.len() * K, ncols)`, with row
116    /// `local_row * K + channel` holding `J[row, :, channel]`. The default
117    /// implementation materialises the full tensor for legacy operators; large
118    /// construction-time adapters override this to stream row chunks.
119    fn channel_flattened_rows(&self, rows: Range<usize>, out: &mut Array2<f64>) {
120        let n = self.nrows();
121        let start = rows.start.min(n);
122        let end = rows.end.min(n);
123        let chunk = end - start;
124        let k = self.k();
125        let p = self.ncols();
126        assert_eq!(out.shape(), &[chunk * k, p]);
127        let full = self.evaluate_full();
128        for local_i in 0..chunk {
129            let row = start + local_i;
130            for ch in 0..k {
131                for col in 0..p {
132                    out[[local_i * k + ch, col]] = full[[row, col, ch]];
133                }
134            }
135        }
136    }
137}
138
139/// Per-row `K × K` PSD Hessian of `−log L_i(u_i)` evaluated at a pilot β.
140pub trait RowHessian: Send + Sync {
141    fn k(&self) -> usize;
142    fn nrows(&self) -> usize;
143    /// Fill the `K × K` block at `row` into `out` (row-major).
144    fn fill_row(&self, row: usize, out: &mut [f64]);
145    /// Materialise full `(n_rows × K × K)` tensor.
146    fn evaluate_full(&self) -> Array3<f64>;
147}
148
149/// Identity row metric: `K^S_i = I_K` for every row. Default structural
150/// metric for [`compile_with_dual_metric`]. Decoupling the
151/// "which directions are real structural columns" decision from a
152/// possibly rank-deficient pilot curvature `H` prevents the compiler from
153/// wrongly dropping columns whose curvature happens to be zero at the
154/// pilot β but which would be kept at the optimum.
155pub struct IdentityRowHessian {
156    n: usize,
157    k: usize,
158}
159
160impl IdentityRowHessian {
161    /// Construct an identity row metric with `n` rows and `K`-channel
162    /// row primary state.
163    pub fn new(n: usize, k: usize) -> Self {
164        Self { n, k }
165    }
166}
167
168impl RowHessian for IdentityRowHessian {
169    fn k(&self) -> usize {
170        self.k
171    }
172    fn nrows(&self) -> usize {
173        self.n
174    }
175    fn fill_row(&self, row: usize, out: &mut [f64]) {
176        assert!(
177            row < self.n,
178            "IdentityRowHessian::fill_row row {row} out of range {n}",
179            n = self.n
180        );
181        assert_eq!(out.len(), self.k * self.k);
182        for i in 0..self.k {
183            for j in 0..self.k {
184                out[i * self.k + j] = if i == j { 1.0 } else { 0.0 };
185            }
186        }
187    }
188    fn evaluate_full(&self) -> Array3<f64> {
189        let mut out = Array3::<f64>::zeros((self.n, self.k, self.k));
190        for i in 0..self.n {
191            for c in 0..self.k {
192                out[[i, c, c]] = 1.0;
193            }
194        }
195        out
196    }
197}
198
199/// One compiled block: reparam matrix `V` (`t_lw`) and the optional anchor
200/// correction matrix `M` that downstream blocks consume as a first-class
201/// anchor.
202pub struct CompiledBlock {
203    /// Orthogonal-complement reparam matrix `V ∈ R^{p × p'}` (right-selector).
204    pub t_lw: Array2<f64>,
205    /// Residualised anchor correction `M ∈ R^{d_raw × p'}` at the compiled
206    /// width, expressed in *raw* cumulative-anchor-column coordinates: `d_raw`
207    /// is the sum of the raw column counts of every prior block, NOT the
208    /// (possibly smaller) count of kept anchor directions. The predict-time
209    /// row contribution is `(C(x)·V − A_raw(x)·M)·β`, where `A_raw(x)` is the
210    /// raw anchor evaluation. `None` for the first block in the ordering.
211    /// Synonymous with `r_lw`.
212    pub anchor_correction: Option<Array2<f64>>,
213    /// Residualised reparam `R_b = M_b · V_b` — what the residualised row
214    /// evaluator uses to subtract the anchor portion. `None` for the first
215    /// block in the ordering (no anchor). Equal to `anchor_correction`.
216    pub r_lw: Option<Array2<f64>>,
217}
218
219/// Output of [`compile`]: one [`CompiledBlock`] per input block plus the
220/// joint pre-fit audit verdict.
221pub struct CompiledBlocks {
222    pub blocks: Vec<CompiledBlock>,
223    /// Joint rank reported by the post-walk column-pivoted QR audit.
224    pub joint_rank: usize,
225    /// Columns deterministically dropped by the audit, as
226    /// `(block_idx, local_col)`. The audit drops only from the latest block.
227    pub dropped: Vec<(usize, usize)>,
228}
229
230/// Structural relationship between one raw penalized block and the higher-priority
231/// anchor already accepted by the identifiability compiler.
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum PenalizedDirectionAnnotationKind {
234    /// The block kept its full realized-design span; none of its penalized
235    /// directions were already represented by a higher-priority block.
236    Independent,
237    /// Some, but not all, raw directions were absorbed by the higher-priority
238    /// anchor. The kept width is the independent residual span.
239    PartiallyAbsorbedByHigherPriority,
240    /// The entire block was the same realized-design direction/span as the
241    /// higher-priority anchor and therefore contributes no independent
242    /// coefficients or smoothing parameter directions.
243    FullyAbsorbedByHigherPriority,
244}
245
246/// Per-block structural annotation emitted by [`orthogonalize_design_blocks`].
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PenalizedDirectionAnnotation {
249    pub block_idx: usize,
250    pub raw_width: usize,
251    pub kept_width: usize,
252    pub absorbed_width: usize,
253    pub kind: PenalizedDirectionAnnotationKind,
254}
255
256/// Errors raised by [`compile`].
257#[derive(Debug)]
258pub enum CompilerError {
259    /// Operator/Hessian/ordering dimensions are inconsistent.
260    DimensionMismatch(String),
261    /// A block degenerated to zero residual span — fully aliased by the
262    /// cumulative anchor in the row metric.
263    FullyAliased { block_idx: usize, reason: String },
264    /// A linear-algebra step failed (Gram solve, eigendecomposition, QR).
265    LinalgFailure(String),
266}
267
268impl std::fmt::Display for CompilerError {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        match self {
271            CompilerError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
272            CompilerError::FullyAliased { block_idx, reason } => {
273                write!(f, "block {block_idx} fully aliased: {reason}")
274            }
275            CompilerError::LinalgFailure(msg) => write!(f, "linalg failure: {msg}"),
276        }
277    }
278}
279
280impl std::error::Error for CompilerError {}
281
282/// Semantic block label. The compiler does not need to know what the block
283/// *is*, only its relative order — but downstream consumers (per-family
284/// install paths) tag the input operators with these labels so that the
285/// compiled output can be routed back to the right runtime slot.
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum BlockOrder {
288    Time,
289    Marginal,
290    Logslope,
291    ScoreWarp,
292    LinkDev,
293}
294
295/// Compile a sequence of row-Jacobian operators against a shared row
296/// Hessian. Walks `ordering` left-to-right, residualising each block
297/// against the cumulative anchor in the `H_i`-weighted row metric, then
298/// performs a joint-design audit and emits one [`CompiledBlock`] per
299/// input (in the same order as `operators`).
300///
301/// `ordering` parallels `operators` and supplies the semantic label for
302/// each block. The compiler treats `ordering[i]` purely as metadata —
303/// the *position* `i` is the residualisation order.
304pub fn compile(
305    operators: &[Arc<dyn RowJacobianOperator>],
306    row_hess: &dyn RowHessian,
307    ordering: &[BlockOrder],
308) -> Result<CompiledBlocks, CompilerError> {
309    compile_protected(operators, row_hess, ordering, &[])
310}
311
312/// Variant of [`compile`] that keeps designated blocks at full raw width.
313///
314/// `protected[b] == true` forces block `b` to retain every raw column,
315/// suppressing both the structural and curvature eigenspace drops for that
316/// block while still using it as a full-width anchor for later blocks. See
317/// [`compile_from_raw_grams_protected`] for the motivation: a block whose
318/// effective Jacobian is a fixed nonlinear functional basis (e.g. the survival
319/// marginal-slope time-wiggle block) cannot be expressed on a linearly reduced
320/// design, so it must not be reparameterised/dropped. `protected` may be
321/// shorter than `ordering`; an empty slice reproduces [`compile`] exactly.
322pub fn compile_protected(
323    operators: &[Arc<dyn RowJacobianOperator>],
324    row_hess: &dyn RowHessian,
325    ordering: &[BlockOrder],
326    protected: &[bool],
327) -> Result<CompiledBlocks, CompilerError> {
328    // Default structural metric is the per-row identity `K^S_i = I_K`.
329    // A pilot-curvature `H` can collapse a direction (zero eigenvalue) at
330    // a bad β even though the optimum keeps that direction; routing the
331    // rank decision through the structural metric and reserving `H` for
332    // *within-kept-subspace* curvature handling prevents that mis-drop.
333    let n = row_hess.nrows();
334    let k = row_hess.k();
335    let id_struct = IdentityRowHessian::new(n, k);
336    compile_with_dual_metric_protected(operators, row_hess, &id_struct, ordering, protected)
337}
338
339/// Compile a sequence of row-Jacobian operators using *separate* metrics
340/// for structural rank decisions and curvature-aware orthogonalisation.
341///
342/// - `row_hess` is the curvature row metric `K^H_i` (a PSD-clamped Hessian
343///   of `−log L_i(u_i)` at a pilot β).
344/// - `row_structural` is the structural row metric `K^S_i` — typically an
345///   [`IdentityRowHessian`] — used only to decide which columns survive
346///   block-against-block residualisation. A direction that the curvature
347///   `K^H` happens to see as zero at a bad pilot β is *not* dropped here
348///   as long as it is structurally non-degenerate.
349///
350/// Per-block algorithm (left-to-right walk over `ordering`):
351///
352/// 1. Residualise the block in the structural metric against the
353///    cumulative structural anchor; eigendecompose the structural residual
354///    Gram and drop only structural-zero eigenvalues → kept basis `D`
355///    (raw-block selector).
356/// 2. Residualise `W^H_b · D` in the curvature metric against the
357///    cumulative curvature anchor → curvature anchor correction
358///    `M^H_inner` and residual `R^H`.
359/// 3. Eigendecompose the curvature Gram of `R^H` and drop curvature-zero
360///    directions (a *within*-structurally-kept curvature alias is a true
361///    redundancy) → rotation/selector `T_inner`.
362/// 4. Compose: `V = D · T_inner`; compiled anchor correction is
363///    `M^H_inner · T_inner` so the predict-time row contribution stays
364///    `(C(x) · V − A(x) · anchor_correction) · β`.
365///
366/// When `row_structural` and `row_hess` represent the same metric (e.g.
367/// `compile()` with an identity row Hessian on both sides), the two
368/// passes collapse to the single-metric loop.
369pub fn compile_with_dual_metric(
370    operators: &[Arc<dyn RowJacobianOperator>],
371    row_hess: &dyn RowHessian,
372    row_structural: &dyn RowHessian,
373    ordering: &[BlockOrder],
374) -> Result<CompiledBlocks, CompilerError> {
375    compile_with_dual_metric_protected(operators, row_hess, row_structural, ordering, &[])
376}
377
378/// Variant of [`compile_with_dual_metric`] that keeps designated blocks at full
379/// raw width (see [`compile_protected`] / [`compile_from_raw_grams_protected`]
380/// for the motivation). `protected[b] == true` replaces block `b`'s structural
381/// and curvature eigenspace drops with identity, so the block emerges at full
382/// raw width while still anchoring later blocks. `protected` may be shorter
383/// than `ordering`; an empty slice reproduces [`compile_with_dual_metric`].
384pub fn compile_with_dual_metric_protected(
385    operators: &[Arc<dyn RowJacobianOperator>],
386    row_hess: &dyn RowHessian,
387    row_structural: &dyn RowHessian,
388    ordering: &[BlockOrder],
389    protected: &[bool],
390) -> Result<CompiledBlocks, CompilerError> {
391    if operators.len() != ordering.len() {
392        return Err(CompilerError::DimensionMismatch(format!(
393            "operators ({}) and ordering ({}) length mismatch",
394            operators.len(),
395            ordering.len()
396        )));
397    }
398    if operators.is_empty() {
399        return Ok(CompiledBlocks {
400            blocks: Vec::new(),
401            joint_rank: 0,
402            dropped: Vec::new(),
403        });
404    }
405
406    let k = row_hess.k();
407    let n = row_hess.nrows();
408    if row_structural.k() != k {
409        return Err(CompilerError::DimensionMismatch(format!(
410            "structural row metric has K={} but curvature row Hessian has K={k}",
411            row_structural.k()
412        )));
413    }
414    if row_structural.nrows() != n {
415        return Err(CompilerError::DimensionMismatch(format!(
416            "structural row metric has nrows={} but curvature row Hessian has nrows={n}",
417            row_structural.nrows()
418        )));
419    }
420    for (idx, op) in operators.iter().enumerate() {
421        if op.k() != k {
422            return Err(CompilerError::DimensionMismatch(format!(
423                "operator {idx} has K={} but row Hessian has K={k}",
424                op.k()
425            )));
426        }
427        if op.nrows() != n {
428            return Err(CompilerError::DimensionMismatch(format!(
429                "operator {idx} has nrows={} but row Hessian has nrows={n}",
430                op.nrows()
431            )));
432        }
433    }
434
435    // Materialise once per metric. K is tiny (1 or 4) so the K×K
436    // symmetric-sqrt cost is dominated by the joint-design audit below.
437    let h_full = row_hess.evaluate_full();
438    let s_full = row_structural.evaluate_full();
439
440    // Request each block's sqrt(H)-scaled design directly through the intent
441    // accessor — the `(n·K, p)` representation the compiler actually consumes —
442    // instead of first materialising the dense `(n, p, K)` per-row tensor and
443    // scaling it. The default `scaled_design_by_sqrt_h` impl still routes
444    // through `evaluate_full()`, so operators without a structured form stay
445    // correct unchanged; a streaming operator (e.g. `BlockJacobianAsRowOp`)
446    // overrides it to scale straight out of its stored layout, dropping the
447    // `O(n·p·K)` tensor clone that `evaluate_full()` performs per block at
448    // large-scale `n`. (#738: a capability is not a representation — the compiler
449    // asks for the scaled design it needs, never the dense tensor.)
450    let scaled_h: Vec<Array2<f64>> = operators
451        .iter()
452        .map(|op| op.scaled_design_by_sqrt_h(&h_full))
453        .collect();
454    let scaled_s: Vec<Array2<f64>> = operators
455        .iter()
456        .map(|op| op.scaled_design_by_sqrt_h(&s_full))
457        .collect();
458
459    let mut compiled: Vec<CompiledBlock> = Vec::with_capacity(operators.len());
460    // Demotions that happen *inside* the per-block walk (a structurally-kept
461    // block losing all its directions to a higher-priority anchor in the
462    // structural or curvature pass) are recorded here, one entry per demoted
463    // raw column, in the same `(block_idx, local_col)` convention that
464    // `audit_and_drop_trailing_pivots` emits at the joint-audit step. Without
465    // this, a zero-width demotion vanished from `dropped`, breaking the
466    // `kept_width + dropped_count == structural_pre_audit_width` accounting.
467    let mut walk_demotions: Vec<(usize, usize)> = Vec::new();
468    let mut anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
469    let mut anchor_s: Array2<f64> = Array2::zeros((n * k, 0));
470    // Cumulative *raw* (un-residualised) curvature-scaled anchor: the
471    // horizontal stack of `sqrt(H)·J_b` for every block already walked,
472    // keeping one column per raw block column. Where `anchor_h` carries the
473    // residualised, kept-direction anchor (its width shrinks whenever a block
474    // sheds an aliased column), this matrix keeps the full raw column count so
475    // the emitted `anchor_correction` can be expressed in raw-anchor-column
476    // coordinates — exactly the basis the predict-time subtraction
477    // `A_raw(x)·M` evaluates against. See the `M_raw` derivation below.
478    let mut raw_anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
479
480    for idx in 0..operators.len() {
481        let w_h = &scaled_h[idx];
482        let w_s = &scaled_s[idx];
483        let p_b = w_h.ncols();
484        let block_protected = protected.get(idx).copied().unwrap_or(false);
485
486        // A zero-width block owns no raw columns, so it cannot alias against any
487        // anchor and is trivially identifiable. Emit an empty compiled block and
488        // skip the structural/curvature passes: their residual Grams are 0×0 and
489        // yield no positive eigenspace, which the `anchor_h.ncols() == 0`
490        // first-block guards below would otherwise mis-report as `FullyAliased`
491        // even though there is nothing to alias. This mirrors the empty block a
492        // fully-absorbed later block compiles to, with no demotions to record
493        // (there are no columns) and no change to the running anchors.
494        if p_b == 0 {
495            compiled.push(CompiledBlock {
496                t_lw: Array2::<f64>::zeros((0, 0)),
497                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
498                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
499            });
500            continue;
501        }
502
503        // Pass 1 (structural): residualise W^S_b against cumulative
504        // structural anchor; eigendecompose the structural residual Gram
505        // and keep only directions with non-zero structural mass → D
506        // (raw-block selector).
507        // Only the structural residual is consumed downstream; the
508        // structural-metric correction M^S is intentionally discarded —
509        // predict-time subtraction uses the curvature metric correction
510        // (`M^H_inner` below), not the structural one.
511        let (residual_s, _) = residualise_in_metric(&anchor_s, w_s)?;
512        let g_s = fast_atb(&residual_s, &residual_s);
513        // Scale reference for the kept-eigenspace tolerance: the *original*
514        // (pre-residualisation) structural block Gram trace. A fully-absorbed
515        // block's residual collapses to ~ε² noise; anchoring tau to that would
516        // keep the noise directions and wrongly treat the block as
517        // structurally independent. The original-block trace is invariant to
518        // absorption, so a near-zero residual is rejected as fully absorbed.
519        let g_s_bb = fast_atb(w_s, w_s);
520        let g_s_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
521        // A protected block keeps every raw column: the structural residual
522        // eigenfilter is replaced by identity so no within-block direction is
523        // dropped. It still anchors later blocks at full raw width.
524        let d = if block_protected {
525            Array2::<f64>::eye(p_b)
526        } else {
527            keep_positive_eigenspace(&g_s, n, k, g_s_trace)?
528        };
529        if d.ncols() == 0 {
530            if anchor_h.ncols() == 0 {
531                return Err(CompilerError::FullyAliased {
532                    block_idx: idx,
533                    reason: format!(
534                        "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
535                    ),
536                });
537            }
538            compiled.push(CompiledBlock {
539                t_lw: Array2::<f64>::zeros((p_b, 0)),
540                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
541                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
542            });
543            // The structural pass fully absorbed all `p_b` raw columns into the
544            // higher-priority anchor: record each as a drop so the per-block
545            // width accounting (kept + dropped == raw width) stays exact.
546            for c in 0..p_b {
547                walk_demotions.push((idx, c));
548            }
549            raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
550            continue;
551        }
552
553        // Pass 2 (curvature): form W^H_b · D and residualise against the
554        // cumulative curvature anchor. Eigendecompose the curvature
555        // residual Gram and drop curvature-zero directions inside D →
556        // T_inner. A direction kept by the structural pass but degenerate
557        // here is genuinely curvature-redundant *within* the
558        // structurally-kept basis, so dropping it is correct.
559        let w_h_d = fast_ab(w_h, &d);
560        let (residual_h, m_h_inner_opt) = residualise_in_metric(&anchor_h, &w_h_d)?;
561        let g_h = fast_atb(&residual_h, &residual_h);
562        let p_d = d.ncols();
563        // Scale reference: the *unresidualised* curvature block Gram trace of
564        // `W^H_b · D` (the same convention the closed-form `compile_from_raw_grams`
565        // path uses with `d_t_kh_d`). Anchoring to the residual trace would
566        // collapse to ~ε² when the block is fully curvature-absorbed and keep
567        // its noise directions.
568        let g_h_dd = fast_atb(&w_h_d, &w_h_d);
569        let g_h_trace: f64 = (0..p_d).map(|i| g_h_dd[[i, i]].max(0.0)).sum();
570        // Protected block: retain every structurally-kept direction (identity
571        // curvature span) instead of dropping curvature-degenerate ones; its own
572        // penalty nullspace regularises the conditioning downstream.
573        let t_inner = if block_protected {
574            Array2::<f64>::eye(p_d)
575        } else {
576            keep_positive_eigenspace(&g_h, n, k, g_h_trace)?
577        };
578        if t_inner.ncols() == 0 {
579            if anchor_h.ncols() == 0 {
580                return Err(CompilerError::FullyAliased {
581                    block_idx: idx,
582                    reason: format!(
583                        "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {p_d}) before any anchor exists"
584                    ),
585                });
586            }
587            compiled.push(CompiledBlock {
588                t_lw: Array2::<f64>::zeros((p_b, 0)),
589                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
590                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
591            });
592            // The structural pass kept `p_d` directions, but the curvature pass
593            // absorbed all of them into the higher-priority anchor. Record each
594            // structurally-kept-but-curvature-demoted direction as a drop so the
595            // pre-audit structural width is fully accounted for.
596            for c in 0..p_d {
597                walk_demotions.push((idx, c));
598            }
599            raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
600            continue;
601        }
602
603        // Compose V = D · T_inner (raw-block → kept).
604        let v = fast_ab(&d, &t_inner);
605
606        // `m_h_inner_opt` was residualised against `anchor_h` as it stands
607        // *here*, i.e. the cumulative kept-direction anchor of all PRIOR
608        // blocks. Snapshot that pre-append anchor and its raw counterpart
609        // before this block's residual columns are appended below; the
610        // change-of-basis for this block's correction must be expressed
611        // against the prior-block anchor that `m` is indexed against, not the
612        // post-append anchor that already carries this block's own columns.
613        let prior_anchor_h = anchor_h.clone();
614        let prior_raw_anchor_h = raw_anchor_h.clone();
615
616        // Append residual-V columns to both cumulative anchors so future
617        // blocks see the structurally-orthogonal and curvature-orthogonal
618        // residual designs of this block, never the raw scaled block.
619        let residual_h_t = fast_ab(&residual_h, &t_inner);
620        anchor_h = concat_cols(&anchor_h, &residual_h_t);
621        // The structural anchor needs the structural-residual restricted
622        // to the kept directions: residual_s · v gives (W^S_b − A^S · M^S)·V.
623        let residual_s_v = fast_ab(&residual_s, &v);
624        anchor_s = concat_cols(&anchor_s, &residual_s_v);
625
626        // Compiled anchor correction lives in the curvature metric — the
627        // predict-time row contribution is `(C(x) · V − A(x) · M)·β`, where
628        // the subtraction makes residuals H-orthogonal at training and `A(x)`
629        // is the *raw* anchor evaluation (one column per raw anchor column).
630        //
631        // `m_h_inner_opt · t_inner` (call it `M_kept`) lives in the
632        // *kept-direction* anchor coordinates of the PRIOR-block anchor
633        // `prior_anchor_h` (the value `anchor_h` held when `m` was produced at
634        // `residualise_in_metric` above, before this block's residual columns
635        // were appended). Its row count is `prior_anchor_h.ncols()`, which
636        // equals the prior-block raw anchor width only when no upstream block
637        // shed an aliased column. The predict path multiplies by the raw
638        // anchor matrix `A_raw` (one column per raw anchor column of the prior
639        // blocks), so we must re-express `M_kept` in raw-anchor-column
640        // coordinates.
641        //
642        // `prior_anchor_h` and `prior_raw_anchor_h` span the same column space
643        // in the curvature metric (the residualisation/rotation only drops
644        // directions that lie inside that span), so there is an exact `Z` with
645        // `prior_raw_anchor_h · Z = prior_anchor_h`. Then
646        //   `prior_anchor_h · M_kept = prior_raw_anchor_h · (Z · M_kept)`,
647        // and the raw-coordinate correction is `M_raw = Z · M_kept`, with row
648        // count `prior_raw_anchor_h.ncols()` = the sum of prior raw anchor
649        // block widths. `Z = (Aᵀ A)⁺ Aᵀ prior_anchor_h` (with
650        // `A = prior_raw_anchor_h`) is the metric-exact least-squares change of
651        // basis (`solve_psd_system`).
652        let m_compiled = match m_h_inner_opt.as_ref() {
653            Some(m) => {
654                let m_kept = fast_ab(m, &t_inner);
655                if m_kept.nrows() != prior_anchor_h.ncols() {
656                    return Err(CompilerError::DimensionMismatch(format!(
657                        "anchor correction must be indexed by prior-block kept anchor directions: \
658                         m_kept has {} rows but prior_anchor_h has {} columns",
659                        m_kept.nrows(),
660                        prior_anchor_h.ncols()
661                    )));
662                }
663                let g_raw = fast_atb(&prior_raw_anchor_h, &prior_raw_anchor_h);
664                let z_rhs = fast_atb(&prior_raw_anchor_h, &prior_anchor_h);
665                let z = solve_psd_system(&g_raw, &z_rhs)?;
666                Some(fast_ab(&z, &m_kept))
667            }
668            None => None,
669        };
670        compiled.push(CompiledBlock {
671            t_lw: v,
672            anchor_correction: m_compiled.clone(),
673            r_lw: m_compiled,
674        });
675
676        // Append this block's raw curvature-scaled columns to the raw anchor
677        // accumulator so the *next* block's `M_raw` is expressed against the
678        // full raw column set of all blocks walked so far.
679        raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
680    }
681
682    // Joint-design audit on the curvature-scaled cumulative anchor: the
683    // identifiability question the fit cares about is curvature-rank.
684    let audit_dropped = audit_and_drop_trailing_pivots(&anchor_h, &mut compiled)?;
685    // Combine in-walk demotions (structural / curvature full absorption of a
686    // block) with the joint-audit trailing-pivot drops so `dropped` accounts
687    // for *every* column the compiler removed, not just the joint-audit ones.
688    let mut dropped = walk_demotions;
689    dropped.extend(audit_dropped);
690    let joint_rank: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
691
692    Ok(CompiledBlocks {
693        blocks: compiled,
694        joint_rank,
695        dropped,
696    })
697}
698
699/// Build `W_b = stack_i sqrt(H_i) · J_b,i` flattened to `(n*K, ncols)` from a
700/// materialised `(n, p, K)` tensor. Thin wrapper over
701/// [`scale_jacobian_by_sqrt_h_with`] that reads the tensor element-wise.
702fn scale_block_by_sqrt_h(jb: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
703    let n = jb.shape()[0];
704    let p = jb.shape()[1];
705    let k = jb.shape()[2];
706    scale_jacobian_by_sqrt_h_with(n, p, k, h_full, |i, a, c| jb[[i, a, c]])
707}
708
709/// Build `W_b = stack_i sqrt(H_i) · J_b,i` flattened to `(n*K, ncols)` without
710/// ever requiring a materialised `(n, p, K)` tensor.
711///
712/// The Jacobian entries are pulled through the `jac` closure
713/// (`jac(i, a, c) = J_b,i[a, c]`), so a structured operator that stores its
714/// Jacobian in a compact / streaming form can supply the sqrt(H)-scaled design
715/// directly — the representation the compiler actually consumes — rather than
716/// being forced to clone a dense `(n, p, K)` tensor first. (#738: a capability
717/// is not a representation — the compiler asks for the scaled `(n·K, p)` design
718/// it needs, not the dense per-row tensor.)
719///
720/// `K` is tiny (1 or 4), so the per-row symmetric sqrt is negligible relative
721/// to the overall compile.
722pub fn scale_jacobian_by_sqrt_h_with(
723    n: usize,
724    p: usize,
725    k: usize,
726    h_full: &Array3<f64>,
727    jac: impl Fn(usize, usize, usize) -> f64,
728) -> Array2<f64> {
729    assert_eq!(h_full.shape(), &[n, k, k]);
730    let mut out = Array2::<f64>::zeros((n * k, p));
731    let mut sqrt_h = Array2::<f64>::zeros((k, k));
732    let mut scratch_jrow = Array2::<f64>::zeros((p, k));
733    for i in 0..n {
734        // Symmetric square root of H_i via eigendecomposition.
735        let h_i = h_full.index_axis(Axis(0), i).to_owned();
736        sqrt_h.fill(0.0);
737        symmetric_sqrt_into(&h_i, &mut sqrt_h);
738        // scratch_jrow[a, c] = J_b,i[a, c] (transpose-friendly layout for
739        // the GEMV below: we want (p × k) · (k,) = (p,) for each column of
740        // sqrt_h, but we batch by writing out[(i*k+c), a] = (sqrt_h · J_b,iᵀ)[c, a].
741        for a in 0..p {
742            for c in 0..k {
743                scratch_jrow[[a, c]] = jac(i, a, c);
744            }
745        }
746        for c in 0..k {
747            for a in 0..p {
748                let mut acc = 0.0;
749                for cp in 0..k {
750                    acc += sqrt_h[[c, cp]] * scratch_jrow[[a, cp]];
751                }
752                out[[i * k + c, a]] = acc;
753            }
754        }
755    }
756    out
757}
758
759/// Symmetric matrix square root via eigendecomposition with negative
760/// eigenvalues clamped to zero (PSD projection guard).
761pub(crate) fn symmetric_sqrt_into(m: &Array2<f64>, out: &mut Array2<f64>) {
762    let k = m.nrows();
763    assert_eq!(m.ncols(), k);
764    assert_eq!(out.shape(), &[k, k]);
765    if k == 1 {
766        out[[0, 0]] = m[[0, 0]].max(0.0).sqrt();
767        return;
768    }
769    let (evals, evecs) = match m.eigh(Side::Lower) {
770        Ok(pair) => pair,
771        Err(_) => {
772            // Fall back to clipped diagonal — extremely defensive for the
773            // K=4 row Hessian which is already PSD-clamped by the caller.
774            out.fill(0.0);
775            for i in 0..k {
776                out[[i, i]] = m[[i, i]].max(0.0).sqrt();
777            }
778            return;
779        }
780    };
781    // out = U · diag(sqrt(max(0, λ))) · Uᵀ
782    let mut scaled = evecs.clone();
783    for j in 0..k {
784        let s = evals[j].max(0.0).sqrt();
785        for i in 0..k {
786            scaled[[i, j]] *= s;
787        }
788    }
789    out.assign(&fast_atb(&evecs.t().to_owned(), &scaled.t().to_owned()));
790    // The above fast_atb computed (Uᵀ)ᵀ · (Uᵀ·diag(s)) = U · diag(s) · Uᵀ
791    // when the inputs are owned. To be safe and avoid layout surprises,
792    // re-do the small multiplication explicitly for K ≤ 4.
793    out.fill(0.0);
794    for i in 0..k {
795        for j in 0..k {
796            let mut acc = 0.0;
797            for l in 0..k {
798                acc += evecs[[i, l]] * evals[l].max(0.0).sqrt() * evecs[[j, l]];
799            }
800            out[[i, j]] = acc;
801        }
802    }
803}
804
805/// Solve `Aᵀ A · M = Aᵀ B` and return `(B − A·M, Some(M))`. With `A`
806/// having zero columns, returns `(B, None)` — the first block needs no
807/// anchor correction.
808fn residualise_in_metric(
809    a_scaled: &Array2<f64>,
810    b_scaled: &Array2<f64>,
811) -> Result<(Array2<f64>, Option<Array2<f64>>), CompilerError> {
812    let d = a_scaled.ncols();
813    if d == 0 {
814        return Ok((b_scaled.clone(), None));
815    }
816    let g_aa = fast_atb(a_scaled, a_scaled);
817    let g_ab = fast_atb(a_scaled, b_scaled);
818    let m = solve_psd_system(&g_aa, &g_ab)?;
819    let a_m = fast_ab(a_scaled, &m);
820    let residual = b_scaled - &a_m;
821    Ok((residual, Some(m)))
822}
823
824/// Solve a PSD linear system `G · M = R` for `M`. Tries the eigen-based
825/// pseudoinverse with a relative threshold and falls back to a damped
826/// solve if the spectrum is ill-conditioned beyond what the threshold
827/// can clean.
828fn solve_psd_system(g: &Array2<f64>, r: &Array2<f64>) -> Result<Array2<f64>, CompilerError> {
829    let n = g.nrows();
830    if n == 0 {
831        return Ok(Array2::zeros((0, r.ncols())));
832    }
833    let (evals, evecs) = g
834        .eigh(Side::Lower)
835        .map_err(|err| CompilerError::LinalgFailure(format!("Gram eigh failed: {err:?}")))?;
836    let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
837    let tol = lambda_max * RANK_REVEAL_EPS_SLACK * (n.max(1) as f64) * f64::EPSILON;
838    // M = U · diag(1/λ_kept) · Uᵀ · R
839    let u_t_r = fast_atb(&evecs, r);
840    let mut scaled = u_t_r.clone();
841    for i in 0..n {
842        let lam = evals[i];
843        let inv = if lam > tol { 1.0 / lam } else { 0.0 };
844        for j in 0..scaled.ncols() {
845            scaled[[i, j]] *= inv;
846        }
847    }
848    let m = fast_ab(&evecs, &scaled);
849    Ok(m)
850}
851
852/// Eigendecompose the residual Gram `G̃` and return `V` made of the
853/// eigenvectors whose eigenvalues exceed
854/// `τ = max(λ_max(G̃), tr(G_BB)) · RANK_REVEAL_EPS_SLACK · n · K · ε`.
855fn keep_positive_eigenspace(
856    g_tilde: &Array2<f64>,
857    n: usize,
858    k: usize,
859    g_bb_trace: f64,
860) -> Result<Array2<f64>, CompilerError> {
861    let p = g_tilde.nrows();
862    if p == 0 {
863        return Ok(Array2::zeros((0, 0)));
864    }
865    let (evals, evecs) = g_tilde.eigh(Side::Lower).map_err(|err| {
866        CompilerError::LinalgFailure(format!("residual Gram eigh failed: {err:?}"))
867    })?;
868    let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
869    let scale = lambda_max.max(g_bb_trace);
870    let nk = (n.saturating_mul(k)).max(p).max(1) as f64;
871    let tau = scale * RANK_REVEAL_EPS_SLACK * nk * f64::EPSILON;
872    // Collect kept column indices.
873    let mut kept: Vec<usize> = (0..p).filter(|&i| evals[i] > tau).collect();
874    // Sort kept indices by descending eigenvalue for a stable column order.
875    kept.sort_by(|&a, &b| {
876        evals[b]
877            .partial_cmp(&evals[a])
878            .unwrap_or(std::cmp::Ordering::Equal)
879    });
880    let mut v = Array2::<f64>::zeros((p, kept.len()));
881    for (out_col, &src_col) in kept.iter().enumerate() {
882        for row in 0..p {
883            v[[row, out_col]] = evecs[[row, src_col]];
884        }
885    }
886    Ok(v)
887}
888
889/// Concatenate two matrices column-wise. Both must have the same row count.
890fn concat_cols(left: &Array2<f64>, right: &Array2<f64>) -> Array2<f64> {
891    let nrows = left.nrows().max(right.nrows());
892    let lc = left.ncols();
893    let rc = right.ncols();
894    let mut out = Array2::<f64>::zeros((nrows, lc + rc));
895    if lc > 0 {
896        out.slice_mut(s![.., ..lc]).assign(left);
897    }
898    if rc > 0 {
899        out.slice_mut(s![.., lc..]).assign(right);
900    }
901    out
902}
903
904/// Post-walk audit: column-pivoted QR on the cumulative scaled design.
905/// If rank < p_total, deterministically drop trailing pivots from the
906/// latest block's `V`. Earlier blocks are never modified.
907fn audit_and_drop_trailing_pivots(
908    w_joint: &Array2<f64>,
909    compiled: &mut [CompiledBlock],
910) -> Result<Vec<(usize, usize)>, CompilerError> {
911    let p_total: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
912    if p_total == 0 || w_joint.nrows() == 0 {
913        return Ok(Vec::new());
914    }
915
916    // RRQR rank with the codebase's default α.
917    let rrqr = rrqr_with_permutation(w_joint, default_rrqr_rank_alpha())
918        .map_err(|err| CompilerError::LinalgFailure(format!("audit RRQR failed: {err:?}")))?;
919    let rank = rrqr.rank;
920    if rank >= p_total {
921        return Ok(Vec::new());
922    }
923
924    // Trailing pivots are the redundant columns. Attribute every demoted
925    // global column to the *latest* block by truncating its V; earlier
926    // blocks keep their full V. The demoted suffix is sorted only by
927    // pivot order, but we drop deterministically: take the count of
928    // demoted columns and truncate that many trailing columns of the
929    // latest block.
930    let drop_count = p_total - rank;
931    let latest_idx = compiled.len() - 1;
932    let latest = &mut compiled[latest_idx];
933    let kept_local = latest.t_lw.ncols().saturating_sub(drop_count);
934    let dropped_locals: Vec<(usize, usize)> = (kept_local..latest.t_lw.ncols())
935        .map(|c| (latest_idx, c))
936        .collect();
937    // Truncate ALL kept-direction-indexed matrices in lockstep so the
938    // shape contract (`anchor_correction: d_total × k_kept`, `r_lw:
939    // d_total × k_kept`, `t_lw: p_raw × k_kept`) holds after the audit
940    // drops trailing pivots. Forgetting these two left
941    // `anchor_correction.ncols() == pre_truncation_k_kept` while
942    // `t_lw.ncols() == post_truncation_k_kept`, surfaced downstream as
943    // `cross-block identifiability: anchor_correction shape D×P does
944    // not match expected d_total=D × k_kept=K`.
945    latest.t_lw = latest.t_lw.slice(s![.., ..kept_local]).to_owned();
946    if let Some(m) = latest.anchor_correction.as_ref() {
947        latest.anchor_correction = Some(m.slice(s![.., ..kept_local]).to_owned());
948    }
949    if let Some(r) = latest.r_lw.as_ref() {
950        latest.r_lw = Some(r.slice(s![.., ..kept_local]).to_owned());
951    }
952    Ok(dropped_locals)
953}
954
955/// Channel-pair decomposition of every parameter block's row Jacobian.
956///
957/// For families with `K` primary-state channels (survival: K=4), each block
958/// `b` contributes a (n × p_b) channel matrix `X_b^(c)` per channel `c` that
959/// it touches. Blocks that do not contribute to a channel store `None` in
960/// that slot. The closed-form Gram compiler consumes this view directly to
961/// build the joint Gram `K^H` without ever materialising the full
962/// `(n·K) × p_total` weighted design `W = sqrt(H) · J`.
963pub struct PrimaryChannelBlocks {
964    /// Outer index: block. Inner index: channel `c ∈ 0..K`. `None` means the
965    /// block does not contribute to that channel.
966    pub blocks: Vec<Vec<Option<Array2<f64>>>>,
967}
968
969/// Closed-form Gram builder: `K^H[a, b] = Σ_{c,d} (X_a^(c))ᵀ · diag(h_{cd}) · X_b^(d)`.
970///
971/// Inputs:
972/// - `channel_blocks`: per-block channel decomposition of the row Jacobian.
973/// - `row_hess`: `(n × K × K)` per-row PSD Hessian (typically clamped to PSD
974///   by the family upstream).
975/// - `raw_block_ranges`: `[start, end)` column ranges of each block inside
976///   the full `p_total`-wide coefficient vector. Must be contiguous and
977///   non-overlapping; their union spans `0..p_total`.
978///
979/// Returns the symmetric `(p_total × p_total)` Gram matrix.
980pub fn build_raw_grams_from_channel_blocks(
981    channel_blocks: &PrimaryChannelBlocks,
982    row_hess: &dyn RowHessian,
983    raw_block_ranges: &[std::ops::Range<usize>],
984) -> Result<Array2<f64>, CompilerError> {
985    let num_blocks = channel_blocks.blocks.len();
986    if num_blocks != raw_block_ranges.len() {
987        return Err(CompilerError::DimensionMismatch(format!(
988            "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
989            raw_block_ranges.len()
990        )));
991    }
992    if num_blocks == 0 {
993        return Ok(Array2::<f64>::zeros((0, 0)));
994    }
995    let k = row_hess.k();
996    let n = row_hess.nrows();
997    let p_total: usize = raw_block_ranges.iter().map(|r| r.end - r.start).sum();
998    let expected_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
999    if expected_total != p_total {
1000        return Err(CompilerError::DimensionMismatch(format!(
1001            "raw_block_ranges must be contiguous from 0; got p_total={p_total} but last end={expected_total}"
1002        )));
1003    }
1004    // Per-block channel-slot shape sanity.
1005    for (b, slots) in channel_blocks.blocks.iter().enumerate() {
1006        if slots.len() != k {
1007            return Err(CompilerError::DimensionMismatch(format!(
1008                "block {b}: expected {k} channel slots, got {}",
1009                slots.len()
1010            )));
1011        }
1012        let p_b = raw_block_ranges[b].end - raw_block_ranges[b].start;
1013        for (c, mat) in slots.iter().enumerate() {
1014            if let Some(x) = mat.as_ref() {
1015                if x.nrows() != n {
1016                    return Err(CompilerError::DimensionMismatch(format!(
1017                        "block {b} channel {c}: nrows={} but row Hessian nrows={n}",
1018                        x.nrows()
1019                    )));
1020                }
1021                if x.ncols() != p_b {
1022                    return Err(CompilerError::DimensionMismatch(format!(
1023                        "block {b} channel {c}: ncols={} but block width={p_b}",
1024                        x.ncols()
1025                    )));
1026                }
1027            }
1028        }
1029    }
1030
1031    // Materialise H once and slice it into K·K length-n vectors h_{cd}.
1032    let h_full = row_hess.evaluate_full();
1033    if h_full.shape() != &[n, k, k] {
1034        return Err(CompilerError::DimensionMismatch(format!(
1035            "row Hessian evaluate_full shape {:?} != [n={n}, k={k}, k={k}]",
1036            h_full.shape()
1037        )));
1038    }
1039    // h_pairs[c * k + d] = length-n vector of H_i[c, d].
1040    let mut h_pairs: Vec<Array1<f64>> = Vec::with_capacity(k * k);
1041    for c in 0..k {
1042        for d in 0..k {
1043            let mut v = Array1::<f64>::zeros(n);
1044            for i in 0..n {
1045                v[i] = h_full[[i, c, d]];
1046            }
1047            h_pairs.push(v);
1048        }
1049    }
1050
1051    let mut gram = Array2::<f64>::zeros((p_total, p_total));
1052    // Accumulate upper triangle (a ≤ b) then symmetrise.
1053    for a in 0..num_blocks {
1054        let range_a = raw_block_ranges[a].clone();
1055        for b in a..num_blocks {
1056            let range_b = raw_block_ranges[b].clone();
1057            let mut block_acc =
1058                Array2::<f64>::zeros((range_a.end - range_a.start, range_b.end - range_b.start));
1059            for c in 0..k {
1060                let Some(x_a_c) = channel_blocks.blocks[a][c].as_ref() else {
1061                    continue;
1062                };
1063                for d in 0..k {
1064                    let Some(x_b_d) = channel_blocks.blocks[b][d].as_ref() else {
1065                        continue;
1066                    };
1067                    let h_cd = &h_pairs[c * k + d];
1068                    // (X_a^(c))ᵀ · diag(h_cd) · X_b^(d)  →  (p_a × p_b).
1069                    let contrib = fast_xt_diag_y(x_a_c, h_cd, x_b_d);
1070                    block_acc += &contrib;
1071                }
1072            }
1073            // Write into upper triangle (and the diagonal block itself).
1074            gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1075                .assign(&block_acc);
1076        }
1077    }
1078    // Symmetrise: copy upper triangle to lower. Diagonal blocks are
1079    // themselves p_a × p_a — symmetrise within them too.
1080    for i in 0..p_total {
1081        for j in 0..i {
1082            let v = gram[[j, i]];
1083            gram[[i, j]] = v;
1084        }
1085    }
1086    Ok(gram)
1087}
1088
1089/// Structural Gram `K^S`: same shape as [`build_raw_grams_from_channel_blocks`]
1090/// but with the per-row Hessian replaced by the K×K identity. Used by the
1091/// dual-metric compiler as the un-weighted reference geometry.
1092///
1093/// `K^S[a, b] = Σ_c (X_a^(c))ᵀ · X_b^(c)` (cross-channel terms vanish under
1094/// `H_i = I_K`).
1095pub fn build_raw_grams_structural(
1096    channel_blocks: &PrimaryChannelBlocks,
1097    raw_block_ranges: &[std::ops::Range<usize>],
1098) -> Array2<f64> {
1099    let num_blocks = channel_blocks.blocks.len();
1100    assert_eq!(
1101        num_blocks,
1102        raw_block_ranges.len(),
1103        "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
1104        raw_block_ranges.len()
1105    );
1106    if num_blocks == 0 {
1107        return Array2::<f64>::zeros((0, 0));
1108    }
1109    let p_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1110    let mut gram = Array2::<f64>::zeros((p_total, p_total));
1111    for a in 0..num_blocks {
1112        let range_a = raw_block_ranges[a].clone();
1113        for b in a..num_blocks {
1114            let range_b = raw_block_ranges[b].clone();
1115            let p_a = range_a.end - range_a.start;
1116            let p_b = range_b.end - range_b.start;
1117            let k_a = channel_blocks.blocks[a].len();
1118            let k_b = channel_blocks.blocks[b].len();
1119            assert_eq!(
1120                k_a, k_b,
1121                "structural Gram: block {a} has {k_a} channels but block {b} has {k_b}",
1122            );
1123            let mut block_acc = Array2::<f64>::zeros((p_a, p_b));
1124            for c in 0..k_a {
1125                let (Some(x_a_c), Some(x_b_c)) = (
1126                    channel_blocks.blocks[a][c].as_ref(),
1127                    channel_blocks.blocks[b][c].as_ref(),
1128                ) else {
1129                    continue;
1130                };
1131                let contrib = if a == b {
1132                    // Diagonal block, same channel — symmetric XᵀX.
1133                    fast_ata(x_a_c)
1134                } else {
1135                    fast_atb(x_a_c, x_b_c)
1136                };
1137                block_acc += &contrib;
1138            }
1139            gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1140                .assign(&block_acc);
1141        }
1142    }
1143    for i in 0..p_total {
1144        for j in 0..i {
1145            let v = gram[[j, i]];
1146            gram[[i, j]] = v;
1147        }
1148    }
1149    gram
1150}
1151
1152/// Build the primary-state curvature Gram `K^H` and structural Gram `K^S`
1153/// for a block decomposition, preferring the device (GPU) path when
1154/// available and falling back to the CPU closed-form builders otherwise.
1155///
1156/// The GPU path is only attempted for survival-family geometry
1157/// (`K = CHANNELS = 4`) — that is the case the GPU kernel
1158/// ([`crate::families::gpu::try_primary_state_gram_cuda`])
1159/// is specialised for via the packed-symmetric `n × 10` weight layout.
1160/// For any other `K` the CPU builders are used unconditionally.
1161///
1162/// Returns `(gram_h, gram_struct)` with the same shape and semantics as
1163/// [`build_raw_grams_from_channel_blocks`] + [`build_raw_grams_structural`].
1164pub fn build_primary_grams_gpu_or_cpu(
1165    channel_blocks: &PrimaryChannelBlocks,
1166    row_hess: &dyn RowHessian,
1167    raw_block_ranges: &[std::ops::Range<usize>],
1168) -> Result<(Array2<f64>, Array2<f64>), CompilerError> {
1169    let k = row_hess.k();
1170    if k == crate::families::gpu::CHANNELS {
1171        let gpu_blocks: Vec<Vec<Option<Array2<f64>>>> = channel_blocks
1172            .blocks
1173            .iter()
1174            .map(|slots| slots.iter().cloned().collect())
1175            .collect();
1176        if let Some(h_packed) = pack_row_hessian_symmetric(row_hess) {
1177            if let Some(bundle) = crate::families::gpu::try_primary_state_gram_cuda(
1178                &gpu_blocks,
1179                &h_packed,
1180                raw_block_ranges,
1181            ) {
1182                log::info!("[identifiability_compile] gram path = gpu");
1183                return Ok((bundle.gram_h, bundle.gram_struct));
1184            }
1185        }
1186    }
1187    log::info!("[identifiability_compile] gram path = cpu");
1188    let gram_h = build_raw_grams_from_channel_blocks(channel_blocks, row_hess, raw_block_ranges)?;
1189    let gram_struct = build_raw_grams_structural(channel_blocks, raw_block_ranges);
1190    Ok((gram_h, gram_struct))
1191}
1192
1193/// Pack a per-row symmetric `K = 4` Hessian into the `n × 10`
1194/// upper-triangular row-major layout consumed by the GPU kernel
1195/// (`packed_index(c, d)` for `c ≤ d`). Returns `None` when `K != 4`.
1196fn pack_row_hessian_symmetric(row_hess: &dyn RowHessian) -> Option<Array2<f64>> {
1197    use crate::families::gpu::{CHANNELS, PACKED_LEN, packed_index};
1198    if row_hess.k() != CHANNELS {
1199        return None;
1200    }
1201    let n = row_hess.nrows();
1202    let h_full = row_hess.evaluate_full();
1203    if h_full.shape() != [n, CHANNELS, CHANNELS] {
1204        return None;
1205    }
1206    let mut packed = Array2::<f64>::zeros((n, PACKED_LEN));
1207    for i in 0..n {
1208        for c in 0..CHANNELS {
1209            for d in c..CHANNELS {
1210                packed[[i, packed_index(c, d)]] = h_full[[i, c, d]];
1211            }
1212        }
1213    }
1214    Some(packed)
1215}
1216
1217/// Closed-form Gram-based compile output: a single `p_raw × p_compiled`
1218/// reparam matrix `T` mapping compiled coordinates back to raw width.
1219/// `T · θ` lifts a fitted compiled-width β back to raw width; predict-time
1220/// row contribution is `X_raw · T · θ` where `X_raw` is the full raw design.
1221///
1222/// `compiled_block_ranges[b]` gives the column range inside `T` (and inside
1223/// the compiled-width coefficient vector) attributable to raw block `b`.
1224/// `raw_block_ranges[b]` gives the corresponding raw-width column range.
1225#[derive(Debug)]
1226pub struct CompiledMap {
1227    /// `(p_raw × p_compiled)` raw-from-compiled reparam matrix.
1228    pub raw_from_compiled: Array2<f64>,
1229    /// Per-block compiled-width column ranges, parallel to
1230    /// `raw_block_ranges`. Same length as the input `ordering`.
1231    pub compiled_block_ranges: Vec<std::ops::Range<usize>>,
1232    /// Per-block raw-width column ranges (copied through from input).
1233    pub raw_block_ranges: Vec<std::ops::Range<usize>>,
1234}
1235
1236/// Neutral view of this compiled reparametrisation for the gauge layer
1237/// (#1521): `Gauge::from_compiled_map` lives DOWN in `gam-problem` and
1238/// names only the `CompiledBlockMap` trait, never the concrete
1239/// `CompiledMap` (which lives ABOVE `gam-problem`). This `impl` supplies
1240/// the inverted dependency edge.
1241impl gam_problem::gauge::CompiledBlockMap for CompiledMap {
1242    fn raw_from_compiled(&self) -> &Array2<f64> {
1243        &self.raw_from_compiled
1244    }
1245    fn raw_block_ranges(&self) -> &[std::ops::Range<usize>] {
1246        &self.raw_block_ranges
1247    }
1248    fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>] {
1249        &self.compiled_block_ranges
1250    }
1251}
1252
1253/// Closed-form Gram-based identifiability compile.
1254///
1255/// Sequential algorithm operating purely on the raw-width Grams
1256/// `K^H = Σ_i J_iᵀ H_i J_i` (curvature) and `K^S = Σ_i J_iᵀ J_i`
1257/// (structural). Walks `ordering` left-to-right; for each block `b` with
1258/// raw-width selector `P_b` (columns of the identity selecting that
1259/// block) and cumulative compiled map `T = [T_0, …, T_{b-1}]`:
1260///
1261/// 1. Structural rank step (drop true gauges):
1262///    `G^S_AA = Tᵀ K^S T`, `G^S_Ab = Tᵀ K^S P_b`, `G^S_bb = P_bᵀ K^S P_b`,
1263///    `R_S = (G^S_AA)^+ G^S_Ab`, `G^S_res = G^S_bb − G^S_Abᵀ R_S`.
1264///    Eigendecompose `G^S_res`; keep positive eigvecs `Q+`. Then
1265///    `D = (P_b − T R_S) · Q+` (raw-space cols, structurally independent
1266///    of `T`).
1267/// 2. Curvature step (within-block conditioning):
1268///    `G^H_AA = Tᵀ K^H T`, `G^H_AD = Tᵀ K^H D`,
1269///    `R_H = (G^H_AA)^+ G^H_AD`, `E = D − T R_H` (raw-space).
1270///    Curvature Gram `G^H_res = Dᵀ K^H D − G^H_ADᵀ R_H`. Eigendecompose
1271///    and keep positive eigvecs `U`. Then `T_b = E · U`.
1272/// 3. Append: `T ← [T, T_b]`.
1273///
1274/// Returns [`CompilerError::FullyAliased`] only when the first block has no
1275/// usable structural/curvature span. Later fully absorbed blocks compile to a
1276/// zero-width block range, which is the reduced-coordinate representation of
1277/// the lower-priority block owning no degrees of freedom.
1278pub fn compile_from_raw_grams(
1279    gram_h: &Array2<f64>,
1280    gram_struct: &Array2<f64>,
1281    raw_block_ranges: &[std::ops::Range<usize>],
1282    ordering: &[BlockOrder],
1283) -> Result<CompiledMap, CompilerError> {
1284    compile_from_raw_grams_protected(gram_h, gram_struct, raw_block_ranges, ordering, &[])
1285}
1286
1287/// Variant of [`compile_from_raw_grams`] that keeps designated blocks at full
1288/// raw width instead of dropping their near-null structural/curvature
1289/// directions.
1290///
1291/// `protected[b] == true` forces block `b` to retain **all** of its raw
1292/// columns: the structural and curvature eigenspace filters that would drop
1293/// weak directions are replaced by identity, so `T_b` embeds the full raw
1294/// block (orthogonalised against earlier anchors) rather than a reduced
1295/// section. The block still serves as a full-width anchor for every later
1296/// (unprotected) block, so cross-block aliasing against it is removed exactly
1297/// as before — only the protected block's own within-block reparameterisation
1298/// is suppressed.
1299///
1300/// This exists for blocks whose effective Jacobian is a **fixed nonlinear
1301/// functional basis** rather than a plain linear design (e.g. the survival
1302/// marginal-slope monotone time-wiggle block). Such a block's chain-rule
1303/// Jacobian recomputes its basis at the raw coefficient width on every
1304/// evaluation and therefore cannot be expressed on a linearly recombined /
1305/// reduced design; reparameterising it silently corrupts — and can index out
1306/// of bounds in — that basis evaluation. Keeping it at raw width lets its own
1307/// penalty nullspace regularise its conditioning, which is the correct
1308/// treatment for a within-block (as opposed to cross-block) rank deficiency.
1309///
1310/// `protected` may be shorter than `ordering` (missing entries default to
1311/// `false`); an empty slice reproduces [`compile_from_raw_grams`] exactly.
1312pub fn compile_from_raw_grams_protected(
1313    gram_h: &Array2<f64>,
1314    gram_struct: &Array2<f64>,
1315    raw_block_ranges: &[std::ops::Range<usize>],
1316    ordering: &[BlockOrder],
1317    protected: &[bool],
1318) -> Result<CompiledMap, CompilerError> {
1319    if raw_block_ranges.len() != ordering.len() {
1320        return Err(CompilerError::DimensionMismatch(format!(
1321            "raw_block_ranges ({}) and ordering ({}) length mismatch",
1322            raw_block_ranges.len(),
1323            ordering.len()
1324        )));
1325    }
1326    let p_raw = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1327    if gram_h.shape() != [p_raw, p_raw] {
1328        return Err(CompilerError::DimensionMismatch(format!(
1329            "gram_h shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1330            gram_h.shape()
1331        )));
1332    }
1333    if gram_struct.shape() != [p_raw, p_raw] {
1334        return Err(CompilerError::DimensionMismatch(format!(
1335            "gram_struct shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1336            gram_struct.shape()
1337        )));
1338    }
1339    if raw_block_ranges.is_empty() {
1340        return Ok(CompiledMap {
1341            raw_from_compiled: Array2::<f64>::zeros((0, 0)),
1342            compiled_block_ranges: Vec::new(),
1343            raw_block_ranges: Vec::new(),
1344        });
1345    }
1346    // Validate contiguous ranges from 0.
1347    let mut expected_start = 0usize;
1348    for (b, r) in raw_block_ranges.iter().enumerate() {
1349        if r.start != expected_start {
1350            return Err(CompilerError::DimensionMismatch(format!(
1351                "raw_block_ranges must be contiguous from 0; block {b} starts at {} expected {expected_start}",
1352                r.start
1353            )));
1354        }
1355        expected_start = r.end;
1356    }
1357
1358    // Cumulative raw-from-compiled map. Starts empty (zero compiled cols).
1359    let mut t_cum: Array2<f64> = Array2::<f64>::zeros((p_raw, 0));
1360    let mut compiled_block_ranges: Vec<std::ops::Range<usize>> =
1361        Vec::with_capacity(raw_block_ranges.len());
1362
1363    for (idx, range_b) in raw_block_ranges.iter().enumerate() {
1364        let p_b = range_b.end - range_b.start;
1365        let block_protected = protected.get(idx).copied().unwrap_or(false);
1366        // A zero-width block owns no raw columns. It contributes no compiled
1367        // degrees of freedom and — having no columns — cannot alias against any
1368        // anchor, so it is trivially identifiable. Emit an empty compiled range
1369        // and skip the structural/curvature analysis: a 0×0 residual Gram has no
1370        // positive eigenspace, which the first-block guard below would otherwise
1371        // mis-report as `FullyAliased` even though there is literally nothing to
1372        // alias. This mirrors the empty range a fully-absorbed later block
1373        // already compiles to (see the `q_plus.ncols() == 0` / `u_mat.ncols() == 0`
1374        // branches), keeping `kept_width + dropped_count == raw_width` exact.
1375        if p_b == 0 {
1376            let at = t_cum.ncols();
1377            compiled_block_ranges.push(at..at);
1378            continue;
1379        }
1380        // Slice gram columns/rows by raw block range. P_bᵀ K X = rows
1381        // range_b of K X. K^S T and K^H T are full-rows products.
1382        // 1) Structural rank step.
1383        // K^S · T (p_raw × p_compiled)
1384        let ks_t = fast_ab(gram_struct, &t_cum);
1385        // G^S_AA = Tᵀ K^S T (p_compiled × p_compiled)
1386        let g_s_aa = fast_atb(&t_cum, &ks_t);
1387        // G^S_Ab = Tᵀ K^S P_b = Tᵀ · K^S[:, range_b] (p_compiled × p_b)
1388        let ks_pb = gram_struct
1389            .slice(s![.., range_b.start..range_b.end])
1390            .to_owned();
1391        let g_s_ab = fast_atb(&t_cum, &ks_pb);
1392        // G^S_bb = P_bᵀ K^S P_b = K^S[range_b, range_b] (p_b × p_b)
1393        let g_s_bb = gram_struct
1394            .slice(s![range_b.start..range_b.end, range_b.start..range_b.end])
1395            .to_owned();
1396        // R_S = (G^S_AA)^+ G^S_Ab (p_compiled × p_b)
1397        let r_s = solve_psd_system(&g_s_aa, &g_s_ab)?;
1398        // G^S_res = G^S_bb − G^S_Abᵀ R_S (p_b × p_b), symmetrise.
1399        let g_s_res_raw = &g_s_bb - &fast_atb(&g_s_ab, &r_s);
1400        let g_s_res = symmetrise(&g_s_res_raw);
1401        // Trace of the unresidualised diagonal block (scale ref).
1402        let g_s_bb_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
1403        // p_raw stands in as the "n*K" scale for the closed-form tolerance.
1404        // A protected block keeps every raw column (identity structural span);
1405        // the residual-Gram eigenfilter that would drop weak directions is
1406        // suppressed so the block emerges at full raw width.
1407        let q_plus = if block_protected {
1408            Array2::<f64>::eye(p_b)
1409        } else {
1410            keep_positive_eigenspace(&g_s_res, p_raw, 1, g_s_bb_trace)?
1411        };
1412        if q_plus.ncols() == 0 {
1413            if t_cum.ncols() == 0 {
1414                return Err(CompilerError::FullyAliased {
1415                    block_idx: idx,
1416                    reason: format!(
1417                        "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
1418                    ),
1419                });
1420            }
1421            let at = t_cum.ncols();
1422            compiled_block_ranges.push(at..at);
1423            continue;
1424        }
1425        // D = (P_b − T R_S) · Q+ (p_raw × k_kept). Build (P_b − T R_S)
1426        // explicitly as a p_raw × p_b matrix: columns of P_b are columns
1427        // range_b of I_p_raw, so (P_b − T R_S) places −T R_S in all rows
1428        // and adds the identity on rows range_b.
1429        let mut diff = Array2::<f64>::zeros((p_raw, p_b));
1430        if t_cum.ncols() > 0 {
1431            // diff = −T · R_S
1432            let t_rs = fast_ab(&t_cum, &r_s);
1433            for i in 0..p_raw {
1434                for j in 0..p_b {
1435                    diff[[i, j]] = -t_rs[[i, j]];
1436                }
1437            }
1438        }
1439        for j in 0..p_b {
1440            diff[[range_b.start + j, j]] += 1.0;
1441        }
1442        let d_mat = fast_ab(&diff, &q_plus);
1443
1444        // 2) Curvature step.
1445        // K^H · T (p_raw × p_compiled), K^H · D (p_raw × k_kept)
1446        let kh_t = fast_ab(gram_h, &t_cum);
1447        let g_h_aa = fast_atb(&t_cum, &kh_t);
1448        let kh_d = fast_ab(gram_h, &d_mat);
1449        let g_h_ad = fast_atb(&t_cum, &kh_d);
1450        let r_h = solve_psd_system(&g_h_aa, &g_h_ad)?;
1451        // G^H_res = Dᵀ K^H D − G^H_ADᵀ R_H (k_kept × k_kept)
1452        let d_t_kh_d = fast_atb(&d_mat, &kh_d);
1453        let g_h_res_raw = &d_t_kh_d - &fast_atb(&g_h_ad, &r_h);
1454        let g_h_res = symmetrise(&g_h_res_raw);
1455        let k_kept = q_plus.ncols();
1456        let g_h_dd_trace: f64 = (0..k_kept).map(|i| d_t_kh_d[[i, i]].max(0.0)).sum();
1457        // A protected block also retains every structurally-kept curvature
1458        // direction (identity curvature span), so no within-block conditioning
1459        // drop occurs; its own penalty nullspace regularises the fit instead.
1460        let u_mat = if block_protected {
1461            Array2::<f64>::eye(k_kept)
1462        } else {
1463            keep_positive_eigenspace(&g_h_res, p_raw, 1, g_h_dd_trace)?
1464        };
1465        if u_mat.ncols() == 0 {
1466            if t_cum.ncols() == 0 {
1467                return Err(CompilerError::FullyAliased {
1468                    block_idx: idx,
1469                    reason: format!(
1470                        "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {k_kept}) before any anchor exists"
1471                    ),
1472                });
1473            }
1474            let at = t_cum.ncols();
1475            compiled_block_ranges.push(at..at);
1476            continue;
1477        }
1478        // E = D − T · R_H (p_raw × k_kept); T_b = E · U.
1479        let mut e_mat = d_mat.clone();
1480        if t_cum.ncols() > 0 {
1481            let t_rh = fast_ab(&t_cum, &r_h);
1482            e_mat = &e_mat - &t_rh;
1483        }
1484        let t_b = fast_ab(&e_mat, &u_mat);
1485
1486        let start = t_cum.ncols();
1487        let end = start + t_b.ncols();
1488        compiled_block_ranges.push(start..end);
1489        t_cum = concat_cols(&t_cum, &t_b);
1490    }
1491
1492    // Finite check.
1493    for v in t_cum.iter() {
1494        if !v.is_finite() {
1495            return Err(CompilerError::LinalgFailure(
1496                "compile_from_raw_grams produced non-finite entry in raw_from_compiled".to_string(),
1497            ));
1498        }
1499    }
1500
1501    Ok(CompiledMap {
1502        raw_from_compiled: t_cum,
1503        compiled_block_ranges,
1504        raw_block_ranges: raw_block_ranges.to_vec(),
1505    })
1506}
1507
1508impl CompiledMap {
1509    /// Raw coefficient width (`p_raw`).
1510    pub fn p_raw(&self) -> usize {
1511        self.raw_from_compiled.nrows()
1512    }
1513
1514    /// Compiled (reduced) coefficient width (`p_compiled`).
1515    pub fn p_compiled(&self) -> usize {
1516        self.raw_from_compiled.ncols()
1517    }
1518
1519    /// Reparameterise a raw design into compiled coordinates:
1520    /// `X_compiled = X_raw · T` (`n × p_compiled`). Because the lift is
1521    /// `β_raw = T β_compiled`, the compiled design predicts identically to the
1522    /// raw design on every compiled coefficient: `X_compiled · θ = X_raw · (T θ)`.
1523    /// Families that build directly in reduced coordinates feed this compiled
1524    /// design (and the [`reduce_penalties_with_map`] penalties) to the solver;
1525    /// the rank-deficient raw basis never reaches Newton.
1526    pub fn reduce_design(&self, raw_design: &Array2<f64>) -> Result<Array2<f64>, String> {
1527        if raw_design.ncols() != self.p_raw() {
1528            return Err(format!(
1529                "CompiledMap::reduce_design: raw_design has {} columns, expected p_raw {}",
1530                raw_design.ncols(),
1531                self.p_raw()
1532            ));
1533        }
1534        Ok(fast_ab(raw_design, &self.raw_from_compiled))
1535    }
1536
1537    /// Lift a fitted compiled-width coefficient vector back to raw width:
1538    /// `β_raw = T · β_compiled`. This is the exact inverse direction of the
1539    /// quotient reduction — the reduced coordinates are what Newton/REML
1540    /// operate in, and this map carries the final estimate (and any linear
1541    /// functional of it) back to the original parameterisation so reported
1542    /// coefficients and predictions match the raw design.
1543    pub fn lift_coefficients(&self, beta_compiled: &Array1<f64>) -> Result<Array1<f64>, String> {
1544        if beta_compiled.len() != self.p_compiled() {
1545            return Err(format!(
1546                "CompiledMap::lift_coefficients: beta_compiled len {} != p_compiled {}",
1547                beta_compiled.len(),
1548                self.p_compiled()
1549            ));
1550        }
1551        Ok(self.raw_from_compiled.dot(beta_compiled))
1552    }
1553
1554    /// The rows of `T` belonging to raw block `b` (`T[raw_block_ranges[b], :]`,
1555    /// shape `p_b_raw × p_compiled`). A raw-block penalty `S_b` acts only on
1556    /// these raw columns, so the penalty's reduced-coordinate form depends on
1557    /// `T` only through this slice.
1558    fn raw_block_rows(&self, block_idx: usize) -> Result<Array2<f64>, String> {
1559        let range = self.raw_block_ranges.get(block_idx).ok_or_else(|| {
1560            format!(
1561                "CompiledMap::raw_block_rows: block {block_idx} out of range {}",
1562                self.raw_block_ranges.len()
1563            )
1564        })?;
1565        Ok(self
1566            .raw_from_compiled
1567            .slice(s![range.start..range.end, ..])
1568            .to_owned())
1569    }
1570}
1571
1572/// Transform a per-block raw-width penalty into the compiled (reduced)
1573/// coordinate frame defined by `map`.
1574///
1575/// `raw_penalties[b]` is the penalty matrix `S_b` acting on raw block `b`
1576/// (shape `p_b_raw × p_b_raw`), or `None` for an unpenalised block. The
1577/// returned `reduced[b]` is the **full** `(p_compiled × p_compiled)` penalty
1578/// `Tᵀ Ŝ_b T`, where `Ŝ_b` embeds `S_b` into the `p_raw × p_raw` zero matrix
1579/// at block `b`'s position. Because `Ŝ_b` is zero outside block `b`'s rows and
1580/// columns, this equals `T_bᵀ S_b T_b` with `T_b = T[raw_block_ranges[b], :]`,
1581/// so the reduced penalty is computed from the block's lift rows alone — no
1582/// dense `p_raw × p_raw` embedding is materialised.
1583///
1584/// Exactness: for any compiled coefficient `θ` with raw lift `β = T θ`, the raw
1585/// penalty energy `βᵀ Ŝ_b β = (T θ)ᵀ Ŝ_b (T θ) = θᵀ (Tᵀ Ŝ_b T) θ`, so the
1586/// reduced penalty reproduces the raw penalty energy on every lifted point.
1587/// A compiled block that absorbed to zero width simply contributes a zero
1588/// column range; its raw penalty (if any) projects onto the surviving
1589/// compiled directions through `T_b`, never lost.
1590pub fn reduce_penalties_with_map(
1591    map: &CompiledMap,
1592    raw_penalties: &[Option<Array2<f64>>],
1593) -> Result<Vec<Option<Array2<f64>>>, String> {
1594    if raw_penalties.len() != map.raw_block_ranges.len() {
1595        return Err(format!(
1596            "reduce_penalties_with_map: raw_penalties ({}) != blocks ({})",
1597            raw_penalties.len(),
1598            map.raw_block_ranges.len()
1599        ));
1600    }
1601    let p_compiled = map.p_compiled();
1602    let mut reduced: Vec<Option<Array2<f64>>> = Vec::with_capacity(raw_penalties.len());
1603    for (block_idx, raw_penalty) in raw_penalties.iter().enumerate() {
1604        let Some(s_b) = raw_penalty.as_ref() else {
1605            reduced.push(None);
1606            continue;
1607        };
1608        let p_b_raw = map.raw_block_ranges[block_idx].len();
1609        if s_b.shape() != [p_b_raw, p_b_raw] {
1610            return Err(format!(
1611                "reduce_penalties_with_map: block {block_idx} penalty shape {:?} != [{p_b_raw}, {p_b_raw}]",
1612                s_b.shape()
1613            ));
1614        }
1615        // T_b = T[raw rows of block b, :]  (p_b_raw × p_compiled)
1616        let t_b = map.raw_block_rows(block_idx)?;
1617        // S_compiled = T_bᵀ S_b T_b  (p_compiled × p_compiled)
1618        let s_t_b = fast_ab(s_b, &t_b); // (p_b_raw × p_compiled)
1619        let s_compiled_raw = fast_atb(&t_b, &s_t_b); // (p_compiled × p_compiled)
1620        let mut s_compiled = symmetrise(&s_compiled_raw);
1621        if s_compiled.shape() != [p_compiled, p_compiled] {
1622            return Err(format!(
1623                "reduce_penalties_with_map: block {block_idx} reduced penalty shape {:?} != [{p_compiled}, {p_compiled}]",
1624                s_compiled.shape()
1625            ));
1626        }
1627        for v in s_compiled.iter_mut() {
1628            if !v.is_finite() {
1629                return Err(format!(
1630                    "reduce_penalties_with_map: block {block_idx} reduced penalty has non-finite entry"
1631                ));
1632            }
1633        }
1634        reduced.push(Some(s_compiled));
1635    }
1636    Ok(reduced)
1637}
1638
1639/// Per-block exact orthogonal reparameterisation of structural confounds.
1640///
1641/// `block_transforms[b]` is a dense `(p_b × r_b)` reparam `V_b` mapping raw
1642/// block-`b` coefficients to reduced coordinates: the orthogonalised block
1643/// design is `X_b · V_b`, and a fitted reduced coefficient lifts back to raw
1644/// space exactly via `β_b_raw = V_b · θ_b`. `r_b ≤ p_b`; `r_b < p_b` exactly
1645/// when block `b` carries `p_b − r_b` directions already spanned (in the
1646/// pilot W-metric) by the cumulative anchor of all higher-priority blocks —
1647/// those directions are removed (not penalised), so the joint design
1648/// `[X_0 V_0 | X_1 V_1 | …]` has the overlap excised exactly.
1649pub struct BlockOrthogonalization {
1650    /// `block_transforms[b]`: the `(p_b × r_b)` reparam `V_b` for raw block `b`,
1651    /// in the **original block order** (parallel to the `block_designs` input).
1652    pub block_transforms: Vec<Array2<f64>>,
1653    /// `(block_idx, local_raw_col_count_dropped)` for every block whose
1654    /// reduced width is strictly smaller than its raw width — i.e. the blocks
1655    /// that shed overlap directions against the anchor. Empty when no block
1656    /// overlapped (every `V_b` is then a `p_b × p_b` rotation/identity).
1657    pub dropped: Vec<(usize, usize)>,
1658    /// One structural annotation per input block, in original block order.
1659    ///
1660    /// This is the explicit "same direction vs independent direction" verdict:
1661    /// `Independent` means the block kept its full realized-design rank, while
1662    /// `PartiallyAbsorbed...` / `FullyAbsorbed...` mean the lower-priority block
1663    /// shared realized-design directions with the cumulative anchor and those
1664    /// directions were removed rather than assigned a separate penalty.
1665    pub direction_annotations: Vec<PenalizedDirectionAnnotation>,
1666}
1667
1668/// Build per-block exact W-metric orthogonalising reparameterisations.
1669///
1670/// `block_designs[b]` is the raw `(n × p_b)` design of block `b`.
1671/// `priority[b]` is the block's gauge priority — blocks are residualised in
1672/// **descending** priority order, so the highest-priority block keeps its full
1673/// column span and lower-priority blocks shed only the directions already
1674/// explained by the cumulative higher-priority anchor. `weight` is the pilot
1675/// W-metric row weight `w_i ≥ 0` (the diagonal of the working GLM/GAM Hessian
1676/// at the pilot β); pass an all-ones vector for the plain Euclidean metric.
1677///
1678/// The returned `block_transforms` are in the **original** block order. For a
1679/// block whose columns are all W-orthogonal to the anchor, `V_b` is a square
1680/// `p_b × p_b` orthonormal rotation (rank preserved, round-trip exact). For a
1681/// block with an overlap of dimension `d`, `V_b` is `p_b × (p_b − d)` and the
1682/// `d` overlap directions are removed exactly.
1683///
1684/// Exactness / round-trip: `X_b · V_b` is the reduced design and
1685/// `β_b_raw = V_b · θ_b` lifts a reduced fit back to raw coordinates. `V_b` has
1686/// orthonormal columns (eigenvectors of the residual Gram), so the lift is the
1687/// minimum-norm raw representative of the reduced fit.
1688pub fn orthogonalize_design_blocks(
1689    block_designs: &[Array2<f64>],
1690    priority: &[u32],
1691    weight: &[f64],
1692) -> Result<BlockOrthogonalization, CompilerError> {
1693    if block_designs.len() != priority.len() {
1694        return Err(CompilerError::DimensionMismatch(format!(
1695            "block_designs ({}) and priority ({}) length mismatch",
1696            block_designs.len(),
1697            priority.len()
1698        )));
1699    }
1700    if block_designs.is_empty() {
1701        return Ok(BlockOrthogonalization {
1702            block_transforms: Vec::new(),
1703            dropped: Vec::new(),
1704            direction_annotations: Vec::new(),
1705        });
1706    }
1707    let n = block_designs[0].nrows();
1708    for (b, x) in block_designs.iter().enumerate() {
1709        if x.nrows() != n {
1710            return Err(CompilerError::DimensionMismatch(format!(
1711                "block {b} design has {} rows but block 0 has {n}",
1712                x.nrows()
1713            )));
1714        }
1715    }
1716    if weight.len() != n {
1717        return Err(CompilerError::DimensionMismatch(format!(
1718            "weight length {} != n {n}",
1719            weight.len()
1720        )));
1721    }
1722    // sqrt(W) row scale (clamp tiny-negative to zero — the pilot Hessian
1723    // diagonal is PSD-clamped upstream, but guard against round-off).
1724    let mut sqrt_w = Array1::<f64>::zeros(n);
1725    for i in 0..n {
1726        let wi = weight[i].max(0.0);
1727        sqrt_w[i] = wi.sqrt();
1728    }
1729
1730    // Descending-priority visitation order over the original block indices.
1731    // Stable on ties (preserves input order) so the anchor build is
1732    // deterministic.
1733    let mut order: Vec<usize> = (0..block_designs.len()).collect();
1734    order.sort_by(|&a, &b| priority[b].cmp(&priority[a]));
1735
1736    // Cumulative weighted anchor `A = sqrt(W) · [kept block designs]`.
1737    let mut anchor: Array2<f64> = Array2::<f64>::zeros((n, 0));
1738
1739    // Output transforms indexed by ORIGINAL block index (filled out of order).
1740    let mut block_transforms: Vec<Option<Array2<f64>>> = vec![None; block_designs.len()];
1741    let mut direction_annotations: Vec<Option<PenalizedDirectionAnnotation>> =
1742        vec![None; block_designs.len()];
1743    let mut dropped: Vec<(usize, usize)> = Vec::new();
1744
1745    for &b in order.iter() {
1746        let x_b = &block_designs[b];
1747        let p_b = x_b.ncols();
1748        // Weighted block design `W_b = sqrt(W) · X_b`.
1749        let mut w_b = x_b.clone();
1750        for i in 0..n {
1751            let s = sqrt_w[i];
1752            for j in 0..p_b {
1753                w_b[[i, j]] *= s;
1754            }
1755        }
1756        // Residualise `W_b` against the cumulative anchor in the W-metric and
1757        // eigendecompose the residual Gram. Eigenvectors with positive
1758        // eigenvalues span block `b`'s W-orthogonal-to-anchor column space;
1759        // the zero-eigenvalue directions are exactly the overlap with the
1760        // anchor and are removed.
1761        let (residual, _correction) = residualise_in_metric(&anchor, &w_b)?;
1762        let g_res = symmetrise(&fast_atb(&residual, &residual));
1763        // Scale reference for `keep_positive_eigenspace` must be the
1764        // *original* (pre-residualisation) weighted block Gram trace, NOT the
1765        // residual's. When `b` is fully absorbed by a higher-priority anchor
1766        // the residual collapses to floating-point noise (~ε² of the original
1767        // O(1) data); anchoring tau to that noise floor would keep the noise
1768        // eigenvalues and misreport a fully-absorbed block as `Independent`.
1769        // The original-block trace is invariant to absorption, so a near-zero
1770        // residual is correctly rejected as fully absorbed.
1771        let g_bb = fast_atb(&w_b, &w_b);
1772        let g_bb_trace: f64 = (0..p_b).map(|i| g_bb[[i, i]].max(0.0)).sum();
1773        let v_b = keep_positive_eigenspace(&g_res, n, 1, g_bb_trace)?;
1774        let r_b = v_b.ncols();
1775        let absorbed_width = p_b - r_b;
1776        let kind = if absorbed_width == 0 {
1777            PenalizedDirectionAnnotationKind::Independent
1778        } else if r_b == 0 {
1779            PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority
1780        } else {
1781            PenalizedDirectionAnnotationKind::PartiallyAbsorbedByHigherPriority
1782        };
1783        direction_annotations[b] = Some(PenalizedDirectionAnnotation {
1784            block_idx: b,
1785            raw_width: p_b,
1786            kept_width: r_b,
1787            absorbed_width,
1788            kind,
1789        });
1790        if absorbed_width > 0 {
1791            dropped.push((b, absorbed_width));
1792        }
1793        // Append this block's kept, W-orthogonalised weighted columns to the
1794        // anchor so lower-priority blocks residualise against them too. The
1795        // residual (already anchor-orthogonal) projected onto the kept basis
1796        // is `residual · V_b` — these are mutually orthogonal in the W-metric
1797        // by construction of `keep_positive_eigenspace`.
1798        let kept_weighted = fast_ab(&residual, &v_b);
1799        anchor = concat_cols(&anchor, &kept_weighted);
1800        block_transforms[b] = Some(v_b);
1801    }
1802
1803    let block_transforms: Vec<Array2<f64>> = block_transforms
1804        .into_iter()
1805        .enumerate()
1806        .map(|(b, t)| {
1807            t.ok_or_else(|| {
1808                CompilerError::LinalgFailure(format!(
1809                    "orthogonalize_design_blocks: block {b} transform was never assigned"
1810                ))
1811            })
1812        })
1813        .collect::<Result<Vec<_>, _>>()?;
1814    let direction_annotations: Vec<PenalizedDirectionAnnotation> = direction_annotations
1815        .into_iter()
1816        .enumerate()
1817        .map(|(b, annotation)| {
1818            annotation.ok_or_else(|| {
1819                CompilerError::LinalgFailure(format!(
1820                    "orthogonalize_design_blocks: block {b} direction annotation was never assigned"
1821                ))
1822            })
1823        })
1824        .collect::<Result<Vec<_>, _>>()?;
1825
1826    // Finite check on every transform.
1827    for (b, v) in block_transforms.iter().enumerate() {
1828        for value in v.iter() {
1829            if !value.is_finite() {
1830                return Err(CompilerError::LinalgFailure(format!(
1831                    "orthogonalize_design_blocks: block {b} transform has a non-finite entry"
1832                )));
1833            }
1834        }
1835    }
1836
1837    Ok(BlockOrthogonalization {
1838        block_transforms,
1839        dropped,
1840        direction_annotations,
1841    })
1842}
1843
1844/// Symmetrise a (nearly-symmetric) matrix by averaging with its transpose.
1845fn symmetrise(m: &Array2<f64>) -> Array2<f64> {
1846    let (r, c) = m.dim();
1847    assert_eq!(r, c, "symmetrise expects square matrix");
1848    let mut out = Array2::<f64>::zeros((r, c));
1849    for i in 0..r {
1850        for j in 0..c {
1851            out[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
1852        }
1853    }
1854    out
1855}
1856
1857#[cfg(test)]
1858mod tests {
1859    use super::*;
1860    use ndarray::{Array1, Array2};
1861
1862    /// Convenience: wrap a dense `(n × p)` block design as a `K=1`
1863    /// row-Jacobian operator. Used by tests; production families ship their
1864    /// own concrete operators.
1865    struct DenseScalarOperator {
1866        design: Array2<f64>,
1867    }
1868
1869    impl DenseScalarOperator {
1870        fn new(design: Array2<f64>) -> Self {
1871            Self { design }
1872        }
1873    }
1874
1875    impl RowJacobianOperator for DenseScalarOperator {
1876        fn k(&self) -> usize {
1877            1
1878        }
1879        fn ncols(&self) -> usize {
1880            self.design.ncols()
1881        }
1882        fn nrows(&self) -> usize {
1883            self.design.nrows()
1884        }
1885        fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
1886            assert_eq!(out.len(), 1);
1887            let mut acc = 0.0;
1888            for (j, &b) in delta_beta.iter().enumerate() {
1889                acc += self.design[[row, j]] * b;
1890            }
1891            out[0] = acc;
1892        }
1893        fn evaluate_full(&self) -> Array3<f64> {
1894            let n = self.design.nrows();
1895            let p = self.design.ncols();
1896            let mut out = Array3::<f64>::zeros((n, p, 1));
1897            for i in 0..n {
1898                for j in 0..p {
1899                    out[[i, j, 0]] = self.design[[i, j]];
1900                }
1901            }
1902            out
1903        }
1904    }
1905
1906    // `IdentityRowHessian` is re-exported from the parent module's `use
1907    // super::*;` above (now a public struct so the dual-metric API can
1908    // share the default structural metric with callers).
1909
1910    /// Diagonal row Hessian with per-row scalar weights (K=1 case).
1911    struct DiagonalScalarRowHessian {
1912        w: Array1<f64>,
1913    }
1914
1915    impl DiagonalScalarRowHessian {
1916        fn new(w: Array1<f64>) -> Self {
1917            Self { w }
1918        }
1919    }
1920
1921    impl RowHessian for DiagonalScalarRowHessian {
1922        fn k(&self) -> usize {
1923            1
1924        }
1925        fn nrows(&self) -> usize {
1926            self.w.len()
1927        }
1928        fn fill_row(&self, row: usize, out: &mut [f64]) {
1929            assert_eq!(out.len(), 1);
1930            out[0] = self.w[row];
1931        }
1932        fn evaluate_full(&self) -> Array3<f64> {
1933            let n = self.w.len();
1934            let mut out = Array3::<f64>::zeros((n, 1, 1));
1935            for i in 0..n {
1936                out[[i, 0, 0]] = self.w[i];
1937            }
1938            out
1939        }
1940    }
1941
1942    fn op(design: Array2<f64>) -> Arc<dyn RowJacobianOperator> {
1943        Arc::new(DenseScalarOperator::new(design))
1944    }
1945
1946    /// §10 test #1: two affine blocks, identity row Hessian. The compiled
1947    /// second-block design must be orthogonal to the first block under the
1948    /// (identity) row metric to machine epsilon.
1949    #[test]
1950    fn compile_two_block_orthogonalises_under_metric() {
1951        let n = 50;
1952        let a = Array2::from_shape_fn((n, 3), |(i, j)| ((i + 1) as f64).sin().powi((j + 1) as i32));
1953        // B partly aliases A's first column.
1954        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1955            0.5 * a[[i, 0]] + ((i as f64) * 0.13 + j as f64).cos()
1956        });
1957        let hess = IdentityRowHessian::new(n, 1);
1958        let ops = vec![op(a.clone()), op(b.clone())];
1959        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1960            .expect("compile should succeed");
1961        // Build A's design (no rotation) and B's compiled design B·V − A·M.
1962        let v_b = &compiled.blocks[1].t_lw;
1963        let m_b = compiled.blocks[1]
1964            .anchor_correction
1965            .as_ref()
1966            .expect("second block must carry an anchor correction");
1967        let b_v = b.dot(v_b);
1968        let a_m = a.dot(m_b);
1969        let b_compiled = &b_v - &a_m;
1970        // <A, B_compiled>_I = Aᵀ · B_compiled should be ≈ 0.
1971        let cross = a.t().dot(&b_compiled);
1972        let max_err = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1973        assert!(
1974            max_err < 1e-10,
1975            "orthogonality residual too large: {max_err:e}"
1976        );
1977    }
1978
1979    /// §10 test #2: three-block chain with sequential aliases.
1980    #[test]
1981    fn compile_three_block_chain() {
1982        let n = 80;
1983        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.1 + j as f64).sin());
1984        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1985            0.3 * a[[i, 0]] + (j as f64) * (i as f64).cos()
1986        });
1987        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
1988            0.2 * a[[i, 1]] + 0.4 * b[[i, 0]] + ((i + j) as f64).tan().min(5.0).max(-5.0)
1989        });
1990        let hess = IdentityRowHessian::new(n, 1);
1991        let ops = vec![op(a), op(b), op(c)];
1992        let compiled = compile(
1993            &ops,
1994            &hess,
1995            &[
1996                BlockOrder::Marginal,
1997                BlockOrder::Logslope,
1998                BlockOrder::LinkDev,
1999            ],
2000        )
2001        .expect("compile should succeed");
2002        let total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
2003        assert_eq!(
2004            compiled.joint_rank, total,
2005            "audit must report full rank on synthetic full-rank design"
2006        );
2007    }
2008
2009    /// `compile_protected` keeps a rank-deficient protected first block at full
2010    /// raw width (identity V) while the unprotected path drops its null
2011    /// direction, and later blocks still orthogonalise against the full anchor.
2012    /// Mirrors the `compile_from_raw_grams_protected` guard for the operator
2013    /// (per-term) reduction path used by the survival time-wiggle time block.
2014    #[test]
2015    fn compile_protected_keeps_rank_deficient_first_block_full_width() {
2016        let n = 40;
2017        // Block A: two identical columns → structural rank 1 (one within-block
2018        // null the unprotected filter drops).
2019        let a = Array2::from_shape_fn((n, 2), |(i, _)| ((i + 1) as f64 * 0.31).sin());
2020        let b = Array2::from_shape_fn((n, 2), |(i, j)| ((i as f64) * 0.17 + j as f64).cos());
2021        let hess = IdentityRowHessian::new(n, 1);
2022        let ordering = [BlockOrder::Time, BlockOrder::Marginal];
2023
2024        let unprotected = compile(&[op(a.clone()), op(b.clone())], &hess, &ordering)
2025            .expect("unprotected compile");
2026        assert_eq!(
2027            unprotected.blocks[0].t_lw.ncols(),
2028            1,
2029            "unprotected first block drops its duplicate column"
2030        );
2031
2032        let protected = compile_protected(
2033            &[op(a.clone()), op(b.clone())],
2034            &hess,
2035            &ordering,
2036            &[true, false],
2037        )
2038        .expect("protected compile");
2039        let v_a = &protected.blocks[0].t_lw;
2040        assert_eq!(
2041            v_a.ncols(),
2042            2,
2043            "protected first block retains its full raw width"
2044        );
2045        // V_a is the 2×2 identity: raw coords == compiled coords for the
2046        // protected first block.
2047        for i in 0..2 {
2048            for j in 0..2 {
2049                let expect = if i == j { 1.0 } else { 0.0 };
2050                assert!(
2051                    (v_a[[i, j]] - expect).abs() <= 1e-12,
2052                    "protected first block V must be identity, got [{i},{j}]={}",
2053                    v_a[[i, j]]
2054                );
2055            }
2056        }
2057    }
2058
2059    /// §10 test #3: non-identity row Hessian. With K=1 and weights `w`,
2060    /// the projection of a 1-col block `b` onto a 1-col block `a` is
2061    /// `Σ w·a·b / Σ w·a²`. Verify the Gram solve recovers this scalar.
2062    #[test]
2063    fn compile_weighted_metric_nontrivial() {
2064        let n = 32;
2065        let a: Array2<f64> = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 1.0).sqrt());
2066        let b: Array2<f64> =
2067            Array2::from_shape_fn((n, 1), |(i, _)| 0.7 * a[[i, 0]] + (i as f64 * 0.05).cos());
2068        let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.2).sin().abs());
2069        let hess = DiagonalScalarRowHessian::new(w.clone());
2070        let ops = vec![op(a.clone()), op(b.clone())];
2071        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2072            .expect("compile should succeed");
2073        let m = compiled.blocks[1]
2074            .anchor_correction
2075            .as_ref()
2076            .expect("anchor correction present");
2077        let analytic_num: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * b[[i, 0]]).sum();
2078        let analytic_den: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * a[[i, 0]]).sum();
2079        let analytic = analytic_num / analytic_den;
2080        assert!(m.dim() == (1, 1));
2081        assert!(
2082            (m[[0, 0]] - analytic).abs() < 1e-10,
2083            "weighted projection mismatch: got {got}, analytic {analytic}",
2084            got = m[[0, 0]]
2085        );
2086    }
2087
2088    /// Regression for #372: an anchor block that internally sheds an aliased
2089    /// column makes the residualised kept-anchor width (`anchor_h.ncols()`)
2090    /// strictly smaller than the raw anchor width (`d_total`). The emitted
2091    /// `anchor_correction` must be expressed in *raw* anchor-column
2092    /// coordinates so the predict-time / install-time subtraction
2093    /// `A_raw(x)·M` is dimensionally and metrically correct. Previously the
2094    /// correction was indexed by kept directions, producing a (d_total−1)×k
2095    /// matrix and the failure
2096    /// `anchor_correction shape 36x6 does not match d_total=37`.
2097    #[test]
2098    fn compile_emits_anchor_correction_in_raw_column_coordinates() {
2099        let n = 64;
2100        // Anchor block A has 3 raw columns but only rank 2: col 2 is an exact
2101        // linear combination of cols 0 and 1, so the compiler keeps just two
2102        // anchor directions (kept width 2 < raw width 3).
2103        let a: Array2<f64> = Array2::from_shape_fn((n, 3), |(i, j)| {
2104            let c0 = (i as f64 * 0.07 + 1.0).ln();
2105            let c1 = (i as f64 * 0.13).sin();
2106            match j {
2107                0 => c0,
2108                1 => c1,
2109                _ => 2.0 * c0 - 0.5 * c1,
2110            }
2111        });
2112        // Candidate block C: partly aliases A's span plus genuine signal.
2113        let c: Array2<f64> = Array2::from_shape_fn((n, 2), |(i, j)| {
2114            0.4 * a[[i, 0]] + (j as f64) * (i as f64 * 0.05).cos() + (i as f64 * 0.011).tanh()
2115        });
2116        let w = Array1::from_shape_fn(n, |i| 0.3 + (i as f64 * 0.17).sin().abs());
2117        let hess = DiagonalScalarRowHessian::new(w.clone());
2118        let ops = vec![op(a.clone()), op(c.clone())];
2119        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::LinkDev])
2120            .expect("compile should succeed");
2121
2122        let v = &compiled.blocks[1].t_lw;
2123        let m = compiled.blocks[1]
2124            .anchor_correction
2125            .as_ref()
2126            .expect("candidate block must carry an anchor correction");
2127        let k_kept = v.ncols();
2128        assert!(k_kept >= 1, "candidate must keep at least one direction");
2129
2130        // The off-by-one the issue tripped on: M must have one row per *raw*
2131        // anchor column (3), not per kept anchor direction (2).
2132        assert_eq!(
2133            m.nrows(),
2134            a.ncols(),
2135            "anchor_correction must be indexed by raw anchor columns (d_total), \
2136             got {} rows for {} raw anchor columns",
2137            m.nrows(),
2138            a.ncols(),
2139        );
2140        assert_eq!(m.ncols(), k_kept, "anchor_correction width must match V");
2141
2142        // Metric correctness: the raw-coordinate subtraction A_raw·M must make
2143        // the compiled candidate design W-orthogonal to the full raw anchor
2144        // span. C̃ = C·V − A·M; require Aᵀ W C̃ ≈ 0 column-wise.
2145        let c_v = c.dot(v);
2146        let a_m = a.dot(m);
2147        let c_tilde = &c_v - &a_m;
2148        let mut max_cross = 0.0_f64;
2149        for ac in 0..a.ncols() {
2150            for cc in 0..c_tilde.ncols() {
2151                let mut acc = 0.0;
2152                for i in 0..n {
2153                    acc += w[i] * a[[i, ac]] * c_tilde[[i, cc]];
2154                }
2155                max_cross = max_cross.max(acc.abs());
2156            }
2157        }
2158        assert!(
2159            max_cross < 1e-9,
2160            "raw-coordinate anchor correction must W-orthogonalise the candidate \
2161             against the raw anchor span; max |Aᵀ W C̃| = {max_cross:e}"
2162        );
2163    }
2164
2165    /// §10 test #4: deliberately rank-deficient joint design. The trailing
2166    /// pivot drop must come from the *latest* block in the ordering.
2167    #[test]
2168    fn compile_drops_trailing_pivots_from_latest_block() {
2169        let n = 40;
2170        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2171        // c is exactly a's first column → after residualising c against a,
2172        // the residual span is zero in that direction, but a non-zero
2173        // independent column also exists. Add an extra exact-alias column
2174        // to force trailing-pivot drop at the audit stage.
2175        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2176            if j == 0 {
2177                a[[i, 0]]
2178            } else {
2179                (i as f64 * 0.1).cos()
2180            }
2181        });
2182        let hess = IdentityRowHessian::new(n, 1);
2183        let ops = vec![op(a), op(c)];
2184        // Manually inject a known alias: pass a second block whose
2185        // residualised columns will themselves be linearly dependent on
2186        // the first block after metric projection — already covered by the
2187        // eigenvalue threshold inside `compile`. Verify either drop path
2188        // (eigen-threshold or audit) attributes loss to block index 1.
2189        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2190            .expect("compile should succeed");
2191        // Either the eigen-threshold dropped a column from block 1, or
2192        // the audit did. In both cases block 1's V must have fewer than
2193        // its 2 input columns.
2194        let v1_cols = compiled.blocks[1].t_lw.ncols();
2195        assert!(
2196            v1_cols < 2 || !compiled.dropped.is_empty(),
2197            "expected rank loss attributed to block 1, got v1_cols={v1_cols}, dropped={dropped:?}",
2198            dropped = compiled.dropped
2199        );
2200        for (block_idx, _) in &compiled.dropped {
2201            assert_eq!(
2202                *block_idx, 1,
2203                "audit drops must come from the latest block only"
2204            );
2205        }
2206    }
2207
2208    /// Regression: when `audit_and_drop_trailing_pivots` truncates the
2209    /// latest block's `t_lw`, the sibling `anchor_correction` and `r_lw`
2210    /// matrices must be truncated to the same `k_kept` so the trailing-
2211    /// block install path sees a coherent
2212    /// `t_lw.ncols() == anchor_correction.ncols() == r_lw.ncols()` shape.
2213    ///
2214    /// Pre-fix bug: only `t_lw` got truncated. Downstream callers
2215    /// asserting `anchor_correction.ncols() == k_kept` then failed with
2216    /// `cross-block identifiability: anchor_correction shape D×P does
2217    /// not match expected d_total=D × k_kept=K` — surfaced via the
2218    /// large-scale V+M repro test.
2219    #[test]
2220    fn audit_truncation_keeps_t_lw_and_anchor_correction_in_lockstep() {
2221        let n = 40;
2222        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2223        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2224            if j == 0 {
2225                a[[i, 0]]
2226            } else {
2227                (i as f64 * 0.1).cos()
2228            }
2229        });
2230        let hess = IdentityRowHessian::new(n, 1);
2231        let ops = vec![op(a), op(c)];
2232        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2233            .expect("compile should succeed");
2234        for (idx, block) in compiled.blocks.iter().enumerate() {
2235            let k_kept = block.t_lw.ncols();
2236            if let Some(m) = block.anchor_correction.as_ref() {
2237                assert_eq!(
2238                    m.ncols(),
2239                    k_kept,
2240                    "block {idx}: anchor_correction.ncols()={ac} must equal t_lw.ncols()={k_kept} \
2241                     after audit truncation",
2242                    ac = m.ncols(),
2243                );
2244            }
2245            if let Some(r) = block.r_lw.as_ref() {
2246                assert_eq!(
2247                    r.ncols(),
2248                    k_kept,
2249                    "block {idx}: r_lw.ncols()={r_cols} must equal t_lw.ncols()={k_kept} \
2250                     after audit truncation",
2251                    r_cols = r.ncols(),
2252                );
2253            }
2254        }
2255    }
2256
2257    /// §10 test #5: regression test for the deleted FlexEvaluation skip
2258    /// bug. A flex anchor (represented by a dense scalar operator with the
2259    /// same column span as the parametric reference) must receive the same
2260    /// residualisation as the parametric anchor.
2261    #[test]
2262    fn compile_flex_anchor_is_first_class() {
2263        let n = 60;
2264        // Two parametric blocks A, B; a third "flex" block C whose
2265        // operator is dense (modelling a compiled flex anchor's column
2266        // span). All-parametric reference vs. mixed parametric+flex must
2267        // produce identical compiled blocks B (residualised against A)
2268        // because the compiler treats every input as a `RowJacobianOperator`.
2269        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.07 + j as f64).sin());
2270        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2271            0.4 * a[[i, 0]] + (j as f64) * (i as f64 + 1.0).ln()
2272        });
2273        let hess = IdentityRowHessian::new(n, 1);
2274
2275        let ops_param = vec![op(a.clone()), op(b.clone())];
2276        let compiled_param = compile(
2277            &ops_param,
2278            &hess,
2279            &[BlockOrder::Marginal, BlockOrder::Logslope],
2280        )
2281        .expect("compile should succeed");
2282
2283        // Now wrap A's design behind a mock anchor evaluator and feed it
2284        // to the compiler as a `DenseScalarOperator` with the same span.
2285        // The B-block result must match the parametric reference.
2286        let ops_flex = vec![op(a.clone()), op(b.clone())];
2287        let compiled_flex = compile(
2288            &ops_flex,
2289            &hess,
2290            &[BlockOrder::ScoreWarp, BlockOrder::LinkDev],
2291        )
2292        .expect("compile should succeed");
2293
2294        let m_param = compiled_param.blocks[1].anchor_correction.as_ref().unwrap();
2295        let m_flex = compiled_flex.blocks[1].anchor_correction.as_ref().unwrap();
2296        assert_eq!(m_param.dim(), m_flex.dim());
2297        let max_diff = (m_param - m_flex)
2298            .iter()
2299            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2300        assert!(
2301            max_diff < 1e-12,
2302            "flex vs parametric anchor correction mismatch: {max_diff:e}"
2303        );
2304    }
2305
2306    /// §10 test #7: Bernoulli row Hessian = IRLS weight. Verified at the
2307    /// trait level — a `DiagonalScalarRowHessian` round-trips through
2308    /// `evaluate_full` to the same per-row scalar.
2309    #[test]
2310    fn bernoulli_row_hessian_matches_irls_weight() {
2311        let w = Array1::from(vec![0.1, 0.5, 0.9, 0.25, 0.75]);
2312        let hess = DiagonalScalarRowHessian::new(w.clone());
2313        let full = hess.evaluate_full();
2314        assert_eq!(full.shape(), &[5, 1, 1]);
2315        for i in 0..5 {
2316            assert_eq!(full[[i, 0, 0]], w[i]);
2317            let mut buf = [0.0_f64; 1];
2318            hess.fill_row(i, &mut buf);
2319            assert_eq!(buf[0], w[i]);
2320        }
2321    }
2322
2323    /// §10 test #8: predict-path roundtrip. With the parametric setting,
2324    /// the row-application of `(C(x)·V − A(x)·M)` at training rows must
2325    /// equal the in-metric residual computed during `compile`.
2326    #[test]
2327    fn compiler_predict_path_roundtrip() {
2328        let n = 24;
2329        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.21).cos() + j as f64);
2330        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2331            0.3 * a[[i, 0]] + (i as f64 + j as f64).sqrt()
2332        });
2333        let hess = IdentityRowHessian::new(n, 1);
2334        let ops = vec![op(a.clone()), op(b.clone())];
2335        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2336            .expect("compile should succeed");
2337        let v_b = &compiled.blocks[1].t_lw;
2338        let m_b = compiled.blocks[1].anchor_correction.as_ref().unwrap();
2339        // Training-time residual: B · V − A · M.
2340        let predict_design = b.dot(v_b) - a.dot(m_b);
2341        // Compare to the algebraic in-metric residual: same expression
2342        // (identity row Hessian collapses sqrt(H) = I), so this is a
2343        // self-consistency / shape check ensuring V and M compose to the
2344        // promised predict-time operator.
2345        assert_eq!(predict_design.nrows(), n);
2346        assert_eq!(predict_design.ncols(), v_b.ncols());
2347        // Finite-value gate.
2348        for &val in predict_design.iter() {
2349            assert!(val.is_finite(), "predict design produced non-finite entry");
2350        }
2351    }
2352
2353    /// `r_lw` and `anchor_correction` are populated on every non-first
2354    /// block as `M_b · V_b` at compiled width. The first block carries
2355    /// `None`. Also verifies the H-orthogonality invariant that the
2356    /// cumulative anchor for the next iteration is orthogonal (in the row
2357    /// metric) to the prior block's design.
2358    #[test]
2359    fn compile_exposes_r_lw_equal_to_m_dot_v() {
2360        let n = 40;
2361        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.17 + j as f64).sin());
2362        // B partially aliases A's first column, so anchor correction is non-trivial.
2363        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2364            0.6 * a[[i, 0]] + ((i as f64) * 0.11 + j as f64).cos()
2365        });
2366        let hess = IdentityRowHessian::new(n, 1);
2367        let ops = vec![op(a.clone()), op(b.clone())];
2368        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2369            .expect("compile should succeed");
2370
2371        // First block: no anchor → both fields None.
2372        assert!(compiled.blocks[0].r_lw.is_none());
2373        assert!(compiled.blocks[0].anchor_correction.is_none());
2374
2375        // Second block: r_lw and anchor_correction must both equal M·V at
2376        // compiled width (p_a_kept × p_b_kept).
2377        let v_a = &compiled.blocks[0].t_lw;
2378        let v_b = &compiled.blocks[1].t_lw;
2379        let m_compiled = compiled.blocks[1]
2380            .anchor_correction
2381            .as_ref()
2382            .expect("second block must carry an anchor correction");
2383        let r_lw = compiled.blocks[1]
2384            .r_lw
2385            .as_ref()
2386            .expect("second block must expose r_lw");
2387        let p_a_kept = v_a.ncols();
2388        let p_b_kept = v_b.ncols();
2389        assert_eq!(
2390            m_compiled.dim(),
2391            (p_a_kept, p_b_kept),
2392            "anchor_correction must be at compiled width"
2393        );
2394        assert_eq!(r_lw.dim(), (p_a_kept, p_b_kept));
2395        // r_lw and anchor_correction are synonymous.
2396        let diff = r_lw - m_compiled;
2397        let max_diff = diff.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2398        assert!(
2399            max_diff == 0.0,
2400            "r_lw and anchor_correction must be identical"
2401        );
2402
2403        // H-orthogonality (identity row metric): the residualised
2404        // compiled B-design `B·V − A·(M·V)` must be orthogonal to A in
2405        // the column-inner-product sense. This validates that the
2406        // cumulative anchor build uses `(W_b − A·M)·V` rather than `W_b·V`.
2407        let b_compiled = b.dot(v_b) - a.dot(m_compiled);
2408        let cross = a.t().dot(&b_compiled);
2409        let max_cross = cross.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2410        assert!(
2411            max_cross < 1e-10,
2412            "compiled B-design must be H-orthogonal to A: max cross = {max_cross:e}"
2413        );
2414    }
2415
2416    /// `K=4` dense row Hessian: per-row PSD matrix supplied directly.
2417    struct DenseRowHessian {
2418        h: Array3<f64>,
2419    }
2420
2421    impl RowHessian for DenseRowHessian {
2422        fn k(&self) -> usize {
2423            self.h.shape()[1]
2424        }
2425        fn nrows(&self) -> usize {
2426            self.h.shape()[0]
2427        }
2428        fn fill_row(&self, row: usize, out: &mut [f64]) {
2429            let k = self.k();
2430            assert_eq!(out.len(), k * k);
2431            for c in 0..k {
2432                for d in 0..k {
2433                    out[c * k + d] = self.h[[row, c, d]];
2434                }
2435            }
2436        }
2437        fn evaluate_full(&self) -> Array3<f64> {
2438            self.h.clone()
2439        }
2440    }
2441
2442    /// Reference W-based Gram for verification: build `W = sqrt(H) · J` then
2443    /// return `Wᵀ W`. Mirrors the in-walk path in [`compile`].
2444    fn reference_gram_from_w(j_full: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
2445        let w = scale_block_by_sqrt_h(j_full, h_full);
2446        fast_ata(&w)
2447    }
2448
2449    /// Two-block toy at K=4: build per-channel (n × p_b) blocks and verify
2450    /// the closed-form Gram matches the reference W-based Gram.
2451    #[test]
2452    fn closed_form_gram_matches_reference_two_block_k4() {
2453        let n = 17;
2454        let k = 4;
2455        let p_a = 3;
2456        let p_b = 2;
2457
2458        // Random-ish per-channel design matrices for each block.
2459        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2460            (0..4)
2461                .map(|c| {
2462                    let m = Array2::from_shape_fn((n, p), |(i, j)| {
2463                        ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2464                    });
2465                    Some(m)
2466                })
2467                .collect()
2468        };
2469        let block_a = make_block(0.3, n, p_a);
2470        let block_b = make_block(1.1, n, p_b);
2471
2472        // Per-row PSD H: random symmetric PSD via Mᵀ M.
2473        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2474            let mut acc = 0.0;
2475            for r in 0..k {
2476                let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2477                let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2478                acc += mc * md;
2479            }
2480            acc + if c == d { 0.5 } else { 0.0 }
2481        });
2482        let row_hess = DenseRowHessian { h: h.clone() };
2483
2484        let channel_blocks = PrimaryChannelBlocks {
2485            blocks: vec![block_a.clone(), block_b.clone()],
2486        };
2487        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2488
2489        let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2490            .expect("closed-form Gram should succeed");
2491
2492        // Reference: assemble full row Jacobian J as (n × p_total × K) by
2493        // placing per-block, per-channel slices at the right columns.
2494        let p_total = p_a + p_b;
2495        let mut j_full = Array3::<f64>::zeros((n, p_total, k));
2496        for c in 0..k {
2497            if let Some(xa) = block_a[c].as_ref() {
2498                for i in 0..n {
2499                    for j in 0..p_a {
2500                        j_full[[i, j, c]] = xa[[i, j]];
2501                    }
2502                }
2503            }
2504            if let Some(xb) = block_b[c].as_ref() {
2505                for i in 0..n {
2506                    for j in 0..p_b {
2507                        j_full[[i, p_a + j, c]] = xb[[i, j]];
2508                    }
2509                }
2510            }
2511        }
2512        let ref_gram = reference_gram_from_w(&j_full, &h);
2513
2514        let diff = &gram - &ref_gram;
2515        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2516        let scale = ref_gram.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2517        assert!(
2518            max_err < 1e-9 * scale.max(1.0),
2519            "closed-form Gram mismatches reference: max_err={max_err:e}, scale={scale:e}"
2520        );
2521
2522        // Symmetry of the result.
2523        for i in 0..p_total {
2524            for j in 0..p_total {
2525                assert!(
2526                    (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2527                    "closed-form Gram not symmetric at ({i},{j})"
2528                );
2529            }
2530        }
2531    }
2532
2533    /// Channel sparsity test: block A contributes only to channel 0, block B
2534    /// only to channel 3. Cross-block contribution must be exactly
2535    /// `(X_A^(0))ᵀ · diag(h_{03}) · X_B^(3)` — zero when `h_03 ≡ 0`,
2536    /// non-zero otherwise.
2537    #[test]
2538    fn closed_form_gram_channel_sparsity() {
2539        let n = 13;
2540        let k = 4;
2541        let p_a = 2;
2542        let p_b = 2;
2543
2544        let xa = Array2::from_shape_fn((n, p_a), |(i, j)| ((i + 1) as f64 * 0.21 + j as f64).cos());
2545        let xb = Array2::from_shape_fn((n, p_b), |(i, j)| {
2546            ((i + 1) as f64 * 0.17 + j as f64).sin() + 0.5
2547        });
2548
2549        let block_a: Vec<Option<Array2<f64>>> = vec![Some(xa.clone()), None, None, None];
2550        let block_b: Vec<Option<Array2<f64>>> = vec![None, None, None, Some(xb.clone())];
2551
2552        // Case 1: H with non-zero h_{03} (and h_{30}). The cross-block
2553        // (A, B) entries must equal `Xaᵀ · diag(h_03) · Xb`.
2554        let h_03_vec = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.4).sin());
2555        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2556            // Symmetric: only the (0,3)/(3,0) off-diagonal carries weight,
2557            // plus a strong PSD diagonal so per-row H is PSD.
2558            if (c, d) == (0, 3) || (c, d) == (3, 0) {
2559                h_03_vec[i]
2560            } else if c == d {
2561                2.0
2562            } else {
2563                0.0
2564            }
2565        });
2566        let row_hess = DenseRowHessian { h: h.clone() };
2567
2568        let channel_blocks = PrimaryChannelBlocks {
2569            blocks: vec![block_a.clone(), block_b.clone()],
2570        };
2571        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2572        let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2573            .expect("closed-form Gram should succeed");
2574
2575        // Cross-block submatrix.
2576        let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2577        // Expected: only the (c=0, d=3) channel-pair survives.
2578        let expected = fast_xt_diag_y(&xa, &h_03_vec, &xb);
2579        let diff = &cross - &expected;
2580        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2581        assert!(
2582            max_err < 1e-12,
2583            "cross-block Gram must equal Xaᵀ·diag(h_03)·Xb: max_err={max_err:e}"
2584        );
2585
2586        // Case 2: zero out h_{03} → cross-block must be zero.
2587        let h_zero = Array3::from_shape_fn((n, k, k), |(_, c, d)| if c == d { 2.0 } else { 0.0 });
2588        let row_hess_zero = DenseRowHessian { h: h_zero };
2589        let gram_zero =
2590            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess_zero, &raw_ranges)
2591                .expect("closed-form Gram should succeed");
2592        let cross_zero = gram_zero.slice(s![0..p_a, p_a..(p_a + p_b)]);
2593        let max_zero = cross_zero.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2594        assert!(
2595            max_zero < 1e-12,
2596            "cross-block Gram must vanish when coupling channel pair is zero: got {max_zero:e}"
2597        );
2598    }
2599
2600    /// Structural Gram: identity per-row Hessian collapses the channel-pair
2601    /// sum to within-channel `XᵀX`. Validates [`build_raw_grams_structural`].
2602    #[test]
2603    fn structural_gram_matches_within_channel_sum() {
2604        let n = 11;
2605        let p_a = 2;
2606        let p_b = 3;
2607        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2608            (0..4)
2609                .map(|c| {
2610                    if c == 1 {
2611                        // Sparse channel for variety.
2612                        return None;
2613                    }
2614                    Some(Array2::from_shape_fn((n, p), |(i, j)| {
2615                        ((i as f64 + 1.0) * (j as f64 + 1.0) + seed * (c as f64 + 1.0)).sin()
2616                    }))
2617                })
2618                .collect()
2619        };
2620        let block_a = make_block(0.1, n, p_a);
2621        let block_b = make_block(0.7, n, p_b);
2622        let channel_blocks = PrimaryChannelBlocks {
2623            blocks: vec![block_a.clone(), block_b.clone()],
2624        };
2625        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2626        let gram = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2627
2628        // Hand-compute cross block: Σ_c Xaᵀ Xb over channels where both
2629        // sides are present (skipping channel 1 entirely).
2630        let mut expected_cross = Array2::<f64>::zeros((p_a, p_b));
2631        for c in 0..4 {
2632            if let (Some(xa), Some(xb)) = (block_a[c].as_ref(), block_b[c].as_ref()) {
2633                expected_cross += &fast_atb(xa, xb);
2634            }
2635        }
2636        let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2637        let diff = &cross - &expected_cross;
2638        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2639        assert!(
2640            max_err < 1e-12,
2641            "structural cross-block must equal Σ_c Xaᵀ·Xb: max_err={max_err:e}"
2642        );
2643
2644        // Symmetry.
2645        for i in 0..(p_a + p_b) {
2646            for j in 0..(p_a + p_b) {
2647                assert!(
2648                    (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2649                    "structural Gram not symmetric at ({i},{j})"
2650                );
2651            }
2652        }
2653    }
2654
2655    // Per-row Hessian (K=1) sourced from an arbitrary positive vector —
2656    // used by the dual-metric sanity test to drive both structural and
2657    // curvature passes with the *same* non-identity weights.
2658    fn diag_hess(w: Array1<f64>) -> DiagonalScalarRowHessian {
2659        DiagonalScalarRowHessian::new(w)
2660    }
2661
2662    /// L#1: dual-metric with structural = curvature reproduces single-metric
2663    /// `compile()` exactly. The two passes degenerate to one because the
2664    /// structural-anchor and curvature-anchor are the same matrix.
2665    #[test]
2666    fn dual_metric_with_equal_metrics_matches_single_metric() {
2667        let n = 36;
2668        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.13 + j as f64).sin());
2669        // B partially aliases A's first column.
2670        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2671            0.4 * a[[i, 0]] + (i as f64 * 0.07 + j as f64).cos()
2672        });
2673        let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.17).sin().abs());
2674        let curvature = diag_hess(w.clone());
2675        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2676
2677        let ops_single = vec![op(a.clone()), op(b.clone())];
2678        let single = compile(&ops_single, &curvature, &ordering)
2679            .expect("single-metric compile should succeed");
2680
2681        // Dual-metric with structural = curvature (same `RowHessian` on both
2682        // sides). The structural pass collapses to the curvature pass.
2683        let structural_same = diag_hess(w.clone());
2684        let ops_dual = vec![op(a.clone()), op(b.clone())];
2685        let dual = compile_with_dual_metric(&ops_dual, &curvature, &structural_same, &ordering)
2686            .expect("dual-metric compile should succeed");
2687
2688        assert_eq!(single.blocks.len(), dual.blocks.len());
2689        for (idx, (sb, db)) in single.blocks.iter().zip(dual.blocks.iter()).enumerate() {
2690            assert_eq!(sb.t_lw.dim(), db.t_lw.dim(), "block {idx}: V dims differ");
2691            let max_v = (&sb.t_lw - &db.t_lw)
2692                .iter()
2693                .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2694            assert!(max_v < 1e-10, "block {idx}: V mismatch {max_v:e}");
2695            match (sb.anchor_correction.as_ref(), db.anchor_correction.as_ref()) {
2696                (None, None) => {}
2697                (Some(s), Some(d)) => {
2698                    assert_eq!(s.dim(), d.dim());
2699                    let max_m = (s - d).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2700                    assert!(max_m < 1e-10, "block {idx}: M mismatch {max_m:e}");
2701                }
2702                _ => panic!("block {idx}: one side has anchor correction, the other does not"),
2703            }
2704        }
2705        assert_eq!(single.joint_rank, dual.joint_rank);
2706    }
2707
2708    /// L#2: the pilot-curvature trap. A 2-block toy where the pilot
2709    /// curvature `H` has a zero direction that is NOT a real gauge — the
2710    /// dual-metric path keeps it (identity-structural sees it as a full-
2711    /// rank structural direction), while a single-metric path through the
2712    /// same H would drop it.
2713    ///
2714    /// Construction: two K=1 blocks `A` (n × 1) and `B` (n × 1). Choose H
2715    /// (diagonal row weights) so that `H · B` happens to be a scalar
2716    /// multiple of `H · A` (curvature alias) but `B` is *not* a scalar
2717    /// multiple of `A` in the unweighted metric. Specifically, pick rows
2718    /// where `w_i` is non-zero only on a handful of rows where A and B
2719    /// happen to be proportional, and zero on the rows where they differ.
2720    /// Under identity-structural this is structurally-independent; under H
2721    /// it is a (spurious) curvature alias.
2722    #[test]
2723    fn dual_metric_resists_pilot_curvature_alias() {
2724        let n = 12;
2725        // A: x_i = i+1 (no zeros). B: equals 2·A on rows 0..6 only; the
2726        // remaining rows are uncorrelated (linear vs trigonometric).
2727        let a = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) + 1.0);
2728        let b = Array2::from_shape_fn((n, 1), |(i, _)| {
2729            if i < 6 {
2730                2.0 * a[[i, 0]]
2731            } else {
2732                ((i as f64) * 0.3).cos() + 0.5
2733            }
2734        });
2735
2736        // Curvature weights are non-zero ONLY on the rows where B == 2A.
2737        // Under curvature metric, B is exactly 2·A → curvature-rank drops
2738        // B fully. Under identity-structural, B is independent of A across
2739        // all rows → structural-rank is 1 (kept).
2740        let mut w_vec = vec![0.0_f64; n];
2741        for w in &mut w_vec[..6] {
2742            *w = 1.0;
2743        }
2744        let w = Array1::from(w_vec);
2745        let curvature = diag_hess(w.clone());
2746
2747        // Reference single-metric compile (uses identity by `compile()` —
2748        // which now routes through identity-structural). For this test we
2749        // explicitly invoke the dual-metric API both ways.
2750        let id_struct = IdentityRowHessian::new(n, 1);
2751        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2752
2753        // Path 1: dual-metric with identity-structural (the new default).
2754        // Structural pass: B is independent of A across all rows → keep
2755        // B's single column.
2756        let ops_dual = vec![op(a.clone()), op(b.clone())];
2757        let dual = compile_with_dual_metric(&ops_dual, &curvature, &id_struct, &ordering);
2758
2759        // Path 2: dual-metric with structural = curvature (the "H decides
2760        // everything" trap). On the curvature-only rows, B ≡ 2A, so
2761        // structural pass sees zero residual span and rejects the block.
2762        let ops_h_only = vec![op(a.clone()), op(b.clone())];
2763        let h_only = compile_with_dual_metric(&ops_h_only, &curvature, &curvature, &ordering);
2764
2765        // The H-only path must fail (FullyAliased) or strip B's column.
2766        // The dual (identity-structural) path must keep B.
2767        match h_only {
2768            Err(CompilerError::FullyAliased { block_idx, .. }) => {
2769                assert_eq!(block_idx, 1, "H-only path must alias block 1");
2770            }
2771            Ok(out) => {
2772                // If the H-only path somehow compiled, it must have
2773                // either dropped B's column to zero width or audited it
2774                // out. Either way B's V must be empty after the audit
2775                // attributes the drop.
2776                let v1_cols = out.blocks[1].t_lw.ncols();
2777                assert!(
2778                    v1_cols == 0 || !out.dropped.is_empty(),
2779                    "H-only path should reject B's curvature-aliased column; v1_cols={v1_cols}, dropped={dropped:?}",
2780                    dropped = out.dropped,
2781                );
2782            }
2783            Err(other) => panic!("unexpected H-only error: {other:?}"),
2784        }
2785
2786        let dual =
2787            dual.expect("dual-metric must succeed: identity-structural sees B as independent");
2788        // The dual path may still drop B's column at the joint audit step
2789        // because the joint H-scaled design is rank-1 (only the first
2790        // block contributes non-zero rows under the curvature weights).
2791        // What matters is that the *structural* decision did NOT drop B
2792        // — verified by the structural pass not raising FullyAliased and
2793        // by B's `t_lw` having the full structural width before the audit
2794        // demotes it. After audit, B's V may shrink because the curvature
2795        // joint design is rank-deficient, and that is expected.
2796        assert_eq!(dual.blocks.len(), 2);
2797        assert_eq!(dual.blocks[0].t_lw.ncols(), 1, "A must keep its column");
2798        // Block 1 either keeps its structural rank-1 column or is audited
2799        // away by the joint H-rank check, but in either case the per-block
2800        // pre-audit width must reflect that the structural pass kept the
2801        // column (i.e. the function did not return FullyAliased).
2802        let v1_post_audit = dual.blocks[1].t_lw.ncols();
2803        let dropped_count = dual.dropped.len();
2804        assert_eq!(
2805            v1_post_audit + dropped_count,
2806            1,
2807            "structural pass kept B's column; audit may demote it but the pre-audit width was 1"
2808        );
2809    }
2810
2811    /// L#3: identity-structural lets the compiler keep a direction even
2812    /// when the pilot curvature has reduced rank. This is the same
2813    /// scenario as L#2 but with a curvature `H` whose row weights are all
2814    /// strictly positive — so the *only* aliasing source is the structural
2815    /// pass deciding to keep or drop. The dual-metric path with non-trivial
2816    /// `H` and identity-structural must agree with the dual-metric path
2817    /// with identity on both sides whenever the blocks are structurally
2818    /// non-aliased.
2819    #[test]
2820    fn dual_metric_identity_structural_preserves_full_rank() {
2821        let n = 24;
2822        let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 + j as f64).sqrt());
2823        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2824            ((i + 1) as f64).ln() + (i as f64 * 0.1 + j as f64).cos()
2825        });
2826        let w = Array1::from_shape_fn(n, |i| 0.4 + (i as f64 * 0.05).sin().powi(2));
2827        let curvature = diag_hess(w.clone());
2828        let id_struct = IdentityRowHessian::new(n, 1);
2829        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2830
2831        let ops = vec![op(a.clone()), op(b.clone())];
2832        let out =
2833            compile_with_dual_metric(&ops, &curvature, &id_struct, &ordering).expect("compile");
2834        // Both blocks structurally independent → both keep full width.
2835        assert_eq!(out.blocks[0].t_lw.ncols(), 2);
2836        assert_eq!(out.blocks[1].t_lw.ncols(), 2);
2837        assert_eq!(out.dropped.len(), 0);
2838        assert_eq!(out.joint_rank, 4);
2839    }
2840
2841    /// Smoke test for the GPU-or-CPU dispatch helper. On non-CUDA hosts
2842    /// (or when the runtime is unavailable) the helper falls back to the
2843    /// CPU closed-form builders; the result must match the CPU builders
2844    /// called directly. When a CUDA runtime is live, parity vs. CPU is
2845    /// verified to tight tolerance.
2846    #[test]
2847    fn build_primary_grams_gpu_or_cpu_two_block_k4_matches_cpu() {
2848        let n = 11;
2849        let k = 4;
2850        let p_a = 2;
2851        let p_b = 3;
2852
2853        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2854            (0..4)
2855                .map(|c| {
2856                    let m = Array2::from_shape_fn((n, p), |(i, j)| {
2857                        ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2858                    });
2859                    Some(m)
2860                })
2861                .collect()
2862        };
2863        let block_a = make_block(0.7, n, p_a);
2864        let block_b = make_block(-0.4, n, p_b);
2865
2866        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2867            let mut acc = 0.0;
2868            for r in 0..k {
2869                let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2870                let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2871                acc += mc * md;
2872            }
2873            acc + if c == d { 0.25 } else { 0.0 }
2874        });
2875        let row_hess = DenseRowHessian { h: h.clone() };
2876
2877        let channel_blocks = PrimaryChannelBlocks {
2878            blocks: vec![block_a, block_b],
2879        };
2880        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2881
2882        let (gram_h, gram_struct) =
2883            build_primary_grams_gpu_or_cpu(&channel_blocks, &row_hess, &raw_ranges)
2884                .expect("dispatch helper should succeed");
2885
2886        let cpu_h = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2887            .expect("CPU curvature Gram should succeed");
2888        let cpu_s = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2889
2890        let tol = 1e-9_f64;
2891        for idx in cpu_h.indexed_iter().map(|(i, _)| i) {
2892            let diff = (gram_h[idx] - cpu_h[idx]).abs();
2893            let scale = cpu_h[idx].abs().max(1.0);
2894            assert!(
2895                diff <= tol * scale,
2896                "gram_h mismatch at {idx:?}: helper={} cpu={}",
2897                gram_h[idx],
2898                cpu_h[idx]
2899            );
2900        }
2901        for idx in cpu_s.indexed_iter().map(|(i, _)| i) {
2902            let diff = (gram_struct[idx] - cpu_s[idx]).abs();
2903            let scale = cpu_s[idx].abs().max(1.0);
2904            assert!(
2905                diff <= tol * scale,
2906                "gram_struct mismatch at {idx:?}: helper={} cpu={}",
2907                gram_struct[idx],
2908                cpu_s[idx]
2909            );
2910        }
2911    }
2912
2913    // ---- compile_from_raw_grams tests ----
2914
2915    /// Build (gram_h, gram_struct) for a K=1 scalar two-block toy via the
2916    /// per-block channel-block builders. Used by the closed-form tests
2917    /// below.
2918    fn scalar_grams_two_block(
2919        a: &Array2<f64>,
2920        b: &Array2<f64>,
2921        w: &Array1<f64>,
2922    ) -> (Array2<f64>, Array2<f64>, Vec<std::ops::Range<usize>>) {
2923        let p_a = a.ncols();
2924        let p_b = b.ncols();
2925        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2926        let channel_blocks = PrimaryChannelBlocks {
2927            blocks: vec![vec![Some(a.clone())], vec![Some(b.clone())]],
2928        };
2929        let row_hess = DiagonalScalarRowHessian::new(w.clone());
2930        let gram_h =
2931            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2932        let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2933        (gram_h, gram_struct, raw_ranges)
2934    }
2935
2936    /// Block B is a column-duplicate of block A in the structural metric
2937    /// → the lower-priority block compiles to zero width instead of making
2938    /// callers skip reduced-coordinate construction.
2939    #[test]
2940    fn compile_from_raw_grams_full_structural_alias() {
2941        let n = 10;
2942        let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (j + 1) as f64).sin());
2943        // Block B = A · L for some 2×2 invertible L → same column span.
2944        let l = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, -0.25, 1.0]).unwrap();
2945        let b = a.dot(&l);
2946        let w = Array1::ones(n);
2947        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2948        let res = compile_from_raw_grams(
2949            &gram_h,
2950            &gram_struct,
2951            &raw_ranges,
2952            &[BlockOrder::Marginal, BlockOrder::Logslope],
2953        )
2954        .expect("lower-priority full alias should compile to zero width");
2955        assert_eq!(res.compiled_block_ranges[0].len(), 2);
2956        assert_eq!(res.compiled_block_ranges[1].len(), 0);
2957        assert_eq!(res.raw_from_compiled.dim(), (4, 2));
2958        assert!(
2959            res.raw_from_compiled
2960                .slice(s![raw_ranges[1].clone(), ..])
2961                .iter()
2962                .all(|v| v.abs() <= 1.0e-12),
2963            "zero-width block must not retain raw coefficient directions in T"
2964        );
2965    }
2966
2967    /// A zero-width *first* block has no columns to alias and must compile to
2968    /// an empty range with the remaining blocks intact — not abort with
2969    /// `FullyAliased`. Regression for the survival location-scale lognormal AFT
2970    /// pre-fit channel-aware audit, whose `time_transform` block collapses to
2971    /// zero free coefficients under the parametric AFT reduction and previously
2972    /// crashed the fit ("block of width 0 has zero structural span").
2973    #[test]
2974    fn compile_from_raw_grams_zero_width_first_block_is_identifiable() {
2975        let n = 12;
2976        let empty = Array2::<f64>::zeros((n, 0));
2977        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2978            ((i + 1) as f64 * (j + 1) as f64 * 0.23).cos()
2979        });
2980        let w = Array1::ones(n);
2981        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&empty, &b, &w);
2982        let map = compile_from_raw_grams(
2983            &gram_h,
2984            &gram_struct,
2985            &raw_ranges,
2986            &[BlockOrder::Marginal, BlockOrder::Logslope],
2987        )
2988        .expect("zero-width first block must be trivially identifiable, not FullyAliased");
2989        assert_eq!(
2990            map.compiled_block_ranges[0].len(),
2991            0,
2992            "empty first block keeps zero columns"
2993        );
2994        assert_eq!(
2995            map.compiled_block_ranges[1].len(),
2996            2,
2997            "the second block keeps its full structural rank"
2998        );
2999        assert_eq!(map.raw_from_compiled.dim(), (2, 2));
3000    }
3001
3002    /// A `protected` first block keeps every raw column even when it is
3003    /// internally rank-deficient (a duplicate-column structural null that the
3004    /// unprotected path drops), and later blocks still reduce against the full
3005    /// raw anchor. Regression for the survival marginal-slope monotone
3006    /// time-wiggle time block: its chain-rule Jacobian recomputes a fixed
3007    /// `p_tw`-column wiggle basis on every evaluation, so a reduced (`p_time <
3008    /// p_tw`) time design made that Jacobian write past its buffer — an
3009    /// out-of-bounds panic in the phase-4b compiled-map path.
3010    #[test]
3011    fn compile_from_raw_grams_protected_keeps_full_rank_deficient_first_block() {
3012        let n = 14;
3013        // Block A (first, highest priority): two IDENTICAL columns → structural
3014        // rank 1, i.e. one within-block null direction the unprotected filter
3015        // drops. Stands in for the wiggle time block whose raw width must be
3016        // preserved.
3017        // Column value depends only on the row → both columns are identical.
3018        let a = Array2::from_shape_fn((n, 2), |(i, _)| ((i + 1) as f64 * 0.37).sin());
3019        // Block B: genuinely independent, so it survives at full width.
3020        let b = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (0.29 + j as f64 * 0.11)).cos());
3021        let w = Array1::ones(n);
3022        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
3023        let ordering = [BlockOrder::Time, BlockOrder::Marginal];
3024
3025        // Unprotected: the duplicate column is dropped → block 0 reduces to 1.
3026        let unprotected =
3027            compile_from_raw_grams(&gram_h, &gram_struct, &raw_ranges, &ordering)
3028                .expect("unprotected compile");
3029        assert_eq!(
3030            unprotected.compiled_block_ranges[0].len(),
3031            1,
3032            "unprotected first block drops its structural-null direction"
3033        );
3034
3035        // Protected: block 0 keeps both raw columns; T's block-0 diagonal is the
3036        // 2×2 identity (raw coords == compiled coords for the protected block).
3037        let protected =
3038            compile_from_raw_grams_protected(&gram_h, &gram_struct, &raw_ranges, &ordering, &[true, false])
3039                .expect("protected compile");
3040        assert_eq!(
3041            protected.compiled_block_ranges[0].len(),
3042            2,
3043            "protected first block retains its full raw width"
3044        );
3045        let t_block0 = protected
3046            .raw_from_compiled
3047            .slice(s![0..2, protected.compiled_block_ranges[0].clone()])
3048            .to_owned();
3049        for i in 0..2 {
3050            for j in 0..2 {
3051                let expect = if i == j { 1.0 } else { 0.0 };
3052                assert!(
3053                    (t_block0[[i, j]] - expect).abs() <= 1e-12,
3054                    "protected first block map must be identity, got [{i},{j}]={}",
3055                    t_block0[[i, j]]
3056                );
3057            }
3058        }
3059    }
3060
3061    #[test]
3062    fn orthogonalization_annotates_independent_and_fully_absorbed_blocks() {
3063        let n = 18;
3064        let anchor = Array2::from_shape_fn((n, 2), |(i, j)| {
3065            ((i + 1) as f64 * (0.19 + j as f64 * 0.07)).sin()
3066        });
3067        let duplicate = anchor.clone();
3068        let independent = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 1) as f64 * 0.43).cos());
3069        let weight = vec![1.0; n];
3070        let ortho = orthogonalize_design_blocks(
3071            &[anchor, duplicate, independent],
3072            &[200, 100, 50],
3073            &weight,
3074        )
3075        .expect("structural annotation compile");
3076
3077        assert_eq!(
3078            ortho.direction_annotations[0].kind,
3079            PenalizedDirectionAnnotationKind::Independent
3080        );
3081        assert_eq!(ortho.direction_annotations[0].absorbed_width, 0);
3082        assert_eq!(
3083            ortho.direction_annotations[1].kind,
3084            PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority,
3085            "a duplicated lower-priority block is the same realized-design direction"
3086        );
3087        assert_eq!(ortho.direction_annotations[1].raw_width, 2);
3088        assert_eq!(ortho.direction_annotations[1].kept_width, 0);
3089        assert_eq!(ortho.direction_annotations[1].absorbed_width, 2);
3090        assert_eq!(
3091            ortho.direction_annotations[2].kind,
3092            PenalizedDirectionAnnotationKind::Independent,
3093            "a genuinely new realized-design direction keeps its own penalty block"
3094        );
3095        assert_eq!(ortho.direction_annotations[2].raw_width, 1);
3096        assert_eq!(ortho.direction_annotations[2].kept_width, 1);
3097        assert_eq!(ortho.dropped, vec![(1, 2)]);
3098    }
3099
3100    #[test]
3101    fn compile_from_raw_grams_three_block_full_logslope_alias_keeps_fast_path() {
3102        let n = 24;
3103        let time = Array2::from_shape_fn((n, 2), |(i, j)| {
3104            ((i + 1) as f64 * (j + 2) as f64 * 0.17).sin()
3105        });
3106        let marginal = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 3) as f64 * 0.11).cos());
3107        let logslope = marginal.clone();
3108        let p_time = time.ncols();
3109        let p_marg = marginal.ncols();
3110        let p_log = logslope.ncols();
3111        let raw_ranges = vec![
3112            0..p_time,
3113            p_time..(p_time + p_marg),
3114            (p_time + p_marg)..(p_time + p_marg + p_log),
3115        ];
3116        let channel_blocks = PrimaryChannelBlocks {
3117            blocks: vec![
3118                vec![Some(time.clone())],
3119                vec![Some(marginal.clone())],
3120                vec![Some(logslope.clone())],
3121            ],
3122        };
3123        let row_hess = DiagonalScalarRowHessian::new(Array1::ones(n));
3124        let gram_h =
3125            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
3126        let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3127
3128        let map = compile_from_raw_grams(
3129            &gram_h,
3130            &gram_struct,
3131            &raw_ranges,
3132            &[BlockOrder::Time, BlockOrder::Marginal, BlockOrder::Logslope],
3133        )
3134        .expect("fully aliased logslope block should not skip the compiled-map path");
3135
3136        assert_eq!(map.compiled_block_ranges[0].len(), p_time);
3137        assert_eq!(map.compiled_block_ranges[1].len(), p_marg);
3138        assert_eq!(map.compiled_block_ranges[2].len(), 0);
3139        assert_eq!(
3140            map.raw_from_compiled.dim(),
3141            (p_time + p_marg + p_log, p_time + p_marg)
3142        );
3143        let x_raw = {
3144            let mut out = Array2::<f64>::zeros((n, p_time + p_marg + p_log));
3145            out.slice_mut(s![.., raw_ranges[0].clone()]).assign(&time);
3146            out.slice_mut(s![.., raw_ranges[1].clone()])
3147                .assign(&marginal);
3148            out.slice_mut(s![.., raw_ranges[2].clone()])
3149                .assign(&logslope);
3150            out
3151        };
3152        let x_compiled = fast_ab(&x_raw, &map.raw_from_compiled);
3153        let rrqr = rrqr_with_permutation(&x_compiled, default_rrqr_rank_alpha()).unwrap();
3154        assert_eq!(rrqr.rank, x_compiled.ncols());
3155    }
3156
3157    /// Partial alias: block B's first column duplicates A; second column is
3158    /// independent. Closed-form `T` must have shape `(p_raw × (p_a + 1))`
3159    /// — block 1's compiled width is exactly the independent direction —
3160    /// and the joint design `X_raw · T` must span the same column space as
3161    /// the W-based reference compile result.
3162    #[test]
3163    fn compile_from_raw_grams_partial_alias_matches_w_reference() {
3164        let n = 25;
3165        let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3166            ((i + 1) as f64 * (j + 1) as f64 * 0.3).sin()
3167        });
3168        // B = [a_0  +  independent]
3169        let mut b = Array2::<f64>::zeros((n, 2));
3170        for i in 0..n {
3171            b[[i, 0]] = a[[i, 0]];
3172            b[[i, 1]] = ((i + 1) as f64 * 0.7).cos();
3173        }
3174        let w = Array1::from_shape_fn(n, |i| 1.0 + 0.1 * (i as f64));
3175        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
3176        let compiled = compile_from_raw_grams(
3177            &gram_h,
3178            &gram_struct,
3179            &raw_ranges,
3180            &[BlockOrder::Marginal, BlockOrder::Logslope],
3181        )
3182        .expect("closed-form compile must succeed");
3183        let p_a = a.ncols();
3184        let p_b = b.ncols();
3185        assert_eq!(compiled.raw_from_compiled.shape()[0], p_a + p_b);
3186        assert_eq!(
3187            compiled.raw_from_compiled.shape()[1],
3188            p_a + 1,
3189            "partial alias should leave compiled width = p_a + 1 (one column dropped from B)"
3190        );
3191        // Block ranges sum to compiled width.
3192        assert_eq!(compiled.compiled_block_ranges[0], 0..p_a);
3193        assert_eq!(
3194            compiled.compiled_block_ranges[1].end - compiled.compiled_block_ranges[1].start,
3195            1
3196        );
3197
3198        // Column-span equality vs. W-reference: stack the raw design
3199        // X_raw = [A | B] and check that range(X_raw · T) ⊆ range(X_raw)
3200        // and has the same rank as the W-based compile.
3201        let mut x_raw = Array2::<f64>::zeros((n, p_a + p_b));
3202        for i in 0..n {
3203            for j in 0..p_a {
3204                x_raw[[i, j]] = a[[i, j]];
3205            }
3206            for j in 0..p_b {
3207                x_raw[[i, p_a + j]] = b[[i, j]];
3208            }
3209        }
3210        let x_compiled = fast_ab(&x_raw, &compiled.raw_from_compiled);
3211        // Rank of compiled design via Gram eigvals.
3212        let g_compiled = fast_ata(&x_compiled);
3213        let (evals, _) = g_compiled.eigh(Side::Lower).unwrap();
3214        let lam_max = evals.iter().cloned().fold(0.0_f64, f64::max);
3215        let tol = lam_max * 64.0 * (g_compiled.nrows() as f64) * f64::EPSILON;
3216        let rank_compiled = evals.iter().filter(|&&l| l > tol).count();
3217        assert_eq!(
3218            rank_compiled,
3219            p_a + 1,
3220            "compiled design column rank must equal p_a + 1 after dropping the alias"
3221        );
3222
3223        // Reference compile via the W-based dual-metric path on the same
3224        // scalar blocks; compiled total width should also be p_a + 1.
3225        let ops_dual: Vec<Arc<dyn RowJacobianOperator>> = vec![op(a.clone()), op(b.clone())];
3226        let curvature = DiagonalScalarRowHessian::new(w.clone());
3227        let id_struct = IdentityRowHessian::new(n, 1);
3228        let dual = compile_with_dual_metric(
3229            &ops_dual,
3230            &curvature,
3231            &id_struct,
3232            &[BlockOrder::Marginal, BlockOrder::Logslope],
3233        )
3234        .expect("dual metric compile should succeed");
3235        let dual_total: usize = dual.blocks.iter().map(|b| b.t_lw.ncols()).sum();
3236        assert_eq!(dual_total, p_a + 1, "W-reference total width should match");
3237    }
3238
3239    /// Three-block toy: changing the ordering changes the per-block
3240    /// compiled widths (later blocks absorb the alias instead of earlier).
3241    #[test]
3242    fn compile_from_raw_grams_three_block_ordering_matters() {
3243        let n = 30;
3244        let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3245            ((i + 1) as f64 * (j + 2) as f64 * 0.2).sin()
3246        });
3247        // B has 2 cols: col 0 independent, col 1 = a[:, 0]
3248        let mut b = Array2::<f64>::zeros((n, 2));
3249        for i in 0..n {
3250            b[[i, 0]] = ((i + 1) as f64 * 0.4).cos();
3251            b[[i, 1]] = a[[i, 0]];
3252        }
3253        // C has 2 cols: col 0 independent, col 1 = a[:, 1]
3254        let mut c = Array2::<f64>::zeros((n, 2));
3255        for i in 0..n {
3256            c[[i, 0]] = ((i + 1) as f64 * 0.55).sin();
3257            c[[i, 1]] = a[[i, 1]];
3258        }
3259        let w = Array1::ones(n);
3260
3261        let build = |b0: &Array2<f64>, b1: &Array2<f64>, b2: &Array2<f64>| {
3262            let raw_ranges = vec![
3263                0..b0.ncols(),
3264                b0.ncols()..(b0.ncols() + b1.ncols()),
3265                (b0.ncols() + b1.ncols())..(b0.ncols() + b1.ncols() + b2.ncols()),
3266            ];
3267            let channel_blocks = PrimaryChannelBlocks {
3268                blocks: vec![
3269                    vec![Some(b0.clone())],
3270                    vec![Some(b1.clone())],
3271                    vec![Some(b2.clone())],
3272                ],
3273            };
3274            let row_hess = DiagonalScalarRowHessian::new(w.clone());
3275            let gram_h =
3276                build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
3277                    .unwrap();
3278            let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3279            (gram_h, gram_struct, raw_ranges)
3280        };
3281
3282        // Order 1: A, B, C — B drops 1 (col 1 aliased to A), C drops 1.
3283        let (gh, gs, rr) = build(&a, &b, &c);
3284        let order_abc = compile_from_raw_grams(
3285            &gh,
3286            &gs,
3287            &rr,
3288            &[
3289                BlockOrder::Marginal,
3290                BlockOrder::Logslope,
3291                BlockOrder::LinkDev,
3292            ],
3293        )
3294        .expect("ABC compile");
3295        assert_eq!(order_abc.compiled_block_ranges[0].len(), 2);
3296        assert_eq!(order_abc.compiled_block_ranges[1].len(), 1);
3297        assert_eq!(order_abc.compiled_block_ranges[2].len(), 1);
3298
3299        // Order 2: B, A, C — A's col 0 is aliased by B's col 1 now; A's
3300        // col 1 is independent. So A drops 1; C still drops 1.
3301        let (gh2, gs2, rr2) = build(&b, &a, &c);
3302        let order_bac = compile_from_raw_grams(
3303            &gh2,
3304            &gs2,
3305            &rr2,
3306            &[
3307                BlockOrder::Marginal,
3308                BlockOrder::Logslope,
3309                BlockOrder::LinkDev,
3310            ],
3311        )
3312        .expect("BAC compile");
3313        assert_eq!(order_bac.compiled_block_ranges[0].len(), 2);
3314        assert_eq!(order_bac.compiled_block_ranges[1].len(), 1);
3315        // Total rank invariant under permutation: 4.
3316        let total_abc: usize = order_abc
3317            .compiled_block_ranges
3318            .iter()
3319            .map(|r| r.len())
3320            .sum();
3321        let total_bac: usize = order_bac
3322            .compiled_block_ranges
3323            .iter()
3324            .map(|r| r.len())
3325            .sum();
3326        assert_eq!(total_abc, total_bac);
3327        assert_eq!(total_abc, 4);
3328    }
3329
3330    /// Build a K=1 raw `(gram_h, gram_struct)` pair for a single stacked design
3331    /// `X` with per-row curvature weights `w`: `gram_struct = Xᵀ X`,
3332    /// `gram_h = Xᵀ diag(w) X`. Mirrors the closed-form definitions the
3333    /// production Gram builders implement for the scalar-channel case.
3334    fn k1_grams(x: &Array2<f64>, w: &Array1<f64>) -> (Array2<f64>, Array2<f64>) {
3335        let gram_struct = fast_atb(x, x);
3336        let xw = fast_xt_diag_y(x, w, x);
3337        (xw, gram_struct)
3338    }
3339
3340    /// Full-rank reduction: when the two blocks are jointly independent the
3341    /// compiled width equals the raw width and the lift `T` reproduces a raw
3342    /// coefficient exactly from its compiled image `θ = T⁺ β` (here, with no
3343    /// aliasing, `lift_coefficients(θ)` of any compiled `θ` lands in the raw
3344    /// design's column interpretation: applying `T` then comparing the induced
3345    /// raw predictor `X·Tθ` to `X·β_raw` for the `θ` solving `Tθ=β_raw`).
3346    #[test]
3347    fn compiled_map_lift_coefficients_roundtrips_full_rank() {
3348        let n = 21;
3349        let p_a = 2;
3350        let p_b = 2;
3351        // Distinct per-column frequencies make the four sinusoidal columns
3352        // genuinely linearly independent over the sample grid. (A shared phase
3353        // offset varying only by column would collapse every column into
3354        // span{sin θ, cos θ, 1}, i.e. rank 3, and the compiler would correctly
3355        // absorb a column — defeating the full-rank premise of this test.)
3356        let x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3357            ((i as f64 + 1.0) * (0.21 + 0.17 * j as f64)).sin() + 0.11 * (j as f64)
3358        });
3359        let w = Array1::from_shape_fn(n, |i| 0.5 + 0.5 * ((i as f64) * 0.3).cos().abs());
3360        let (gh, gs) = k1_grams(&x, &w);
3361        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3362        let map = compile_from_raw_grams(
3363            &gh,
3364            &gs,
3365            &raw_ranges,
3366            &[BlockOrder::Marginal, BlockOrder::Logslope],
3367        )
3368        .expect("full-rank compile");
3369        // Jointly independent ⇒ no columns absorbed.
3370        assert_eq!(map.p_compiled(), p_a + p_b);
3371        assert_eq!(map.p_raw(), p_a + p_b);
3372        // For a target raw coefficient, solve T θ = β_raw (T square invertible
3373        // here) and confirm lift_coefficients(θ) == β_raw.
3374        let beta_raw = Array1::from_shape_fn(p_a + p_b, |j| 0.4 * (j as f64) - 0.7);
3375        // T is (p × p); recover θ by a least-squares solve via the normal
3376        // equations TᵀT θ = Tᵀ β.
3377        let tt = fast_atb(&map.raw_from_compiled, &map.raw_from_compiled);
3378        let tb = map.raw_from_compiled.t().dot(&beta_raw);
3379        let theta = solve_psd_system(&tt, &tb.insert_axis(Axis(1)))
3380            .expect("normal-equation solve")
3381            .column(0)
3382            .to_owned();
3383        let lifted = map.lift_coefficients(&theta).expect("lift");
3384        let max_err = (&lifted - &beta_raw)
3385            .iter()
3386            .fold(0.0_f64, |a, &v| a.max(v.abs()));
3387        assert!(
3388            max_err < 1e-8,
3389            "lift round-trip error {max_err:e} (full-rank reduction must be exactly invertible)"
3390        );
3391    }
3392
3393    /// Design reparameterisation exactness: the compiled design predicts
3394    /// identically to the raw design on every lifted coefficient, i.e.
3395    /// `X_compiled · θ == X_raw · (T θ)`. This is the contract that lets a
3396    /// family fit in reduced coordinates and still produce raw-design
3397    /// predictions.
3398    #[test]
3399    fn compiled_map_reduce_design_matches_lifted_raw_predictor() {
3400        let n = 23;
3401        let p_a = 3;
3402        let p_b = 3;
3403        let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3404            ((i as f64 + 1.0) * 0.41 + (j as f64 + 1.0) * 0.7).sin() + 0.05 * (i % 3) as f64
3405        });
3406        // Alias one B column onto an A column so the reduction is non-trivial.
3407        for i in 0..n {
3408            x[[i, p_a + 1]] = x[[i, 1]];
3409        }
3410        let w = Array1::from_shape_fn(n, |i| 0.6 + 0.4 * ((i as f64) * 0.25).cos().abs());
3411        let (gh, gs) = k1_grams(&x, &w);
3412        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3413        let map = compile_from_raw_grams(
3414            &gh,
3415            &gs,
3416            &raw_ranges,
3417            &[BlockOrder::Marginal, BlockOrder::Logslope],
3418        )
3419        .expect("compile");
3420        let x_compiled = map.reduce_design(&x).expect("reduce_design");
3421        assert_eq!(x_compiled.ncols(), map.p_compiled());
3422        let theta = Array1::from_shape_fn(map.p_compiled(), |j| 0.3 * (j as f64) - 0.5);
3423        let pred_compiled = x_compiled.dot(&theta);
3424        let beta_raw = map.lift_coefficients(&theta).expect("lift");
3425        let pred_raw = x.dot(&beta_raw);
3426        let max_err = (&pred_compiled - &pred_raw)
3427            .iter()
3428            .fold(0.0_f64, |a, &v| a.max(v.abs()));
3429        assert!(
3430            max_err < 1e-9,
3431            "compiled-design predictor diverges from lifted raw predictor: {max_err:e}"
3432        );
3433    }
3434
3435    /// Penalty-energy preservation: the reduced penalty `Tᵀ Ŝ_b T` reproduces
3436    /// the raw penalty energy `βᵀ Ŝ_b β` on every lifted point `β = T θ`. This
3437    /// is the exactness contract the lift map must satisfy for REML/inference
3438    /// to be invariant to the quotient reparameterisation.
3439    #[test]
3440    fn reduce_penalties_with_map_preserves_energy_on_lift() {
3441        let n = 19;
3442        let p_a = 3;
3443        let p_b = 2;
3444        // Make block B partly aliased with A so the reduction actually drops a
3445        // column — the penalty reduction must still preserve energy on the
3446        // surviving compiled directions.
3447        let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3448            ((i as f64 + 1.0) * 0.29 + (j as f64 + 1.0) * 0.9).cos()
3449        });
3450        // Column (p_a+0) := column 0 (exact alias) ⇒ B loses one direction.
3451        for i in 0..n {
3452            x[[i, p_a]] = x[[i, 0]];
3453        }
3454        let w = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.2).sin().abs());
3455        let (gh, gs) = k1_grams(&x, &w);
3456        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3457        let map = compile_from_raw_grams(
3458            &gh,
3459            &gs,
3460            &raw_ranges,
3461            &[BlockOrder::Marginal, BlockOrder::Logslope],
3462        )
3463        .expect("compile with alias");
3464        assert!(
3465            map.p_compiled() < p_a + p_b,
3466            "expected at least one absorbed column, got p_compiled={}",
3467            map.p_compiled()
3468        );
3469        // A simple per-block raw penalty: ridge on each block.
3470        let s_a = Array2::<f64>::eye(p_a);
3471        let s_b = Array2::<f64>::eye(p_b);
3472        let reduced = reduce_penalties_with_map(&map, &[Some(s_a.clone()), Some(s_b.clone())])
3473            .expect("reduce penalties");
3474        // For random compiled θ, raw β = T θ. Raw energy for block b is
3475        // β[range_b]ᵀ S_b β[range_b]; reduced energy is θᵀ S_reduced_b θ.
3476        let theta = Array1::from_shape_fn(map.p_compiled(), |j| {
3477            0.6 * (j as f64) - 0.3 + 0.05 * (j % 2) as f64
3478        });
3479        let beta = map.lift_coefficients(&theta).expect("lift");
3480        for (block_idx, s_raw) in [(0usize, &s_a), (1usize, &s_b)] {
3481            let range = &map.raw_block_ranges[block_idx];
3482            let beta_b = beta.slice(s![range.start..range.end]).to_owned();
3483            let raw_energy = beta_b.dot(&s_raw.dot(&beta_b));
3484            let s_reduced = reduced[block_idx]
3485                .as_ref()
3486                .expect("reduced penalty present");
3487            let reduced_energy = theta.dot(&s_reduced.dot(&theta));
3488            assert!(
3489                (raw_energy - reduced_energy).abs() < 1e-8 * raw_energy.abs().max(1.0),
3490                "block {block_idx} energy mismatch: raw={raw_energy:e} reduced={reduced_energy:e}"
3491            );
3492        }
3493    }
3494}