Skip to main content

gam_identifiability/families/
compiler.rs

1//! Family-agnostic identifiability compiler.
2//!
3//! Single source of truth for cross-block W-metric residualisation across
4//! every blockwise family (BMS, SMGS, …). Row-Jacobian compiler that
5//! orthogonalises parameter blocks in the *row primary-state* metric `H_i`. Each block
6//! exposes a [`RowJacobianOperator`] that maps a coefficient perturbation
7//! `δβ ∈ R^p` to its contribution to the per-row primary state
8//! `u_i ∈ R^K`. The compiler walks the supplied ordering left-to-right,
9//! solves the weighted Gram system against the cumulative anchor, and
10//! emits a [`CompiledBlock`] per stage. A post-walk column-pivoted QR
11//! audit on the joint primary-state design deterministically drops
12//! trailing pivots from the latest block when joint rank is lost.
13
14use std::ops::Range;
15use std::sync::Arc;
16
17use ndarray::{Array1, Array2, Array3, Axis, s};
18
19use faer::Side;
20use gam_linalg::faer_ndarray::{
21    FaerEigh, default_rrqr_rank_alpha, fast_ab, fast_ata, fast_atb, fast_xt_diag_y,
22    rrqr_with_permutation,
23};
24
25/// Slack factor (multiples of machine ε) for the rank-revealing eigenvalue
26/// threshold used when pseudo-inverting a Gram matrix or selecting the
27/// positive eigenspace of a residual Gram. The retain threshold is
28/// `scale · RANK_REVEAL_EPS_SLACK · size · ε`, where `scale` is the dominant
29/// eigenvalue (and matrix size accounts for the worst-case roundoff
30/// accumulation in the `O(size)` inner products forming each Gram entry). 64×
31/// keeps numerically-zero directions out of the kept subspace while preserving
32/// every genuinely identified direction at large-scale conditioning.
33const RANK_REVEAL_EPS_SLACK: f64 = 64.0;
34
35/// Maps a coefficient perturbation `δβ ∈ R^p` for one parameter block into
36/// its contribution to the per-row primary state `u_i ∈ R^K`.
37///
38/// For affine blocks (everything in this compiler), `J_i = ∂u_i/∂β_block` is
39/// independent of `β` and equals the transposed row of the block's effective
40/// design matrix lifted into `R^K`.
41pub trait RowJacobianOperator: Send + Sync {
42    /// Dimension of the row primary state (survival: 4, Bernoulli: 1).
43    fn k(&self) -> usize;
44
45    /// Number of coefficients in this block (= width of `J_i`).
46    fn ncols(&self) -> usize;
47
48    /// Number of training rows.
49    fn nrows(&self) -> usize;
50
51    /// Apply the row Jacobian: writes `J_i · δβ ∈ R^K` for `row` into `out`.
52    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]);
53
54    /// Materialise the full operator as an `(n_rows × ncols × K)` tensor.
55    fn evaluate_full(&self) -> Array3<f64>;
56
57    /// Build the sqrt(H)-scaled design `W = stack_i sqrt(H_i) · J_i`, flattened
58    /// channel-major to `(n_rows·K × ncols)`.
59    ///
60    /// This is the representation the identifiability *compiler*
61    /// ([`compile_with_dual_metric`]) actually consumes — it residualises and
62    /// eigendecomposes Grams of `W`, and never indexes the per-row `(n, p, K)`
63    /// tensor element-wise. Requesting the scaled design directly lets an
64    /// operator with a structured / streaming form supply it without
65    /// materialising and cloning the whole `O(n·p·K)` tensor; the default
66    /// implementation routes through [`evaluate_full`] so existing operators
67    /// remain correct unchanged. (#738: a capability is not a representation —
68    /// the compiler asks for the scaled design it needs, not the dense tensor.)
69    ///
70    /// [`evaluate_full`]: RowJacobianOperator::evaluate_full
71    /// [`compile_with_dual_metric`]: crate::families::compiler::compile_with_dual_metric
72    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
73        scale_block_by_sqrt_h(&self.evaluate_full(), h_full)
74    }
75
76    /// Write the channel-flattened column `col` — the `(n_rows · K)` vector
77    /// whose entry `i·K + ch` is `J[i, col, ch]` — into `out`.
78    ///
79    /// This is the representation the identifiability *audit* actually consumes
80    /// (per-column leverage statistics and pairwise overlaps), as opposed to the
81    /// dense `(n, p, K)` tensor. Requesting a column directly lets an operator
82    /// that has a structured / streaming form supply it without materialising
83    /// and cloning the whole `O(n·p·K)` tensor on every audit pass; the default
84    /// implementation routes through [`evaluate_full`] so existing operators
85    /// remain correct unchanged. (#738: a capability is not a representation —
86    /// the audit asks for the column view it needs, not the tensor.)
87    ///
88    /// [`evaluate_full`]: RowJacobianOperator::evaluate_full
89    fn channel_flattened_column(&self, col: usize, out: &mut [f64]) {
90        let k = self.k();
91        let n = self.nrows();
92        assert!(
93            col < self.ncols(),
94            "channel_flattened_column col {col} out of range {}",
95            self.ncols()
96        );
97        assert_eq!(
98            out.len(),
99            n * k,
100            "channel_flattened_column out length {} != n*k = {}*{}",
101            out.len(),
102            n,
103            k
104        );
105        let full = self.evaluate_full();
106        for i in 0..n {
107            for ch in 0..k {
108                out[i * k + ch] = full[[i, col, ch]];
109            }
110        }
111    }
112
113    /// Write channel-flattened rows for `rows` into `out`.
114    ///
115    /// `out` has shape `(rows.len() * K, ncols)`, with row
116    /// `local_row * K + channel` holding `J[row, :, channel]`. The default
117    /// implementation materialises the full tensor for legacy operators; large
118    /// construction-time adapters override this to stream row chunks.
119    fn channel_flattened_rows(&self, rows: Range<usize>, out: &mut Array2<f64>) {
120        let n = self.nrows();
121        let start = rows.start.min(n);
122        let end = rows.end.min(n);
123        let chunk = end - start;
124        let k = self.k();
125        let p = self.ncols();
126        assert_eq!(out.shape(), &[chunk * k, p]);
127        let full = self.evaluate_full();
128        for local_i in 0..chunk {
129            let row = start + local_i;
130            for ch in 0..k {
131                for col in 0..p {
132                    out[[local_i * k + ch, col]] = full[[row, col, ch]];
133                }
134            }
135        }
136    }
137}
138
139/// Per-row `K × K` PSD Hessian of `−log L_i(u_i)` evaluated at a pilot β.
140pub trait RowHessian: Send + Sync {
141    fn k(&self) -> usize;
142    fn nrows(&self) -> usize;
143    /// Fill the `K × K` block at `row` into `out` (row-major).
144    fn fill_row(&self, row: usize, out: &mut [f64]);
145    /// Materialise full `(n_rows × K × K)` tensor.
146    fn evaluate_full(&self) -> Array3<f64>;
147}
148
149/// Identity row metric: `K^S_i = I_K` for every row. Default structural
150/// metric for [`compile_with_dual_metric`]. Decoupling the
151/// "which directions are real structural columns" decision from a
152/// possibly rank-deficient pilot curvature `H` prevents the compiler from
153/// wrongly dropping columns whose curvature happens to be zero at the
154/// pilot β but which would be kept at the optimum.
155pub struct IdentityRowHessian {
156    n: usize,
157    k: usize,
158}
159
160impl IdentityRowHessian {
161    /// Construct an identity row metric with `n` rows and `K`-channel
162    /// row primary state.
163    pub fn new(n: usize, k: usize) -> Self {
164        Self { n, k }
165    }
166}
167
168impl RowHessian for IdentityRowHessian {
169    fn k(&self) -> usize {
170        self.k
171    }
172    fn nrows(&self) -> usize {
173        self.n
174    }
175    fn fill_row(&self, row: usize, out: &mut [f64]) {
176        assert!(
177            row < self.n,
178            "IdentityRowHessian::fill_row row {row} out of range {n}",
179            n = self.n
180        );
181        assert_eq!(out.len(), self.k * self.k);
182        for i in 0..self.k {
183            for j in 0..self.k {
184                out[i * self.k + j] = if i == j { 1.0 } else { 0.0 };
185            }
186        }
187    }
188    fn evaluate_full(&self) -> Array3<f64> {
189        let mut out = Array3::<f64>::zeros((self.n, self.k, self.k));
190        for i in 0..self.n {
191            for c in 0..self.k {
192                out[[i, c, c]] = 1.0;
193            }
194        }
195        out
196    }
197}
198
199/// One compiled block: reparam matrix `V` (`t_lw`) and the optional anchor
200/// correction matrix `M` that downstream blocks consume as a first-class
201/// anchor.
202pub struct CompiledBlock {
203    /// Orthogonal-complement reparam matrix `V ∈ R^{p × p'}` (right-selector).
204    pub t_lw: Array2<f64>,
205    /// Residualised anchor correction `M ∈ R^{d_raw × p'}` at the compiled
206    /// width, expressed in *raw* cumulative-anchor-column coordinates: `d_raw`
207    /// is the sum of the raw column counts of every prior block, NOT the
208    /// (possibly smaller) count of kept anchor directions. The predict-time
209    /// row contribution is `(C(x)·V − A_raw(x)·M)·β`, where `A_raw(x)` is the
210    /// raw anchor evaluation. `None` for the first block in the ordering.
211    /// Synonymous with `r_lw`.
212    pub anchor_correction: Option<Array2<f64>>,
213    /// Residualised reparam `R_b = M_b · V_b` — what the residualised row
214    /// evaluator uses to subtract the anchor portion. `None` for the first
215    /// block in the ordering (no anchor). Equal to `anchor_correction`.
216    pub r_lw: Option<Array2<f64>>,
217}
218
219/// Output of [`compile`]: one [`CompiledBlock`] per input block plus the
220/// joint pre-fit audit verdict.
221pub struct CompiledBlocks {
222    pub blocks: Vec<CompiledBlock>,
223    /// Joint rank reported by the post-walk column-pivoted QR audit.
224    pub joint_rank: usize,
225    /// Columns deterministically dropped by the audit, as
226    /// `(block_idx, local_col)`. The audit drops only from the latest block.
227    pub dropped: Vec<(usize, usize)>,
228}
229
230/// Structural relationship between one raw penalized block and the higher-priority
231/// anchor already accepted by the identifiability compiler.
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum PenalizedDirectionAnnotationKind {
234    /// The block kept its full realized-design span; none of its penalized
235    /// directions were already represented by a higher-priority block.
236    Independent,
237    /// Some, but not all, raw directions were absorbed by the higher-priority
238    /// anchor. The kept width is the independent residual span.
239    PartiallyAbsorbedByHigherPriority,
240    /// The entire block was the same realized-design direction/span as the
241    /// higher-priority anchor and therefore contributes no independent
242    /// coefficients or smoothing parameter directions.
243    FullyAbsorbedByHigherPriority,
244}
245
246/// Per-block structural annotation emitted by [`orthogonalize_design_blocks`].
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PenalizedDirectionAnnotation {
249    pub block_idx: usize,
250    pub raw_width: usize,
251    pub kept_width: usize,
252    pub absorbed_width: usize,
253    pub kind: PenalizedDirectionAnnotationKind,
254}
255
256/// Errors raised by [`compile`].
257#[derive(Debug)]
258pub enum CompilerError {
259    /// Operator/Hessian/ordering dimensions are inconsistent.
260    DimensionMismatch(String),
261    /// A block degenerated to zero residual span — fully aliased by the
262    /// cumulative anchor in the row metric.
263    FullyAliased { block_idx: usize, reason: String },
264    /// A linear-algebra step failed (Gram solve, eigendecomposition, QR).
265    LinalgFailure(String),
266}
267
268impl std::fmt::Display for CompilerError {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        match self {
271            CompilerError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
272            CompilerError::FullyAliased { block_idx, reason } => {
273                write!(f, "block {block_idx} fully aliased: {reason}")
274            }
275            CompilerError::LinalgFailure(msg) => write!(f, "linalg failure: {msg}"),
276        }
277    }
278}
279
280impl std::error::Error for CompilerError {}
281
282/// Semantic block label. The compiler does not need to know what the block
283/// *is*, only its relative order — but downstream consumers (per-family
284/// install paths) tag the input operators with these labels so that the
285/// compiled output can be routed back to the right runtime slot.
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum BlockOrder {
288    Time,
289    Marginal,
290    Logslope,
291    ScoreWarp,
292    LinkDev,
293}
294
295/// Compile a sequence of row-Jacobian operators against a shared row
296/// Hessian. Walks `ordering` left-to-right, residualising each block
297/// against the cumulative anchor in the `H_i`-weighted row metric, then
298/// performs a joint-design audit and emits one [`CompiledBlock`] per
299/// input (in the same order as `operators`).
300///
301/// `ordering` parallels `operators` and supplies the semantic label for
302/// each block. The compiler treats `ordering[i]` purely as metadata —
303/// the *position* `i` is the residualisation order.
304pub fn compile(
305    operators: &[Arc<dyn RowJacobianOperator>],
306    row_hess: &dyn RowHessian,
307    ordering: &[BlockOrder],
308) -> Result<CompiledBlocks, CompilerError> {
309    // Default structural metric is the per-row identity `K^S_i = I_K`.
310    // A pilot-curvature `H` can collapse a direction (zero eigenvalue) at
311    // a bad β even though the optimum keeps that direction; routing the
312    // rank decision through the structural metric and reserving `H` for
313    // *within-kept-subspace* curvature handling prevents that mis-drop.
314    let n = row_hess.nrows();
315    let k = row_hess.k();
316    let id_struct = IdentityRowHessian::new(n, k);
317    compile_with_dual_metric(operators, row_hess, &id_struct, ordering)
318}
319
320/// Compile a sequence of row-Jacobian operators using *separate* metrics
321/// for structural rank decisions and curvature-aware orthogonalisation.
322///
323/// - `row_hess` is the curvature row metric `K^H_i` (a PSD-clamped Hessian
324///   of `−log L_i(u_i)` at a pilot β).
325/// - `row_structural` is the structural row metric `K^S_i` — typically an
326///   [`IdentityRowHessian`] — used only to decide which columns survive
327///   block-against-block residualisation. A direction that the curvature
328///   `K^H` happens to see as zero at a bad pilot β is *not* dropped here
329///   as long as it is structurally non-degenerate.
330///
331/// Per-block algorithm (left-to-right walk over `ordering`):
332///
333/// 1. Residualise the block in the structural metric against the
334///    cumulative structural anchor; eigendecompose the structural residual
335///    Gram and drop only structural-zero eigenvalues → kept basis `D`
336///    (raw-block selector).
337/// 2. Residualise `W^H_b · D` in the curvature metric against the
338///    cumulative curvature anchor → curvature anchor correction
339///    `M^H_inner` and residual `R^H`.
340/// 3. Eigendecompose the curvature Gram of `R^H` and drop curvature-zero
341///    directions (a *within*-structurally-kept curvature alias is a true
342///    redundancy) → rotation/selector `T_inner`.
343/// 4. Compose: `V = D · T_inner`; compiled anchor correction is
344///    `M^H_inner · T_inner` so the predict-time row contribution stays
345///    `(C(x) · V − A(x) · anchor_correction) · β`.
346///
347/// When `row_structural` and `row_hess` represent the same metric (e.g.
348/// `compile()` with an identity row Hessian on both sides), the two
349/// passes collapse to the single-metric loop.
350pub fn compile_with_dual_metric(
351    operators: &[Arc<dyn RowJacobianOperator>],
352    row_hess: &dyn RowHessian,
353    row_structural: &dyn RowHessian,
354    ordering: &[BlockOrder],
355) -> Result<CompiledBlocks, CompilerError> {
356    if operators.len() != ordering.len() {
357        return Err(CompilerError::DimensionMismatch(format!(
358            "operators ({}) and ordering ({}) length mismatch",
359            operators.len(),
360            ordering.len()
361        )));
362    }
363    if operators.is_empty() {
364        return Ok(CompiledBlocks {
365            blocks: Vec::new(),
366            joint_rank: 0,
367            dropped: Vec::new(),
368        });
369    }
370
371    let k = row_hess.k();
372    let n = row_hess.nrows();
373    if row_structural.k() != k {
374        return Err(CompilerError::DimensionMismatch(format!(
375            "structural row metric has K={} but curvature row Hessian has K={k}",
376            row_structural.k()
377        )));
378    }
379    if row_structural.nrows() != n {
380        return Err(CompilerError::DimensionMismatch(format!(
381            "structural row metric has nrows={} but curvature row Hessian has nrows={n}",
382            row_structural.nrows()
383        )));
384    }
385    for (idx, op) in operators.iter().enumerate() {
386        if op.k() != k {
387            return Err(CompilerError::DimensionMismatch(format!(
388                "operator {idx} has K={} but row Hessian has K={k}",
389                op.k()
390            )));
391        }
392        if op.nrows() != n {
393            return Err(CompilerError::DimensionMismatch(format!(
394                "operator {idx} has nrows={} but row Hessian has nrows={n}",
395                op.nrows()
396            )));
397        }
398    }
399
400    // Materialise once per metric. K is tiny (1 or 4) so the K×K
401    // symmetric-sqrt cost is dominated by the joint-design audit below.
402    let h_full = row_hess.evaluate_full();
403    let s_full = row_structural.evaluate_full();
404
405    // Request each block's sqrt(H)-scaled design directly through the intent
406    // accessor — the `(n·K, p)` representation the compiler actually consumes —
407    // instead of first materialising the dense `(n, p, K)` per-row tensor and
408    // scaling it. The default `scaled_design_by_sqrt_h` impl still routes
409    // through `evaluate_full()`, so operators without a structured form stay
410    // correct unchanged; a streaming operator (e.g. `BlockJacobianAsRowOp`)
411    // overrides it to scale straight out of its stored layout, dropping the
412    // `O(n·p·K)` tensor clone that `evaluate_full()` performs per block at
413    // large-scale `n`. (#738: a capability is not a representation — the compiler
414    // asks for the scaled design it needs, never the dense tensor.)
415    let scaled_h: Vec<Array2<f64>> = operators
416        .iter()
417        .map(|op| op.scaled_design_by_sqrt_h(&h_full))
418        .collect();
419    let scaled_s: Vec<Array2<f64>> = operators
420        .iter()
421        .map(|op| op.scaled_design_by_sqrt_h(&s_full))
422        .collect();
423
424    let mut compiled: Vec<CompiledBlock> = Vec::with_capacity(operators.len());
425    // Demotions that happen *inside* the per-block walk (a structurally-kept
426    // block losing all its directions to a higher-priority anchor in the
427    // structural or curvature pass) are recorded here, one entry per demoted
428    // raw column, in the same `(block_idx, local_col)` convention that
429    // `audit_and_drop_trailing_pivots` emits at the joint-audit step. Without
430    // this, a zero-width demotion vanished from `dropped`, breaking the
431    // `kept_width + dropped_count == structural_pre_audit_width` accounting.
432    let mut walk_demotions: Vec<(usize, usize)> = Vec::new();
433    let mut anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
434    let mut anchor_s: Array2<f64> = Array2::zeros((n * k, 0));
435    // Cumulative *raw* (un-residualised) curvature-scaled anchor: the
436    // horizontal stack of `sqrt(H)·J_b` for every block already walked,
437    // keeping one column per raw block column. Where `anchor_h` carries the
438    // residualised, kept-direction anchor (its width shrinks whenever a block
439    // sheds an aliased column), this matrix keeps the full raw column count so
440    // the emitted `anchor_correction` can be expressed in raw-anchor-column
441    // coordinates — exactly the basis the predict-time subtraction
442    // `A_raw(x)·M` evaluates against. See the `M_raw` derivation below.
443    let mut raw_anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
444
445    for idx in 0..operators.len() {
446        let w_h = &scaled_h[idx];
447        let w_s = &scaled_s[idx];
448        let p_b = w_h.ncols();
449
450        // A zero-width block owns no raw columns, so it cannot alias against any
451        // anchor and is trivially identifiable. Emit an empty compiled block and
452        // skip the structural/curvature passes: their residual Grams are 0×0 and
453        // yield no positive eigenspace, which the `anchor_h.ncols() == 0`
454        // first-block guards below would otherwise mis-report as `FullyAliased`
455        // even though there is nothing to alias. This mirrors the empty block a
456        // fully-absorbed later block compiles to, with no demotions to record
457        // (there are no columns) and no change to the running anchors.
458        if p_b == 0 {
459            compiled.push(CompiledBlock {
460                t_lw: Array2::<f64>::zeros((0, 0)),
461                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
462                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
463            });
464            continue;
465        }
466
467        // Pass 1 (structural): residualise W^S_b against cumulative
468        // structural anchor; eigendecompose the structural residual Gram
469        // and keep only directions with non-zero structural mass → D
470        // (raw-block selector).
471        // Only the structural residual is consumed downstream; the
472        // structural-metric correction M^S is intentionally discarded —
473        // predict-time subtraction uses the curvature metric correction
474        // (`M^H_inner` below), not the structural one.
475        let (residual_s, _) = residualise_in_metric(&anchor_s, w_s)?;
476        let g_s = fast_atb(&residual_s, &residual_s);
477        // Scale reference for the kept-eigenspace tolerance: the *original*
478        // (pre-residualisation) structural block Gram trace. A fully-absorbed
479        // block's residual collapses to ~ε² noise; anchoring tau to that would
480        // keep the noise directions and wrongly treat the block as
481        // structurally independent. The original-block trace is invariant to
482        // absorption, so a near-zero residual is rejected as fully absorbed.
483        let g_s_bb = fast_atb(w_s, w_s);
484        let g_s_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
485        let d = keep_positive_eigenspace(&g_s, n, k, g_s_trace)?;
486        if d.ncols() == 0 {
487            if anchor_h.ncols() == 0 {
488                return Err(CompilerError::FullyAliased {
489                    block_idx: idx,
490                    reason: format!(
491                        "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
492                    ),
493                });
494            }
495            compiled.push(CompiledBlock {
496                t_lw: Array2::<f64>::zeros((p_b, 0)),
497                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
498                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
499            });
500            // The structural pass fully absorbed all `p_b` raw columns into the
501            // higher-priority anchor: record each as a drop so the per-block
502            // width accounting (kept + dropped == raw width) stays exact.
503            for c in 0..p_b {
504                walk_demotions.push((idx, c));
505            }
506            raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
507            continue;
508        }
509
510        // Pass 2 (curvature): form W^H_b · D and residualise against the
511        // cumulative curvature anchor. Eigendecompose the curvature
512        // residual Gram and drop curvature-zero directions inside D →
513        // T_inner. A direction kept by the structural pass but degenerate
514        // here is genuinely curvature-redundant *within* the
515        // structurally-kept basis, so dropping it is correct.
516        let w_h_d = fast_ab(w_h, &d);
517        let (residual_h, m_h_inner_opt) = residualise_in_metric(&anchor_h, &w_h_d)?;
518        let g_h = fast_atb(&residual_h, &residual_h);
519        let p_d = d.ncols();
520        // Scale reference: the *unresidualised* curvature block Gram trace of
521        // `W^H_b · D` (the same convention the closed-form `compile_from_raw_grams`
522        // path uses with `d_t_kh_d`). Anchoring to the residual trace would
523        // collapse to ~ε² when the block is fully curvature-absorbed and keep
524        // its noise directions.
525        let g_h_dd = fast_atb(&w_h_d, &w_h_d);
526        let g_h_trace: f64 = (0..p_d).map(|i| g_h_dd[[i, i]].max(0.0)).sum();
527        let t_inner = keep_positive_eigenspace(&g_h, n, k, g_h_trace)?;
528        if t_inner.ncols() == 0 {
529            if anchor_h.ncols() == 0 {
530                return Err(CompilerError::FullyAliased {
531                    block_idx: idx,
532                    reason: format!(
533                        "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {p_d}) before any anchor exists"
534                    ),
535                });
536            }
537            compiled.push(CompiledBlock {
538                t_lw: Array2::<f64>::zeros((p_b, 0)),
539                anchor_correction: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
540                r_lw: Some(Array2::<f64>::zeros((raw_anchor_h.ncols(), 0))),
541            });
542            // The structural pass kept `p_d` directions, but the curvature pass
543            // absorbed all of them into the higher-priority anchor. Record each
544            // structurally-kept-but-curvature-demoted direction as a drop so the
545            // pre-audit structural width is fully accounted for.
546            for c in 0..p_d {
547                walk_demotions.push((idx, c));
548            }
549            raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
550            continue;
551        }
552
553        // Compose V = D · T_inner (raw-block → kept).
554        let v = fast_ab(&d, &t_inner);
555
556        // `m_h_inner_opt` was residualised against `anchor_h` as it stands
557        // *here*, i.e. the cumulative kept-direction anchor of all PRIOR
558        // blocks. Snapshot that pre-append anchor and its raw counterpart
559        // before this block's residual columns are appended below; the
560        // change-of-basis for this block's correction must be expressed
561        // against the prior-block anchor that `m` is indexed against, not the
562        // post-append anchor that already carries this block's own columns.
563        let prior_anchor_h = anchor_h.clone();
564        let prior_raw_anchor_h = raw_anchor_h.clone();
565
566        // Append residual-V columns to both cumulative anchors so future
567        // blocks see the structurally-orthogonal and curvature-orthogonal
568        // residual designs of this block, never the raw scaled block.
569        let residual_h_t = fast_ab(&residual_h, &t_inner);
570        anchor_h = concat_cols(&anchor_h, &residual_h_t);
571        // The structural anchor needs the structural-residual restricted
572        // to the kept directions: residual_s · v gives (W^S_b − A^S · M^S)·V.
573        let residual_s_v = fast_ab(&residual_s, &v);
574        anchor_s = concat_cols(&anchor_s, &residual_s_v);
575
576        // Compiled anchor correction lives in the curvature metric — the
577        // predict-time row contribution is `(C(x) · V − A(x) · M)·β`, where
578        // the subtraction makes residuals H-orthogonal at training and `A(x)`
579        // is the *raw* anchor evaluation (one column per raw anchor column).
580        //
581        // `m_h_inner_opt · t_inner` (call it `M_kept`) lives in the
582        // *kept-direction* anchor coordinates of the PRIOR-block anchor
583        // `prior_anchor_h` (the value `anchor_h` held when `m` was produced at
584        // `residualise_in_metric` above, before this block's residual columns
585        // were appended). Its row count is `prior_anchor_h.ncols()`, which
586        // equals the prior-block raw anchor width only when no upstream block
587        // shed an aliased column. The predict path multiplies by the raw
588        // anchor matrix `A_raw` (one column per raw anchor column of the prior
589        // blocks), so we must re-express `M_kept` in raw-anchor-column
590        // coordinates.
591        //
592        // `prior_anchor_h` and `prior_raw_anchor_h` span the same column space
593        // in the curvature metric (the residualisation/rotation only drops
594        // directions that lie inside that span), so there is an exact `Z` with
595        // `prior_raw_anchor_h · Z = prior_anchor_h`. Then
596        //   `prior_anchor_h · M_kept = prior_raw_anchor_h · (Z · M_kept)`,
597        // and the raw-coordinate correction is `M_raw = Z · M_kept`, with row
598        // count `prior_raw_anchor_h.ncols()` = the sum of prior raw anchor
599        // block widths. `Z = (Aᵀ A)⁺ Aᵀ prior_anchor_h` (with
600        // `A = prior_raw_anchor_h`) is the metric-exact least-squares change of
601        // basis (`solve_psd_system`).
602        let m_compiled = match m_h_inner_opt.as_ref() {
603            Some(m) => {
604                let m_kept = fast_ab(m, &t_inner);
605                if m_kept.nrows() != prior_anchor_h.ncols() {
606                    return Err(CompilerError::DimensionMismatch(format!(
607                        "anchor correction must be indexed by prior-block kept anchor directions: \
608                         m_kept has {} rows but prior_anchor_h has {} columns",
609                        m_kept.nrows(),
610                        prior_anchor_h.ncols()
611                    )));
612                }
613                let g_raw = fast_atb(&prior_raw_anchor_h, &prior_raw_anchor_h);
614                let z_rhs = fast_atb(&prior_raw_anchor_h, &prior_anchor_h);
615                let z = solve_psd_system(&g_raw, &z_rhs)?;
616                Some(fast_ab(&z, &m_kept))
617            }
618            None => None,
619        };
620        compiled.push(CompiledBlock {
621            t_lw: v,
622            anchor_correction: m_compiled.clone(),
623            r_lw: m_compiled,
624        });
625
626        // Append this block's raw curvature-scaled columns to the raw anchor
627        // accumulator so the *next* block's `M_raw` is expressed against the
628        // full raw column set of all blocks walked so far.
629        raw_anchor_h = concat_cols(&raw_anchor_h, w_h);
630    }
631
632    // Joint-design audit on the curvature-scaled cumulative anchor: the
633    // identifiability question the fit cares about is curvature-rank.
634    let audit_dropped = audit_and_drop_trailing_pivots(&anchor_h, &mut compiled)?;
635    // Combine in-walk demotions (structural / curvature full absorption of a
636    // block) with the joint-audit trailing-pivot drops so `dropped` accounts
637    // for *every* column the compiler removed, not just the joint-audit ones.
638    let mut dropped = walk_demotions;
639    dropped.extend(audit_dropped);
640    let joint_rank: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
641
642    Ok(CompiledBlocks {
643        blocks: compiled,
644        joint_rank,
645        dropped,
646    })
647}
648
649/// Build `W_b = stack_i sqrt(H_i) · J_b,i` flattened to `(n*K, ncols)` from a
650/// materialised `(n, p, K)` tensor. Thin wrapper over
651/// [`scale_jacobian_by_sqrt_h_with`] that reads the tensor element-wise.
652fn scale_block_by_sqrt_h(jb: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
653    let n = jb.shape()[0];
654    let p = jb.shape()[1];
655    let k = jb.shape()[2];
656    scale_jacobian_by_sqrt_h_with(n, p, k, h_full, |i, a, c| jb[[i, a, c]])
657}
658
659/// Build `W_b = stack_i sqrt(H_i) · J_b,i` flattened to `(n*K, ncols)` without
660/// ever requiring a materialised `(n, p, K)` tensor.
661///
662/// The Jacobian entries are pulled through the `jac` closure
663/// (`jac(i, a, c) = J_b,i[a, c]`), so a structured operator that stores its
664/// Jacobian in a compact / streaming form can supply the sqrt(H)-scaled design
665/// directly — the representation the compiler actually consumes — rather than
666/// being forced to clone a dense `(n, p, K)` tensor first. (#738: a capability
667/// is not a representation — the compiler asks for the scaled `(n·K, p)` design
668/// it needs, not the dense per-row tensor.)
669///
670/// `K` is tiny (1 or 4), so the per-row symmetric sqrt is negligible relative
671/// to the overall compile.
672pub fn scale_jacobian_by_sqrt_h_with(
673    n: usize,
674    p: usize,
675    k: usize,
676    h_full: &Array3<f64>,
677    jac: impl Fn(usize, usize, usize) -> f64,
678) -> Array2<f64> {
679    assert_eq!(h_full.shape(), &[n, k, k]);
680    let mut out = Array2::<f64>::zeros((n * k, p));
681    let mut sqrt_h = Array2::<f64>::zeros((k, k));
682    let mut scratch_jrow = Array2::<f64>::zeros((p, k));
683    for i in 0..n {
684        // Symmetric square root of H_i via eigendecomposition.
685        let h_i = h_full.index_axis(Axis(0), i).to_owned();
686        sqrt_h.fill(0.0);
687        symmetric_sqrt_into(&h_i, &mut sqrt_h);
688        // scratch_jrow[a, c] = J_b,i[a, c] (transpose-friendly layout for
689        // the GEMV below: we want (p × k) · (k,) = (p,) for each column of
690        // sqrt_h, but we batch by writing out[(i*k+c), a] = (sqrt_h · J_b,iᵀ)[c, a].
691        for a in 0..p {
692            for c in 0..k {
693                scratch_jrow[[a, c]] = jac(i, a, c);
694            }
695        }
696        for c in 0..k {
697            for a in 0..p {
698                let mut acc = 0.0;
699                for cp in 0..k {
700                    acc += sqrt_h[[c, cp]] * scratch_jrow[[a, cp]];
701                }
702                out[[i * k + c, a]] = acc;
703            }
704        }
705    }
706    out
707}
708
709/// Symmetric matrix square root via eigendecomposition with negative
710/// eigenvalues clamped to zero (PSD projection guard).
711pub(crate) fn symmetric_sqrt_into(m: &Array2<f64>, out: &mut Array2<f64>) {
712    let k = m.nrows();
713    assert_eq!(m.ncols(), k);
714    assert_eq!(out.shape(), &[k, k]);
715    if k == 1 {
716        out[[0, 0]] = m[[0, 0]].max(0.0).sqrt();
717        return;
718    }
719    let (evals, evecs) = match m.eigh(Side::Lower) {
720        Ok(pair) => pair,
721        Err(_) => {
722            // Fall back to clipped diagonal — extremely defensive for the
723            // K=4 row Hessian which is already PSD-clamped by the caller.
724            out.fill(0.0);
725            for i in 0..k {
726                out[[i, i]] = m[[i, i]].max(0.0).sqrt();
727            }
728            return;
729        }
730    };
731    // out = U · diag(sqrt(max(0, λ))) · Uᵀ
732    let mut scaled = evecs.clone();
733    for j in 0..k {
734        let s = evals[j].max(0.0).sqrt();
735        for i in 0..k {
736            scaled[[i, j]] *= s;
737        }
738    }
739    out.assign(&fast_atb(&evecs.t().to_owned(), &scaled.t().to_owned()));
740    // The above fast_atb computed (Uᵀ)ᵀ · (Uᵀ·diag(s)) = U · diag(s) · Uᵀ
741    // when the inputs are owned. To be safe and avoid layout surprises,
742    // re-do the small multiplication explicitly for K ≤ 4.
743    out.fill(0.0);
744    for i in 0..k {
745        for j in 0..k {
746            let mut acc = 0.0;
747            for l in 0..k {
748                acc += evecs[[i, l]] * evals[l].max(0.0).sqrt() * evecs[[j, l]];
749            }
750            out[[i, j]] = acc;
751        }
752    }
753}
754
755/// Solve `Aᵀ A · M = Aᵀ B` and return `(B − A·M, Some(M))`. With `A`
756/// having zero columns, returns `(B, None)` — the first block needs no
757/// anchor correction.
758fn residualise_in_metric(
759    a_scaled: &Array2<f64>,
760    b_scaled: &Array2<f64>,
761) -> Result<(Array2<f64>, Option<Array2<f64>>), CompilerError> {
762    let d = a_scaled.ncols();
763    if d == 0 {
764        return Ok((b_scaled.clone(), None));
765    }
766    let g_aa = fast_atb(a_scaled, a_scaled);
767    let g_ab = fast_atb(a_scaled, b_scaled);
768    let m = solve_psd_system(&g_aa, &g_ab)?;
769    let a_m = fast_ab(a_scaled, &m);
770    let residual = b_scaled - &a_m;
771    Ok((residual, Some(m)))
772}
773
774/// Solve a PSD linear system `G · M = R` for `M`. Tries the eigen-based
775/// pseudoinverse with a relative threshold and falls back to a damped
776/// solve if the spectrum is ill-conditioned beyond what the threshold
777/// can clean.
778fn solve_psd_system(g: &Array2<f64>, r: &Array2<f64>) -> Result<Array2<f64>, CompilerError> {
779    let n = g.nrows();
780    if n == 0 {
781        return Ok(Array2::zeros((0, r.ncols())));
782    }
783    let (evals, evecs) = g
784        .eigh(Side::Lower)
785        .map_err(|err| CompilerError::LinalgFailure(format!("Gram eigh failed: {err:?}")))?;
786    let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
787    let tol = lambda_max * RANK_REVEAL_EPS_SLACK * (n.max(1) as f64) * f64::EPSILON;
788    // M = U · diag(1/λ_kept) · Uᵀ · R
789    let u_t_r = fast_atb(&evecs, r);
790    let mut scaled = u_t_r.clone();
791    for i in 0..n {
792        let lam = evals[i];
793        let inv = if lam > tol { 1.0 / lam } else { 0.0 };
794        for j in 0..scaled.ncols() {
795            scaled[[i, j]] *= inv;
796        }
797    }
798    let m = fast_ab(&evecs, &scaled);
799    Ok(m)
800}
801
802/// Eigendecompose the residual Gram `G̃` and return `V` made of the
803/// eigenvectors whose eigenvalues exceed
804/// `τ = max(λ_max(G̃), tr(G_BB)) · RANK_REVEAL_EPS_SLACK · n · K · ε`.
805fn keep_positive_eigenspace(
806    g_tilde: &Array2<f64>,
807    n: usize,
808    k: usize,
809    g_bb_trace: f64,
810) -> Result<Array2<f64>, CompilerError> {
811    let p = g_tilde.nrows();
812    if p == 0 {
813        return Ok(Array2::zeros((0, 0)));
814    }
815    let (evals, evecs) = g_tilde.eigh(Side::Lower).map_err(|err| {
816        CompilerError::LinalgFailure(format!("residual Gram eigh failed: {err:?}"))
817    })?;
818    let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
819    let scale = lambda_max.max(g_bb_trace);
820    let nk = (n.saturating_mul(k)).max(p).max(1) as f64;
821    let tau = scale * RANK_REVEAL_EPS_SLACK * nk * f64::EPSILON;
822    // Collect kept column indices.
823    let mut kept: Vec<usize> = (0..p).filter(|&i| evals[i] > tau).collect();
824    // Sort kept indices by descending eigenvalue for a stable column order.
825    kept.sort_by(|&a, &b| {
826        evals[b]
827            .partial_cmp(&evals[a])
828            .unwrap_or(std::cmp::Ordering::Equal)
829    });
830    let mut v = Array2::<f64>::zeros((p, kept.len()));
831    for (out_col, &src_col) in kept.iter().enumerate() {
832        for row in 0..p {
833            v[[row, out_col]] = evecs[[row, src_col]];
834        }
835    }
836    Ok(v)
837}
838
839/// Concatenate two matrices column-wise. Both must have the same row count.
840fn concat_cols(left: &Array2<f64>, right: &Array2<f64>) -> Array2<f64> {
841    let nrows = left.nrows().max(right.nrows());
842    let lc = left.ncols();
843    let rc = right.ncols();
844    let mut out = Array2::<f64>::zeros((nrows, lc + rc));
845    if lc > 0 {
846        out.slice_mut(s![.., ..lc]).assign(left);
847    }
848    if rc > 0 {
849        out.slice_mut(s![.., lc..]).assign(right);
850    }
851    out
852}
853
854/// Post-walk audit: column-pivoted QR on the cumulative scaled design.
855/// If rank < p_total, deterministically drop trailing pivots from the
856/// latest block's `V`. Earlier blocks are never modified.
857fn audit_and_drop_trailing_pivots(
858    w_joint: &Array2<f64>,
859    compiled: &mut [CompiledBlock],
860) -> Result<Vec<(usize, usize)>, CompilerError> {
861    let p_total: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
862    if p_total == 0 || w_joint.nrows() == 0 {
863        return Ok(Vec::new());
864    }
865
866    // RRQR rank with the codebase's default α.
867    let rrqr = rrqr_with_permutation(w_joint, default_rrqr_rank_alpha())
868        .map_err(|err| CompilerError::LinalgFailure(format!("audit RRQR failed: {err:?}")))?;
869    let rank = rrqr.rank;
870    if rank >= p_total {
871        return Ok(Vec::new());
872    }
873
874    // Trailing pivots are the redundant columns. Attribute every demoted
875    // global column to the *latest* block by truncating its V; earlier
876    // blocks keep their full V. The demoted suffix is sorted only by
877    // pivot order, but we drop deterministically: take the count of
878    // demoted columns and truncate that many trailing columns of the
879    // latest block.
880    let drop_count = p_total - rank;
881    let latest_idx = compiled.len() - 1;
882    let latest = &mut compiled[latest_idx];
883    let kept_local = latest.t_lw.ncols().saturating_sub(drop_count);
884    let dropped_locals: Vec<(usize, usize)> = (kept_local..latest.t_lw.ncols())
885        .map(|c| (latest_idx, c))
886        .collect();
887    // Truncate ALL kept-direction-indexed matrices in lockstep so the
888    // shape contract (`anchor_correction: d_total × k_kept`, `r_lw:
889    // d_total × k_kept`, `t_lw: p_raw × k_kept`) holds after the audit
890    // drops trailing pivots. Forgetting these two left
891    // `anchor_correction.ncols() == pre_truncation_k_kept` while
892    // `t_lw.ncols() == post_truncation_k_kept`, surfaced downstream as
893    // `cross-block identifiability: anchor_correction shape D×P does
894    // not match expected d_total=D × k_kept=K`.
895    latest.t_lw = latest.t_lw.slice(s![.., ..kept_local]).to_owned();
896    if let Some(m) = latest.anchor_correction.as_ref() {
897        latest.anchor_correction = Some(m.slice(s![.., ..kept_local]).to_owned());
898    }
899    if let Some(r) = latest.r_lw.as_ref() {
900        latest.r_lw = Some(r.slice(s![.., ..kept_local]).to_owned());
901    }
902    Ok(dropped_locals)
903}
904
905/// Channel-pair decomposition of every parameter block's row Jacobian.
906///
907/// For families with `K` primary-state channels (survival: K=4), each block
908/// `b` contributes a (n × p_b) channel matrix `X_b^(c)` per channel `c` that
909/// it touches. Blocks that do not contribute to a channel store `None` in
910/// that slot. The closed-form Gram compiler consumes this view directly to
911/// build the joint Gram `K^H` without ever materialising the full
912/// `(n·K) × p_total` weighted design `W = sqrt(H) · J`.
913pub struct PrimaryChannelBlocks {
914    /// Outer index: block. Inner index: channel `c ∈ 0..K`. `None` means the
915    /// block does not contribute to that channel.
916    pub blocks: Vec<Vec<Option<Array2<f64>>>>,
917}
918
919/// Closed-form Gram builder: `K^H[a, b] = Σ_{c,d} (X_a^(c))ᵀ · diag(h_{cd}) · X_b^(d)`.
920///
921/// Inputs:
922/// - `channel_blocks`: per-block channel decomposition of the row Jacobian.
923/// - `row_hess`: `(n × K × K)` per-row PSD Hessian (typically clamped to PSD
924///   by the family upstream).
925/// - `raw_block_ranges`: `[start, end)` column ranges of each block inside
926///   the full `p_total`-wide coefficient vector. Must be contiguous and
927///   non-overlapping; their union spans `0..p_total`.
928///
929/// Returns the symmetric `(p_total × p_total)` Gram matrix.
930pub fn build_raw_grams_from_channel_blocks(
931    channel_blocks: &PrimaryChannelBlocks,
932    row_hess: &dyn RowHessian,
933    raw_block_ranges: &[std::ops::Range<usize>],
934) -> Result<Array2<f64>, CompilerError> {
935    let num_blocks = channel_blocks.blocks.len();
936    if num_blocks != raw_block_ranges.len() {
937        return Err(CompilerError::DimensionMismatch(format!(
938            "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
939            raw_block_ranges.len()
940        )));
941    }
942    if num_blocks == 0 {
943        return Ok(Array2::<f64>::zeros((0, 0)));
944    }
945    let k = row_hess.k();
946    let n = row_hess.nrows();
947    let p_total: usize = raw_block_ranges.iter().map(|r| r.end - r.start).sum();
948    let expected_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
949    if expected_total != p_total {
950        return Err(CompilerError::DimensionMismatch(format!(
951            "raw_block_ranges must be contiguous from 0; got p_total={p_total} but last end={expected_total}"
952        )));
953    }
954    // Per-block channel-slot shape sanity.
955    for (b, slots) in channel_blocks.blocks.iter().enumerate() {
956        if slots.len() != k {
957            return Err(CompilerError::DimensionMismatch(format!(
958                "block {b}: expected {k} channel slots, got {}",
959                slots.len()
960            )));
961        }
962        let p_b = raw_block_ranges[b].end - raw_block_ranges[b].start;
963        for (c, mat) in slots.iter().enumerate() {
964            if let Some(x) = mat.as_ref() {
965                if x.nrows() != n {
966                    return Err(CompilerError::DimensionMismatch(format!(
967                        "block {b} channel {c}: nrows={} but row Hessian nrows={n}",
968                        x.nrows()
969                    )));
970                }
971                if x.ncols() != p_b {
972                    return Err(CompilerError::DimensionMismatch(format!(
973                        "block {b} channel {c}: ncols={} but block width={p_b}",
974                        x.ncols()
975                    )));
976                }
977            }
978        }
979    }
980
981    // Materialise H once and slice it into K·K length-n vectors h_{cd}.
982    let h_full = row_hess.evaluate_full();
983    if h_full.shape() != &[n, k, k] {
984        return Err(CompilerError::DimensionMismatch(format!(
985            "row Hessian evaluate_full shape {:?} != [n={n}, k={k}, k={k}]",
986            h_full.shape()
987        )));
988    }
989    // h_pairs[c * k + d] = length-n vector of H_i[c, d].
990    let mut h_pairs: Vec<Array1<f64>> = Vec::with_capacity(k * k);
991    for c in 0..k {
992        for d in 0..k {
993            let mut v = Array1::<f64>::zeros(n);
994            for i in 0..n {
995                v[i] = h_full[[i, c, d]];
996            }
997            h_pairs.push(v);
998        }
999    }
1000
1001    let mut gram = Array2::<f64>::zeros((p_total, p_total));
1002    // Accumulate upper triangle (a ≤ b) then symmetrise.
1003    for a in 0..num_blocks {
1004        let range_a = raw_block_ranges[a].clone();
1005        for b in a..num_blocks {
1006            let range_b = raw_block_ranges[b].clone();
1007            let mut block_acc =
1008                Array2::<f64>::zeros((range_a.end - range_a.start, range_b.end - range_b.start));
1009            for c in 0..k {
1010                let Some(x_a_c) = channel_blocks.blocks[a][c].as_ref() else {
1011                    continue;
1012                };
1013                for d in 0..k {
1014                    let Some(x_b_d) = channel_blocks.blocks[b][d].as_ref() else {
1015                        continue;
1016                    };
1017                    let h_cd = &h_pairs[c * k + d];
1018                    // (X_a^(c))ᵀ · diag(h_cd) · X_b^(d)  →  (p_a × p_b).
1019                    let contrib = fast_xt_diag_y(x_a_c, h_cd, x_b_d);
1020                    block_acc += &contrib;
1021                }
1022            }
1023            // Write into upper triangle (and the diagonal block itself).
1024            gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1025                .assign(&block_acc);
1026        }
1027    }
1028    // Symmetrise: copy upper triangle to lower. Diagonal blocks are
1029    // themselves p_a × p_a — symmetrise within them too.
1030    for i in 0..p_total {
1031        for j in 0..i {
1032            let v = gram[[j, i]];
1033            gram[[i, j]] = v;
1034        }
1035    }
1036    Ok(gram)
1037}
1038
1039/// Structural Gram `K^S`: same shape as [`build_raw_grams_from_channel_blocks`]
1040/// but with the per-row Hessian replaced by the K×K identity. Used by the
1041/// dual-metric compiler as the un-weighted reference geometry.
1042///
1043/// `K^S[a, b] = Σ_c (X_a^(c))ᵀ · X_b^(c)` (cross-channel terms vanish under
1044/// `H_i = I_K`).
1045pub fn build_raw_grams_structural(
1046    channel_blocks: &PrimaryChannelBlocks,
1047    raw_block_ranges: &[std::ops::Range<usize>],
1048) -> Array2<f64> {
1049    let num_blocks = channel_blocks.blocks.len();
1050    assert_eq!(
1051        num_blocks,
1052        raw_block_ranges.len(),
1053        "channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
1054        raw_block_ranges.len()
1055    );
1056    if num_blocks == 0 {
1057        return Array2::<f64>::zeros((0, 0));
1058    }
1059    let p_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1060    let mut gram = Array2::<f64>::zeros((p_total, p_total));
1061    for a in 0..num_blocks {
1062        let range_a = raw_block_ranges[a].clone();
1063        for b in a..num_blocks {
1064            let range_b = raw_block_ranges[b].clone();
1065            let p_a = range_a.end - range_a.start;
1066            let p_b = range_b.end - range_b.start;
1067            let k_a = channel_blocks.blocks[a].len();
1068            let k_b = channel_blocks.blocks[b].len();
1069            assert_eq!(
1070                k_a, k_b,
1071                "structural Gram: block {a} has {k_a} channels but block {b} has {k_b}",
1072            );
1073            let mut block_acc = Array2::<f64>::zeros((p_a, p_b));
1074            for c in 0..k_a {
1075                let (Some(x_a_c), Some(x_b_c)) = (
1076                    channel_blocks.blocks[a][c].as_ref(),
1077                    channel_blocks.blocks[b][c].as_ref(),
1078                ) else {
1079                    continue;
1080                };
1081                let contrib = if a == b {
1082                    // Diagonal block, same channel — symmetric XᵀX.
1083                    fast_ata(x_a_c)
1084                } else {
1085                    fast_atb(x_a_c, x_b_c)
1086                };
1087                block_acc += &contrib;
1088            }
1089            gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
1090                .assign(&block_acc);
1091        }
1092    }
1093    for i in 0..p_total {
1094        for j in 0..i {
1095            let v = gram[[j, i]];
1096            gram[[i, j]] = v;
1097        }
1098    }
1099    gram
1100}
1101
1102/// Build the primary-state curvature Gram `K^H` and structural Gram `K^S`
1103/// for a block decomposition, preferring the device (GPU) path when
1104/// available and falling back to the CPU closed-form builders otherwise.
1105///
1106/// The GPU path is only attempted for survival-family geometry
1107/// (`K = CHANNELS = 4`) — that is the case the GPU kernel
1108/// ([`crate::families::gpu::try_primary_state_gram_cuda`])
1109/// is specialised for via the packed-symmetric `n × 10` weight layout.
1110/// For any other `K` the CPU builders are used unconditionally.
1111///
1112/// Returns `(gram_h, gram_struct)` with the same shape and semantics as
1113/// [`build_raw_grams_from_channel_blocks`] + [`build_raw_grams_structural`].
1114pub fn build_primary_grams_gpu_or_cpu(
1115    channel_blocks: &PrimaryChannelBlocks,
1116    row_hess: &dyn RowHessian,
1117    raw_block_ranges: &[std::ops::Range<usize>],
1118) -> Result<(Array2<f64>, Array2<f64>), CompilerError> {
1119    let k = row_hess.k();
1120    if k == crate::families::gpu::CHANNELS {
1121        let gpu_blocks: Vec<Vec<Option<Array2<f64>>>> = channel_blocks
1122            .blocks
1123            .iter()
1124            .map(|slots| slots.iter().cloned().collect())
1125            .collect();
1126        if let Some(h_packed) = pack_row_hessian_symmetric(row_hess) {
1127            if let Some(bundle) = crate::families::gpu::try_primary_state_gram_cuda(
1128                &gpu_blocks,
1129                &h_packed,
1130                raw_block_ranges,
1131            ) {
1132                log::info!("[identifiability_compile] gram path = gpu");
1133                return Ok((bundle.gram_h, bundle.gram_struct));
1134            }
1135        }
1136    }
1137    log::info!("[identifiability_compile] gram path = cpu");
1138    let gram_h = build_raw_grams_from_channel_blocks(channel_blocks, row_hess, raw_block_ranges)?;
1139    let gram_struct = build_raw_grams_structural(channel_blocks, raw_block_ranges);
1140    Ok((gram_h, gram_struct))
1141}
1142
1143/// Pack a per-row symmetric `K = 4` Hessian into the `n × 10`
1144/// upper-triangular row-major layout consumed by the GPU kernel
1145/// (`packed_index(c, d)` for `c ≤ d`). Returns `None` when `K != 4`.
1146fn pack_row_hessian_symmetric(row_hess: &dyn RowHessian) -> Option<Array2<f64>> {
1147    use crate::families::gpu::{CHANNELS, PACKED_LEN, packed_index};
1148    if row_hess.k() != CHANNELS {
1149        return None;
1150    }
1151    let n = row_hess.nrows();
1152    let h_full = row_hess.evaluate_full();
1153    if h_full.shape() != [n, CHANNELS, CHANNELS] {
1154        return None;
1155    }
1156    let mut packed = Array2::<f64>::zeros((n, PACKED_LEN));
1157    for i in 0..n {
1158        for c in 0..CHANNELS {
1159            for d in c..CHANNELS {
1160                packed[[i, packed_index(c, d)]] = h_full[[i, c, d]];
1161            }
1162        }
1163    }
1164    Some(packed)
1165}
1166
1167/// Closed-form Gram-based compile output: a single `p_raw × p_compiled`
1168/// reparam matrix `T` mapping compiled coordinates back to raw width.
1169/// `T · θ` lifts a fitted compiled-width β back to raw width; predict-time
1170/// row contribution is `X_raw · T · θ` where `X_raw` is the full raw design.
1171///
1172/// `compiled_block_ranges[b]` gives the column range inside `T` (and inside
1173/// the compiled-width coefficient vector) attributable to raw block `b`.
1174/// `raw_block_ranges[b]` gives the corresponding raw-width column range.
1175#[derive(Debug)]
1176pub struct CompiledMap {
1177    /// `(p_raw × p_compiled)` raw-from-compiled reparam matrix.
1178    pub raw_from_compiled: Array2<f64>,
1179    /// Per-block compiled-width column ranges, parallel to
1180    /// `raw_block_ranges`. Same length as the input `ordering`.
1181    pub compiled_block_ranges: Vec<std::ops::Range<usize>>,
1182    /// Per-block raw-width column ranges (copied through from input).
1183    pub raw_block_ranges: Vec<std::ops::Range<usize>>,
1184}
1185
1186/// Neutral view of this compiled reparametrisation for the gauge layer
1187/// (#1521): `Gauge::from_compiled_map` lives DOWN in `gam-problem` and
1188/// names only the `CompiledBlockMap` trait, never the concrete
1189/// `CompiledMap` (which lives ABOVE `gam-problem`). This `impl` supplies
1190/// the inverted dependency edge.
1191impl gam_problem::gauge::CompiledBlockMap for CompiledMap {
1192    fn raw_from_compiled(&self) -> &Array2<f64> {
1193        &self.raw_from_compiled
1194    }
1195    fn raw_block_ranges(&self) -> &[std::ops::Range<usize>] {
1196        &self.raw_block_ranges
1197    }
1198    fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>] {
1199        &self.compiled_block_ranges
1200    }
1201}
1202
1203/// Closed-form Gram-based identifiability compile.
1204///
1205/// Sequential algorithm operating purely on the raw-width Grams
1206/// `K^H = Σ_i J_iᵀ H_i J_i` (curvature) and `K^S = Σ_i J_iᵀ J_i`
1207/// (structural). Walks `ordering` left-to-right; for each block `b` with
1208/// raw-width selector `P_b` (columns of the identity selecting that
1209/// block) and cumulative compiled map `T = [T_0, …, T_{b-1}]`:
1210///
1211/// 1. Structural rank step (drop true gauges):
1212///    `G^S_AA = Tᵀ K^S T`, `G^S_Ab = Tᵀ K^S P_b`, `G^S_bb = P_bᵀ K^S P_b`,
1213///    `R_S = (G^S_AA)^+ G^S_Ab`, `G^S_res = G^S_bb − G^S_Abᵀ R_S`.
1214///    Eigendecompose `G^S_res`; keep positive eigvecs `Q+`. Then
1215///    `D = (P_b − T R_S) · Q+` (raw-space cols, structurally independent
1216///    of `T`).
1217/// 2. Curvature step (within-block conditioning):
1218///    `G^H_AA = Tᵀ K^H T`, `G^H_AD = Tᵀ K^H D`,
1219///    `R_H = (G^H_AA)^+ G^H_AD`, `E = D − T R_H` (raw-space).
1220///    Curvature Gram `G^H_res = Dᵀ K^H D − G^H_ADᵀ R_H`. Eigendecompose
1221///    and keep positive eigvecs `U`. Then `T_b = E · U`.
1222/// 3. Append: `T ← [T, T_b]`.
1223///
1224/// Returns [`CompilerError::FullyAliased`] only when the first block has no
1225/// usable structural/curvature span. Later fully absorbed blocks compile to a
1226/// zero-width block range, which is the reduced-coordinate representation of
1227/// the lower-priority block owning no degrees of freedom.
1228pub fn compile_from_raw_grams(
1229    gram_h: &Array2<f64>,
1230    gram_struct: &Array2<f64>,
1231    raw_block_ranges: &[std::ops::Range<usize>],
1232    ordering: &[BlockOrder],
1233) -> Result<CompiledMap, CompilerError> {
1234    if raw_block_ranges.len() != ordering.len() {
1235        return Err(CompilerError::DimensionMismatch(format!(
1236            "raw_block_ranges ({}) and ordering ({}) length mismatch",
1237            raw_block_ranges.len(),
1238            ordering.len()
1239        )));
1240    }
1241    let p_raw = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
1242    if gram_h.shape() != [p_raw, p_raw] {
1243        return Err(CompilerError::DimensionMismatch(format!(
1244            "gram_h shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1245            gram_h.shape()
1246        )));
1247    }
1248    if gram_struct.shape() != [p_raw, p_raw] {
1249        return Err(CompilerError::DimensionMismatch(format!(
1250            "gram_struct shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
1251            gram_struct.shape()
1252        )));
1253    }
1254    if raw_block_ranges.is_empty() {
1255        return Ok(CompiledMap {
1256            raw_from_compiled: Array2::<f64>::zeros((0, 0)),
1257            compiled_block_ranges: Vec::new(),
1258            raw_block_ranges: Vec::new(),
1259        });
1260    }
1261    // Validate contiguous ranges from 0.
1262    let mut expected_start = 0usize;
1263    for (b, r) in raw_block_ranges.iter().enumerate() {
1264        if r.start != expected_start {
1265            return Err(CompilerError::DimensionMismatch(format!(
1266                "raw_block_ranges must be contiguous from 0; block {b} starts at {} expected {expected_start}",
1267                r.start
1268            )));
1269        }
1270        expected_start = r.end;
1271    }
1272
1273    // Cumulative raw-from-compiled map. Starts empty (zero compiled cols).
1274    let mut t_cum: Array2<f64> = Array2::<f64>::zeros((p_raw, 0));
1275    let mut compiled_block_ranges: Vec<std::ops::Range<usize>> =
1276        Vec::with_capacity(raw_block_ranges.len());
1277
1278    for (idx, range_b) in raw_block_ranges.iter().enumerate() {
1279        let p_b = range_b.end - range_b.start;
1280        // A zero-width block owns no raw columns. It contributes no compiled
1281        // degrees of freedom and — having no columns — cannot alias against any
1282        // anchor, so it is trivially identifiable. Emit an empty compiled range
1283        // and skip the structural/curvature analysis: a 0×0 residual Gram has no
1284        // positive eigenspace, which the first-block guard below would otherwise
1285        // mis-report as `FullyAliased` even though there is literally nothing to
1286        // alias. This mirrors the empty range a fully-absorbed later block
1287        // already compiles to (see the `q_plus.ncols() == 0` / `u_mat.ncols() == 0`
1288        // branches), keeping `kept_width + dropped_count == raw_width` exact.
1289        if p_b == 0 {
1290            let at = t_cum.ncols();
1291            compiled_block_ranges.push(at..at);
1292            continue;
1293        }
1294        // Slice gram columns/rows by raw block range. P_bᵀ K X = rows
1295        // range_b of K X. K^S T and K^H T are full-rows products.
1296        // 1) Structural rank step.
1297        // K^S · T (p_raw × p_compiled)
1298        let ks_t = fast_ab(gram_struct, &t_cum);
1299        // G^S_AA = Tᵀ K^S T (p_compiled × p_compiled)
1300        let g_s_aa = fast_atb(&t_cum, &ks_t);
1301        // G^S_Ab = Tᵀ K^S P_b = Tᵀ · K^S[:, range_b] (p_compiled × p_b)
1302        let ks_pb = gram_struct
1303            .slice(s![.., range_b.start..range_b.end])
1304            .to_owned();
1305        let g_s_ab = fast_atb(&t_cum, &ks_pb);
1306        // G^S_bb = P_bᵀ K^S P_b = K^S[range_b, range_b] (p_b × p_b)
1307        let g_s_bb = gram_struct
1308            .slice(s![range_b.start..range_b.end, range_b.start..range_b.end])
1309            .to_owned();
1310        // R_S = (G^S_AA)^+ G^S_Ab (p_compiled × p_b)
1311        let r_s = solve_psd_system(&g_s_aa, &g_s_ab)?;
1312        // G^S_res = G^S_bb − G^S_Abᵀ R_S (p_b × p_b), symmetrise.
1313        let g_s_res_raw = &g_s_bb - &fast_atb(&g_s_ab, &r_s);
1314        let g_s_res = symmetrise(&g_s_res_raw);
1315        // Trace of the unresidualised diagonal block (scale ref).
1316        let g_s_bb_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
1317        // p_raw stands in as the "n*K" scale for the closed-form tolerance.
1318        let q_plus = keep_positive_eigenspace(&g_s_res, p_raw, 1, g_s_bb_trace)?;
1319        if q_plus.ncols() == 0 {
1320            if t_cum.ncols() == 0 {
1321                return Err(CompilerError::FullyAliased {
1322                    block_idx: idx,
1323                    reason: format!(
1324                        "structural residual Gram has no positive eigenspace (block of width {p_b} has zero structural span before any anchor exists)"
1325                    ),
1326                });
1327            }
1328            let at = t_cum.ncols();
1329            compiled_block_ranges.push(at..at);
1330            continue;
1331        }
1332        // D = (P_b − T R_S) · Q+ (p_raw × k_kept). Build (P_b − T R_S)
1333        // explicitly as a p_raw × p_b matrix: columns of P_b are columns
1334        // range_b of I_p_raw, so (P_b − T R_S) places −T R_S in all rows
1335        // and adds the identity on rows range_b.
1336        let mut diff = Array2::<f64>::zeros((p_raw, p_b));
1337        if t_cum.ncols() > 0 {
1338            // diff = −T · R_S
1339            let t_rs = fast_ab(&t_cum, &r_s);
1340            for i in 0..p_raw {
1341                for j in 0..p_b {
1342                    diff[[i, j]] = -t_rs[[i, j]];
1343                }
1344            }
1345        }
1346        for j in 0..p_b {
1347            diff[[range_b.start + j, j]] += 1.0;
1348        }
1349        let d_mat = fast_ab(&diff, &q_plus);
1350
1351        // 2) Curvature step.
1352        // K^H · T (p_raw × p_compiled), K^H · D (p_raw × k_kept)
1353        let kh_t = fast_ab(gram_h, &t_cum);
1354        let g_h_aa = fast_atb(&t_cum, &kh_t);
1355        let kh_d = fast_ab(gram_h, &d_mat);
1356        let g_h_ad = fast_atb(&t_cum, &kh_d);
1357        let r_h = solve_psd_system(&g_h_aa, &g_h_ad)?;
1358        // G^H_res = Dᵀ K^H D − G^H_ADᵀ R_H (k_kept × k_kept)
1359        let d_t_kh_d = fast_atb(&d_mat, &kh_d);
1360        let g_h_res_raw = &d_t_kh_d - &fast_atb(&g_h_ad, &r_h);
1361        let g_h_res = symmetrise(&g_h_res_raw);
1362        let k_kept = q_plus.ncols();
1363        let g_h_dd_trace: f64 = (0..k_kept).map(|i| d_t_kh_d[[i, i]].max(0.0)).sum();
1364        let u_mat = keep_positive_eigenspace(&g_h_res, p_raw, 1, g_h_dd_trace)?;
1365        if u_mat.ncols() == 0 {
1366            if t_cum.ncols() == 0 {
1367                return Err(CompilerError::FullyAliased {
1368                    block_idx: idx,
1369                    reason: format!(
1370                        "curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {k_kept}) before any anchor exists"
1371                    ),
1372                });
1373            }
1374            let at = t_cum.ncols();
1375            compiled_block_ranges.push(at..at);
1376            continue;
1377        }
1378        // E = D − T · R_H (p_raw × k_kept); T_b = E · U.
1379        let mut e_mat = d_mat.clone();
1380        if t_cum.ncols() > 0 {
1381            let t_rh = fast_ab(&t_cum, &r_h);
1382            e_mat = &e_mat - &t_rh;
1383        }
1384        let t_b = fast_ab(&e_mat, &u_mat);
1385
1386        let start = t_cum.ncols();
1387        let end = start + t_b.ncols();
1388        compiled_block_ranges.push(start..end);
1389        t_cum = concat_cols(&t_cum, &t_b);
1390    }
1391
1392    // Finite check.
1393    for v in t_cum.iter() {
1394        if !v.is_finite() {
1395            return Err(CompilerError::LinalgFailure(
1396                "compile_from_raw_grams produced non-finite entry in raw_from_compiled".to_string(),
1397            ));
1398        }
1399    }
1400
1401    Ok(CompiledMap {
1402        raw_from_compiled: t_cum,
1403        compiled_block_ranges,
1404        raw_block_ranges: raw_block_ranges.to_vec(),
1405    })
1406}
1407
1408impl CompiledMap {
1409    /// Raw coefficient width (`p_raw`).
1410    pub fn p_raw(&self) -> usize {
1411        self.raw_from_compiled.nrows()
1412    }
1413
1414    /// Compiled (reduced) coefficient width (`p_compiled`).
1415    pub fn p_compiled(&self) -> usize {
1416        self.raw_from_compiled.ncols()
1417    }
1418
1419    /// Reparameterise a raw design into compiled coordinates:
1420    /// `X_compiled = X_raw · T` (`n × p_compiled`). Because the lift is
1421    /// `β_raw = T β_compiled`, the compiled design predicts identically to the
1422    /// raw design on every compiled coefficient: `X_compiled · θ = X_raw · (T θ)`.
1423    /// Families that build directly in reduced coordinates feed this compiled
1424    /// design (and the [`reduce_penalties_with_map`] penalties) to the solver;
1425    /// the rank-deficient raw basis never reaches Newton.
1426    pub fn reduce_design(&self, raw_design: &Array2<f64>) -> Result<Array2<f64>, String> {
1427        if raw_design.ncols() != self.p_raw() {
1428            return Err(format!(
1429                "CompiledMap::reduce_design: raw_design has {} columns, expected p_raw {}",
1430                raw_design.ncols(),
1431                self.p_raw()
1432            ));
1433        }
1434        Ok(fast_ab(raw_design, &self.raw_from_compiled))
1435    }
1436
1437    /// Lift a fitted compiled-width coefficient vector back to raw width:
1438    /// `β_raw = T · β_compiled`. This is the exact inverse direction of the
1439    /// quotient reduction — the reduced coordinates are what Newton/REML
1440    /// operate in, and this map carries the final estimate (and any linear
1441    /// functional of it) back to the original parameterisation so reported
1442    /// coefficients and predictions match the raw design.
1443    pub fn lift_coefficients(&self, beta_compiled: &Array1<f64>) -> Result<Array1<f64>, String> {
1444        if beta_compiled.len() != self.p_compiled() {
1445            return Err(format!(
1446                "CompiledMap::lift_coefficients: beta_compiled len {} != p_compiled {}",
1447                beta_compiled.len(),
1448                self.p_compiled()
1449            ));
1450        }
1451        Ok(self.raw_from_compiled.dot(beta_compiled))
1452    }
1453
1454    /// The rows of `T` belonging to raw block `b` (`T[raw_block_ranges[b], :]`,
1455    /// shape `p_b_raw × p_compiled`). A raw-block penalty `S_b` acts only on
1456    /// these raw columns, so the penalty's reduced-coordinate form depends on
1457    /// `T` only through this slice.
1458    fn raw_block_rows(&self, block_idx: usize) -> Result<Array2<f64>, String> {
1459        let range = self.raw_block_ranges.get(block_idx).ok_or_else(|| {
1460            format!(
1461                "CompiledMap::raw_block_rows: block {block_idx} out of range {}",
1462                self.raw_block_ranges.len()
1463            )
1464        })?;
1465        Ok(self
1466            .raw_from_compiled
1467            .slice(s![range.start..range.end, ..])
1468            .to_owned())
1469    }
1470}
1471
1472/// Transform a per-block raw-width penalty into the compiled (reduced)
1473/// coordinate frame defined by `map`.
1474///
1475/// `raw_penalties[b]` is the penalty matrix `S_b` acting on raw block `b`
1476/// (shape `p_b_raw × p_b_raw`), or `None` for an unpenalised block. The
1477/// returned `reduced[b]` is the **full** `(p_compiled × p_compiled)` penalty
1478/// `Tᵀ Ŝ_b T`, where `Ŝ_b` embeds `S_b` into the `p_raw × p_raw` zero matrix
1479/// at block `b`'s position. Because `Ŝ_b` is zero outside block `b`'s rows and
1480/// columns, this equals `T_bᵀ S_b T_b` with `T_b = T[raw_block_ranges[b], :]`,
1481/// so the reduced penalty is computed from the block's lift rows alone — no
1482/// dense `p_raw × p_raw` embedding is materialised.
1483///
1484/// Exactness: for any compiled coefficient `θ` with raw lift `β = T θ`, the raw
1485/// penalty energy `βᵀ Ŝ_b β = (T θ)ᵀ Ŝ_b (T θ) = θᵀ (Tᵀ Ŝ_b T) θ`, so the
1486/// reduced penalty reproduces the raw penalty energy on every lifted point.
1487/// A compiled block that absorbed to zero width simply contributes a zero
1488/// column range; its raw penalty (if any) projects onto the surviving
1489/// compiled directions through `T_b`, never lost.
1490pub fn reduce_penalties_with_map(
1491    map: &CompiledMap,
1492    raw_penalties: &[Option<Array2<f64>>],
1493) -> Result<Vec<Option<Array2<f64>>>, String> {
1494    if raw_penalties.len() != map.raw_block_ranges.len() {
1495        return Err(format!(
1496            "reduce_penalties_with_map: raw_penalties ({}) != blocks ({})",
1497            raw_penalties.len(),
1498            map.raw_block_ranges.len()
1499        ));
1500    }
1501    let p_compiled = map.p_compiled();
1502    let mut reduced: Vec<Option<Array2<f64>>> = Vec::with_capacity(raw_penalties.len());
1503    for (block_idx, raw_penalty) in raw_penalties.iter().enumerate() {
1504        let Some(s_b) = raw_penalty.as_ref() else {
1505            reduced.push(None);
1506            continue;
1507        };
1508        let p_b_raw = map.raw_block_ranges[block_idx].len();
1509        if s_b.shape() != [p_b_raw, p_b_raw] {
1510            return Err(format!(
1511                "reduce_penalties_with_map: block {block_idx} penalty shape {:?} != [{p_b_raw}, {p_b_raw}]",
1512                s_b.shape()
1513            ));
1514        }
1515        // T_b = T[raw rows of block b, :]  (p_b_raw × p_compiled)
1516        let t_b = map.raw_block_rows(block_idx)?;
1517        // S_compiled = T_bᵀ S_b T_b  (p_compiled × p_compiled)
1518        let s_t_b = fast_ab(s_b, &t_b); // (p_b_raw × p_compiled)
1519        let s_compiled_raw = fast_atb(&t_b, &s_t_b); // (p_compiled × p_compiled)
1520        let mut s_compiled = symmetrise(&s_compiled_raw);
1521        if s_compiled.shape() != [p_compiled, p_compiled] {
1522            return Err(format!(
1523                "reduce_penalties_with_map: block {block_idx} reduced penalty shape {:?} != [{p_compiled}, {p_compiled}]",
1524                s_compiled.shape()
1525            ));
1526        }
1527        for v in s_compiled.iter_mut() {
1528            if !v.is_finite() {
1529                return Err(format!(
1530                    "reduce_penalties_with_map: block {block_idx} reduced penalty has non-finite entry"
1531                ));
1532            }
1533        }
1534        reduced.push(Some(s_compiled));
1535    }
1536    Ok(reduced)
1537}
1538
1539/// Per-block exact orthogonal reparameterisation of structural confounds.
1540///
1541/// `block_transforms[b]` is a dense `(p_b × r_b)` reparam `V_b` mapping raw
1542/// block-`b` coefficients to reduced coordinates: the orthogonalised block
1543/// design is `X_b · V_b`, and a fitted reduced coefficient lifts back to raw
1544/// space exactly via `β_b_raw = V_b · θ_b`. `r_b ≤ p_b`; `r_b < p_b` exactly
1545/// when block `b` carries `p_b − r_b` directions already spanned (in the
1546/// pilot W-metric) by the cumulative anchor of all higher-priority blocks —
1547/// those directions are removed (not penalised), so the joint design
1548/// `[X_0 V_0 | X_1 V_1 | …]` has the overlap excised exactly.
1549pub struct BlockOrthogonalization {
1550    /// `block_transforms[b]`: the `(p_b × r_b)` reparam `V_b` for raw block `b`,
1551    /// in the **original block order** (parallel to the `block_designs` input).
1552    pub block_transforms: Vec<Array2<f64>>,
1553    /// `(block_idx, local_raw_col_count_dropped)` for every block whose
1554    /// reduced width is strictly smaller than its raw width — i.e. the blocks
1555    /// that shed overlap directions against the anchor. Empty when no block
1556    /// overlapped (every `V_b` is then a `p_b × p_b` rotation/identity).
1557    pub dropped: Vec<(usize, usize)>,
1558    /// One structural annotation per input block, in original block order.
1559    ///
1560    /// This is the explicit "same direction vs independent direction" verdict:
1561    /// `Independent` means the block kept its full realized-design rank, while
1562    /// `PartiallyAbsorbed...` / `FullyAbsorbed...` mean the lower-priority block
1563    /// shared realized-design directions with the cumulative anchor and those
1564    /// directions were removed rather than assigned a separate penalty.
1565    pub direction_annotations: Vec<PenalizedDirectionAnnotation>,
1566}
1567
1568/// Build per-block exact W-metric orthogonalising reparameterisations.
1569///
1570/// `block_designs[b]` is the raw `(n × p_b)` design of block `b`.
1571/// `priority[b]` is the block's gauge priority — blocks are residualised in
1572/// **descending** priority order, so the highest-priority block keeps its full
1573/// column span and lower-priority blocks shed only the directions already
1574/// explained by the cumulative higher-priority anchor. `weight` is the pilot
1575/// W-metric row weight `w_i ≥ 0` (the diagonal of the working GLM/GAM Hessian
1576/// at the pilot β); pass an all-ones vector for the plain Euclidean metric.
1577///
1578/// The returned `block_transforms` are in the **original** block order. For a
1579/// block whose columns are all W-orthogonal to the anchor, `V_b` is a square
1580/// `p_b × p_b` orthonormal rotation (rank preserved, round-trip exact). For a
1581/// block with an overlap of dimension `d`, `V_b` is `p_b × (p_b − d)` and the
1582/// `d` overlap directions are removed exactly.
1583///
1584/// Exactness / round-trip: `X_b · V_b` is the reduced design and
1585/// `β_b_raw = V_b · θ_b` lifts a reduced fit back to raw coordinates. `V_b` has
1586/// orthonormal columns (eigenvectors of the residual Gram), so the lift is the
1587/// minimum-norm raw representative of the reduced fit.
1588pub fn orthogonalize_design_blocks(
1589    block_designs: &[Array2<f64>],
1590    priority: &[u32],
1591    weight: &[f64],
1592) -> Result<BlockOrthogonalization, CompilerError> {
1593    if block_designs.len() != priority.len() {
1594        return Err(CompilerError::DimensionMismatch(format!(
1595            "block_designs ({}) and priority ({}) length mismatch",
1596            block_designs.len(),
1597            priority.len()
1598        )));
1599    }
1600    if block_designs.is_empty() {
1601        return Ok(BlockOrthogonalization {
1602            block_transforms: Vec::new(),
1603            dropped: Vec::new(),
1604            direction_annotations: Vec::new(),
1605        });
1606    }
1607    let n = block_designs[0].nrows();
1608    for (b, x) in block_designs.iter().enumerate() {
1609        if x.nrows() != n {
1610            return Err(CompilerError::DimensionMismatch(format!(
1611                "block {b} design has {} rows but block 0 has {n}",
1612                x.nrows()
1613            )));
1614        }
1615    }
1616    if weight.len() != n {
1617        return Err(CompilerError::DimensionMismatch(format!(
1618            "weight length {} != n {n}",
1619            weight.len()
1620        )));
1621    }
1622    // sqrt(W) row scale (clamp tiny-negative to zero — the pilot Hessian
1623    // diagonal is PSD-clamped upstream, but guard against round-off).
1624    let mut sqrt_w = Array1::<f64>::zeros(n);
1625    for i in 0..n {
1626        let wi = weight[i].max(0.0);
1627        sqrt_w[i] = wi.sqrt();
1628    }
1629
1630    // Descending-priority visitation order over the original block indices.
1631    // Stable on ties (preserves input order) so the anchor build is
1632    // deterministic.
1633    let mut order: Vec<usize> = (0..block_designs.len()).collect();
1634    order.sort_by(|&a, &b| priority[b].cmp(&priority[a]));
1635
1636    // Cumulative weighted anchor `A = sqrt(W) · [kept block designs]`.
1637    let mut anchor: Array2<f64> = Array2::<f64>::zeros((n, 0));
1638
1639    // Output transforms indexed by ORIGINAL block index (filled out of order).
1640    let mut block_transforms: Vec<Option<Array2<f64>>> = vec![None; block_designs.len()];
1641    let mut direction_annotations: Vec<Option<PenalizedDirectionAnnotation>> =
1642        vec![None; block_designs.len()];
1643    let mut dropped: Vec<(usize, usize)> = Vec::new();
1644
1645    for &b in order.iter() {
1646        let x_b = &block_designs[b];
1647        let p_b = x_b.ncols();
1648        // Weighted block design `W_b = sqrt(W) · X_b`.
1649        let mut w_b = x_b.clone();
1650        for i in 0..n {
1651            let s = sqrt_w[i];
1652            for j in 0..p_b {
1653                w_b[[i, j]] *= s;
1654            }
1655        }
1656        // Residualise `W_b` against the cumulative anchor in the W-metric and
1657        // eigendecompose the residual Gram. Eigenvectors with positive
1658        // eigenvalues span block `b`'s W-orthogonal-to-anchor column space;
1659        // the zero-eigenvalue directions are exactly the overlap with the
1660        // anchor and are removed.
1661        let (residual, _correction) = residualise_in_metric(&anchor, &w_b)?;
1662        let g_res = symmetrise(&fast_atb(&residual, &residual));
1663        // Scale reference for `keep_positive_eigenspace` must be the
1664        // *original* (pre-residualisation) weighted block Gram trace, NOT the
1665        // residual's. When `b` is fully absorbed by a higher-priority anchor
1666        // the residual collapses to floating-point noise (~ε² of the original
1667        // O(1) data); anchoring tau to that noise floor would keep the noise
1668        // eigenvalues and misreport a fully-absorbed block as `Independent`.
1669        // The original-block trace is invariant to absorption, so a near-zero
1670        // residual is correctly rejected as fully absorbed.
1671        let g_bb = fast_atb(&w_b, &w_b);
1672        let g_bb_trace: f64 = (0..p_b).map(|i| g_bb[[i, i]].max(0.0)).sum();
1673        let v_b = keep_positive_eigenspace(&g_res, n, 1, g_bb_trace)?;
1674        let r_b = v_b.ncols();
1675        let absorbed_width = p_b - r_b;
1676        let kind = if absorbed_width == 0 {
1677            PenalizedDirectionAnnotationKind::Independent
1678        } else if r_b == 0 {
1679            PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority
1680        } else {
1681            PenalizedDirectionAnnotationKind::PartiallyAbsorbedByHigherPriority
1682        };
1683        direction_annotations[b] = Some(PenalizedDirectionAnnotation {
1684            block_idx: b,
1685            raw_width: p_b,
1686            kept_width: r_b,
1687            absorbed_width,
1688            kind,
1689        });
1690        if absorbed_width > 0 {
1691            dropped.push((b, absorbed_width));
1692        }
1693        // Append this block's kept, W-orthogonalised weighted columns to the
1694        // anchor so lower-priority blocks residualise against them too. The
1695        // residual (already anchor-orthogonal) projected onto the kept basis
1696        // is `residual · V_b` — these are mutually orthogonal in the W-metric
1697        // by construction of `keep_positive_eigenspace`.
1698        let kept_weighted = fast_ab(&residual, &v_b);
1699        anchor = concat_cols(&anchor, &kept_weighted);
1700        block_transforms[b] = Some(v_b);
1701    }
1702
1703    let block_transforms: Vec<Array2<f64>> = block_transforms
1704        .into_iter()
1705        .enumerate()
1706        .map(|(b, t)| {
1707            t.ok_or_else(|| {
1708                CompilerError::LinalgFailure(format!(
1709                    "orthogonalize_design_blocks: block {b} transform was never assigned"
1710                ))
1711            })
1712        })
1713        .collect::<Result<Vec<_>, _>>()?;
1714    let direction_annotations: Vec<PenalizedDirectionAnnotation> = direction_annotations
1715        .into_iter()
1716        .enumerate()
1717        .map(|(b, annotation)| {
1718            annotation.ok_or_else(|| {
1719                CompilerError::LinalgFailure(format!(
1720                    "orthogonalize_design_blocks: block {b} direction annotation was never assigned"
1721                ))
1722            })
1723        })
1724        .collect::<Result<Vec<_>, _>>()?;
1725
1726    // Finite check on every transform.
1727    for (b, v) in block_transforms.iter().enumerate() {
1728        for value in v.iter() {
1729            if !value.is_finite() {
1730                return Err(CompilerError::LinalgFailure(format!(
1731                    "orthogonalize_design_blocks: block {b} transform has a non-finite entry"
1732                )));
1733            }
1734        }
1735    }
1736
1737    Ok(BlockOrthogonalization {
1738        block_transforms,
1739        dropped,
1740        direction_annotations,
1741    })
1742}
1743
1744/// Symmetrise a (nearly-symmetric) matrix by averaging with its transpose.
1745fn symmetrise(m: &Array2<f64>) -> Array2<f64> {
1746    let (r, c) = m.dim();
1747    assert_eq!(r, c, "symmetrise expects square matrix");
1748    let mut out = Array2::<f64>::zeros((r, c));
1749    for i in 0..r {
1750        for j in 0..c {
1751            out[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
1752        }
1753    }
1754    out
1755}
1756
1757#[cfg(test)]
1758mod tests {
1759    use super::*;
1760    use ndarray::{Array1, Array2};
1761
1762    /// Convenience: wrap a dense `(n × p)` block design as a `K=1`
1763    /// row-Jacobian operator. Used by tests; production families ship their
1764    /// own concrete operators.
1765    struct DenseScalarOperator {
1766        design: Array2<f64>,
1767    }
1768
1769    impl DenseScalarOperator {
1770        fn new(design: Array2<f64>) -> Self {
1771            Self { design }
1772        }
1773    }
1774
1775    impl RowJacobianOperator for DenseScalarOperator {
1776        fn k(&self) -> usize {
1777            1
1778        }
1779        fn ncols(&self) -> usize {
1780            self.design.ncols()
1781        }
1782        fn nrows(&self) -> usize {
1783            self.design.nrows()
1784        }
1785        fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
1786            assert_eq!(out.len(), 1);
1787            let mut acc = 0.0;
1788            for (j, &b) in delta_beta.iter().enumerate() {
1789                acc += self.design[[row, j]] * b;
1790            }
1791            out[0] = acc;
1792        }
1793        fn evaluate_full(&self) -> Array3<f64> {
1794            let n = self.design.nrows();
1795            let p = self.design.ncols();
1796            let mut out = Array3::<f64>::zeros((n, p, 1));
1797            for i in 0..n {
1798                for j in 0..p {
1799                    out[[i, j, 0]] = self.design[[i, j]];
1800                }
1801            }
1802            out
1803        }
1804    }
1805
1806    // `IdentityRowHessian` is re-exported from the parent module's `use
1807    // super::*;` above (now a public struct so the dual-metric API can
1808    // share the default structural metric with callers).
1809
1810    /// Diagonal row Hessian with per-row scalar weights (K=1 case).
1811    struct DiagonalScalarRowHessian {
1812        w: Array1<f64>,
1813    }
1814
1815    impl DiagonalScalarRowHessian {
1816        fn new(w: Array1<f64>) -> Self {
1817            Self { w }
1818        }
1819    }
1820
1821    impl RowHessian for DiagonalScalarRowHessian {
1822        fn k(&self) -> usize {
1823            1
1824        }
1825        fn nrows(&self) -> usize {
1826            self.w.len()
1827        }
1828        fn fill_row(&self, row: usize, out: &mut [f64]) {
1829            assert_eq!(out.len(), 1);
1830            out[0] = self.w[row];
1831        }
1832        fn evaluate_full(&self) -> Array3<f64> {
1833            let n = self.w.len();
1834            let mut out = Array3::<f64>::zeros((n, 1, 1));
1835            for i in 0..n {
1836                out[[i, 0, 0]] = self.w[i];
1837            }
1838            out
1839        }
1840    }
1841
1842    fn op(design: Array2<f64>) -> Arc<dyn RowJacobianOperator> {
1843        Arc::new(DenseScalarOperator::new(design))
1844    }
1845
1846    /// §10 test #1: two affine blocks, identity row Hessian. The compiled
1847    /// second-block design must be orthogonal to the first block under the
1848    /// (identity) row metric to machine epsilon.
1849    #[test]
1850    fn compile_two_block_orthogonalises_under_metric() {
1851        let n = 50;
1852        let a = Array2::from_shape_fn((n, 3), |(i, j)| ((i + 1) as f64).sin().powi((j + 1) as i32));
1853        // B partly aliases A's first column.
1854        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1855            0.5 * a[[i, 0]] + ((i as f64) * 0.13 + j as f64).cos()
1856        });
1857        let hess = IdentityRowHessian::new(n, 1);
1858        let ops = vec![op(a.clone()), op(b.clone())];
1859        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1860            .expect("compile should succeed");
1861        // Build A's design (no rotation) and B's compiled design B·V − A·M.
1862        let v_b = &compiled.blocks[1].t_lw;
1863        let m_b = compiled.blocks[1]
1864            .anchor_correction
1865            .as_ref()
1866            .expect("second block must carry an anchor correction");
1867        let b_v = b.dot(v_b);
1868        let a_m = a.dot(m_b);
1869        let b_compiled = &b_v - &a_m;
1870        // <A, B_compiled>_I = Aᵀ · B_compiled should be ≈ 0.
1871        let cross = a.t().dot(&b_compiled);
1872        let max_err = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1873        assert!(
1874            max_err < 1e-10,
1875            "orthogonality residual too large: {max_err:e}"
1876        );
1877    }
1878
1879    /// §10 test #2: three-block chain with sequential aliases.
1880    #[test]
1881    fn compile_three_block_chain() {
1882        let n = 80;
1883        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.1 + j as f64).sin());
1884        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
1885            0.3 * a[[i, 0]] + (j as f64) * (i as f64).cos()
1886        });
1887        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
1888            0.2 * a[[i, 1]] + 0.4 * b[[i, 0]] + ((i + j) as f64).tan().min(5.0).max(-5.0)
1889        });
1890        let hess = IdentityRowHessian::new(n, 1);
1891        let ops = vec![op(a), op(b), op(c)];
1892        let compiled = compile(
1893            &ops,
1894            &hess,
1895            &[
1896                BlockOrder::Marginal,
1897                BlockOrder::Logslope,
1898                BlockOrder::LinkDev,
1899            ],
1900        )
1901        .expect("compile should succeed");
1902        let total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
1903        assert_eq!(
1904            compiled.joint_rank, total,
1905            "audit must report full rank on synthetic full-rank design"
1906        );
1907    }
1908
1909    /// §10 test #3: non-identity row Hessian. With K=1 and weights `w`,
1910    /// the projection of a 1-col block `b` onto a 1-col block `a` is
1911    /// `Σ w·a·b / Σ w·a²`. Verify the Gram solve recovers this scalar.
1912    #[test]
1913    fn compile_weighted_metric_nontrivial() {
1914        let n = 32;
1915        let a: Array2<f64> = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 1.0).sqrt());
1916        let b: Array2<f64> =
1917            Array2::from_shape_fn((n, 1), |(i, _)| 0.7 * a[[i, 0]] + (i as f64 * 0.05).cos());
1918        let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.2).sin().abs());
1919        let hess = DiagonalScalarRowHessian::new(w.clone());
1920        let ops = vec![op(a.clone()), op(b.clone())];
1921        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
1922            .expect("compile should succeed");
1923        let m = compiled.blocks[1]
1924            .anchor_correction
1925            .as_ref()
1926            .expect("anchor correction present");
1927        let analytic_num: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * b[[i, 0]]).sum();
1928        let analytic_den: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * a[[i, 0]]).sum();
1929        let analytic = analytic_num / analytic_den;
1930        assert!(m.dim() == (1, 1));
1931        assert!(
1932            (m[[0, 0]] - analytic).abs() < 1e-10,
1933            "weighted projection mismatch: got {got}, analytic {analytic}",
1934            got = m[[0, 0]]
1935        );
1936    }
1937
1938    /// Regression for #372: an anchor block that internally sheds an aliased
1939    /// column makes the residualised kept-anchor width (`anchor_h.ncols()`)
1940    /// strictly smaller than the raw anchor width (`d_total`). The emitted
1941    /// `anchor_correction` must be expressed in *raw* anchor-column
1942    /// coordinates so the predict-time / install-time subtraction
1943    /// `A_raw(x)·M` is dimensionally and metrically correct. Previously the
1944    /// correction was indexed by kept directions, producing a (d_total−1)×k
1945    /// matrix and the failure
1946    /// `anchor_correction shape 36x6 does not match d_total=37`.
1947    #[test]
1948    fn compile_emits_anchor_correction_in_raw_column_coordinates() {
1949        let n = 64;
1950        // Anchor block A has 3 raw columns but only rank 2: col 2 is an exact
1951        // linear combination of cols 0 and 1, so the compiler keeps just two
1952        // anchor directions (kept width 2 < raw width 3).
1953        let a: Array2<f64> = Array2::from_shape_fn((n, 3), |(i, j)| {
1954            let c0 = (i as f64 * 0.07 + 1.0).ln();
1955            let c1 = (i as f64 * 0.13).sin();
1956            match j {
1957                0 => c0,
1958                1 => c1,
1959                _ => 2.0 * c0 - 0.5 * c1,
1960            }
1961        });
1962        // Candidate block C: partly aliases A's span plus genuine signal.
1963        let c: Array2<f64> = Array2::from_shape_fn((n, 2), |(i, j)| {
1964            0.4 * a[[i, 0]] + (j as f64) * (i as f64 * 0.05).cos() + (i as f64 * 0.011).tanh()
1965        });
1966        let w = Array1::from_shape_fn(n, |i| 0.3 + (i as f64 * 0.17).sin().abs());
1967        let hess = DiagonalScalarRowHessian::new(w.clone());
1968        let ops = vec![op(a.clone()), op(c.clone())];
1969        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::LinkDev])
1970            .expect("compile should succeed");
1971
1972        let v = &compiled.blocks[1].t_lw;
1973        let m = compiled.blocks[1]
1974            .anchor_correction
1975            .as_ref()
1976            .expect("candidate block must carry an anchor correction");
1977        let k_kept = v.ncols();
1978        assert!(k_kept >= 1, "candidate must keep at least one direction");
1979
1980        // The off-by-one the issue tripped on: M must have one row per *raw*
1981        // anchor column (3), not per kept anchor direction (2).
1982        assert_eq!(
1983            m.nrows(),
1984            a.ncols(),
1985            "anchor_correction must be indexed by raw anchor columns (d_total), \
1986             got {} rows for {} raw anchor columns",
1987            m.nrows(),
1988            a.ncols(),
1989        );
1990        assert_eq!(m.ncols(), k_kept, "anchor_correction width must match V");
1991
1992        // Metric correctness: the raw-coordinate subtraction A_raw·M must make
1993        // the compiled candidate design W-orthogonal to the full raw anchor
1994        // span. C̃ = C·V − A·M; require Aᵀ W C̃ ≈ 0 column-wise.
1995        let c_v = c.dot(v);
1996        let a_m = a.dot(m);
1997        let c_tilde = &c_v - &a_m;
1998        let mut max_cross = 0.0_f64;
1999        for ac in 0..a.ncols() {
2000            for cc in 0..c_tilde.ncols() {
2001                let mut acc = 0.0;
2002                for i in 0..n {
2003                    acc += w[i] * a[[i, ac]] * c_tilde[[i, cc]];
2004                }
2005                max_cross = max_cross.max(acc.abs());
2006            }
2007        }
2008        assert!(
2009            max_cross < 1e-9,
2010            "raw-coordinate anchor correction must W-orthogonalise the candidate \
2011             against the raw anchor span; max |Aᵀ W C̃| = {max_cross:e}"
2012        );
2013    }
2014
2015    /// §10 test #4: deliberately rank-deficient joint design. The trailing
2016    /// pivot drop must come from the *latest* block in the ordering.
2017    #[test]
2018    fn compile_drops_trailing_pivots_from_latest_block() {
2019        let n = 40;
2020        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2021        // c is exactly a's first column → after residualising c against a,
2022        // the residual span is zero in that direction, but a non-zero
2023        // independent column also exists. Add an extra exact-alias column
2024        // to force trailing-pivot drop at the audit stage.
2025        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2026            if j == 0 {
2027                a[[i, 0]]
2028            } else {
2029                (i as f64 * 0.1).cos()
2030            }
2031        });
2032        let hess = IdentityRowHessian::new(n, 1);
2033        let ops = vec![op(a), op(c)];
2034        // Manually inject a known alias: pass a second block whose
2035        // residualised columns will themselves be linearly dependent on
2036        // the first block after metric projection — already covered by the
2037        // eigenvalue threshold inside `compile`. Verify either drop path
2038        // (eigen-threshold or audit) attributes loss to block index 1.
2039        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2040            .expect("compile should succeed");
2041        // Either the eigen-threshold dropped a column from block 1, or
2042        // the audit did. In both cases block 1's V must have fewer than
2043        // its 2 input columns.
2044        let v1_cols = compiled.blocks[1].t_lw.ncols();
2045        assert!(
2046            v1_cols < 2 || !compiled.dropped.is_empty(),
2047            "expected rank loss attributed to block 1, got v1_cols={v1_cols}, dropped={dropped:?}",
2048            dropped = compiled.dropped
2049        );
2050        for (block_idx, _) in &compiled.dropped {
2051            assert_eq!(
2052                *block_idx, 1,
2053                "audit drops must come from the latest block only"
2054            );
2055        }
2056    }
2057
2058    /// Regression: when `audit_and_drop_trailing_pivots` truncates the
2059    /// latest block's `t_lw`, the sibling `anchor_correction` and `r_lw`
2060    /// matrices must be truncated to the same `k_kept` so the trailing-
2061    /// block install path sees a coherent
2062    /// `t_lw.ncols() == anchor_correction.ncols() == r_lw.ncols()` shape.
2063    ///
2064    /// Pre-fix bug: only `t_lw` got truncated. Downstream callers
2065    /// asserting `anchor_correction.ncols() == k_kept` then failed with
2066    /// `cross-block identifiability: anchor_correction shape D×P does
2067    /// not match expected d_total=D × k_kept=K` — surfaced via the
2068    /// large-scale V+M repro test.
2069    #[test]
2070    fn audit_truncation_keeps_t_lw_and_anchor_correction_in_lockstep() {
2071        let n = 40;
2072        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
2073        let c = Array2::from_shape_fn((n, 2), |(i, j)| {
2074            if j == 0 {
2075                a[[i, 0]]
2076            } else {
2077                (i as f64 * 0.1).cos()
2078            }
2079        });
2080        let hess = IdentityRowHessian::new(n, 1);
2081        let ops = vec![op(a), op(c)];
2082        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2083            .expect("compile should succeed");
2084        for (idx, block) in compiled.blocks.iter().enumerate() {
2085            let k_kept = block.t_lw.ncols();
2086            if let Some(m) = block.anchor_correction.as_ref() {
2087                assert_eq!(
2088                    m.ncols(),
2089                    k_kept,
2090                    "block {idx}: anchor_correction.ncols()={ac} must equal t_lw.ncols()={k_kept} \
2091                     after audit truncation",
2092                    ac = m.ncols(),
2093                );
2094            }
2095            if let Some(r) = block.r_lw.as_ref() {
2096                assert_eq!(
2097                    r.ncols(),
2098                    k_kept,
2099                    "block {idx}: r_lw.ncols()={r_cols} must equal t_lw.ncols()={k_kept} \
2100                     after audit truncation",
2101                    r_cols = r.ncols(),
2102                );
2103            }
2104        }
2105    }
2106
2107    /// §10 test #5: regression test for the deleted FlexEvaluation skip
2108    /// bug. A flex anchor (represented by a dense scalar operator with the
2109    /// same column span as the parametric reference) must receive the same
2110    /// residualisation as the parametric anchor.
2111    #[test]
2112    fn compile_flex_anchor_is_first_class() {
2113        let n = 60;
2114        // Two parametric blocks A, B; a third "flex" block C whose
2115        // operator is dense (modelling a compiled flex anchor's column
2116        // span). All-parametric reference vs. mixed parametric+flex must
2117        // produce identical compiled blocks B (residualised against A)
2118        // because the compiler treats every input as a `RowJacobianOperator`.
2119        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.07 + j as f64).sin());
2120        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2121            0.4 * a[[i, 0]] + (j as f64) * (i as f64 + 1.0).ln()
2122        });
2123        let hess = IdentityRowHessian::new(n, 1);
2124
2125        let ops_param = vec![op(a.clone()), op(b.clone())];
2126        let compiled_param = compile(
2127            &ops_param,
2128            &hess,
2129            &[BlockOrder::Marginal, BlockOrder::Logslope],
2130        )
2131        .expect("compile should succeed");
2132
2133        // Now wrap A's design behind a mock anchor evaluator and feed it
2134        // to the compiler as a `DenseScalarOperator` with the same span.
2135        // The B-block result must match the parametric reference.
2136        let ops_flex = vec![op(a.clone()), op(b.clone())];
2137        let compiled_flex = compile(
2138            &ops_flex,
2139            &hess,
2140            &[BlockOrder::ScoreWarp, BlockOrder::LinkDev],
2141        )
2142        .expect("compile should succeed");
2143
2144        let m_param = compiled_param.blocks[1].anchor_correction.as_ref().unwrap();
2145        let m_flex = compiled_flex.blocks[1].anchor_correction.as_ref().unwrap();
2146        assert_eq!(m_param.dim(), m_flex.dim());
2147        let max_diff = (m_param - m_flex)
2148            .iter()
2149            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2150        assert!(
2151            max_diff < 1e-12,
2152            "flex vs parametric anchor correction mismatch: {max_diff:e}"
2153        );
2154    }
2155
2156    /// §10 test #7: Bernoulli row Hessian = IRLS weight. Verified at the
2157    /// trait level — a `DiagonalScalarRowHessian` round-trips through
2158    /// `evaluate_full` to the same per-row scalar.
2159    #[test]
2160    fn bernoulli_row_hessian_matches_irls_weight() {
2161        let w = Array1::from(vec![0.1, 0.5, 0.9, 0.25, 0.75]);
2162        let hess = DiagonalScalarRowHessian::new(w.clone());
2163        let full = hess.evaluate_full();
2164        assert_eq!(full.shape(), &[5, 1, 1]);
2165        for i in 0..5 {
2166            assert_eq!(full[[i, 0, 0]], w[i]);
2167            let mut buf = [0.0_f64; 1];
2168            hess.fill_row(i, &mut buf);
2169            assert_eq!(buf[0], w[i]);
2170        }
2171    }
2172
2173    /// §10 test #8: predict-path roundtrip. With the parametric setting,
2174    /// the row-application of `(C(x)·V − A(x)·M)` at training rows must
2175    /// equal the in-metric residual computed during `compile`.
2176    #[test]
2177    fn compiler_predict_path_roundtrip() {
2178        let n = 24;
2179        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.21).cos() + j as f64);
2180        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2181            0.3 * a[[i, 0]] + (i as f64 + j as f64).sqrt()
2182        });
2183        let hess = IdentityRowHessian::new(n, 1);
2184        let ops = vec![op(a.clone()), op(b.clone())];
2185        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2186            .expect("compile should succeed");
2187        let v_b = &compiled.blocks[1].t_lw;
2188        let m_b = compiled.blocks[1].anchor_correction.as_ref().unwrap();
2189        // Training-time residual: B · V − A · M.
2190        let predict_design = b.dot(v_b) - a.dot(m_b);
2191        // Compare to the algebraic in-metric residual: same expression
2192        // (identity row Hessian collapses sqrt(H) = I), so this is a
2193        // self-consistency / shape check ensuring V and M compose to the
2194        // promised predict-time operator.
2195        assert_eq!(predict_design.nrows(), n);
2196        assert_eq!(predict_design.ncols(), v_b.ncols());
2197        // Finite-value gate.
2198        for &val in predict_design.iter() {
2199            assert!(val.is_finite(), "predict design produced non-finite entry");
2200        }
2201    }
2202
2203    /// `r_lw` and `anchor_correction` are populated on every non-first
2204    /// block as `M_b · V_b` at compiled width. The first block carries
2205    /// `None`. Also verifies the H-orthogonality invariant that the
2206    /// cumulative anchor for the next iteration is orthogonal (in the row
2207    /// metric) to the prior block's design.
2208    #[test]
2209    fn compile_exposes_r_lw_equal_to_m_dot_v() {
2210        let n = 40;
2211        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.17 + j as f64).sin());
2212        // B partially aliases A's first column, so anchor correction is non-trivial.
2213        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2214            0.6 * a[[i, 0]] + ((i as f64) * 0.11 + j as f64).cos()
2215        });
2216        let hess = IdentityRowHessian::new(n, 1);
2217        let ops = vec![op(a.clone()), op(b.clone())];
2218        let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
2219            .expect("compile should succeed");
2220
2221        // First block: no anchor → both fields None.
2222        assert!(compiled.blocks[0].r_lw.is_none());
2223        assert!(compiled.blocks[0].anchor_correction.is_none());
2224
2225        // Second block: r_lw and anchor_correction must both equal M·V at
2226        // compiled width (p_a_kept × p_b_kept).
2227        let v_a = &compiled.blocks[0].t_lw;
2228        let v_b = &compiled.blocks[1].t_lw;
2229        let m_compiled = compiled.blocks[1]
2230            .anchor_correction
2231            .as_ref()
2232            .expect("second block must carry an anchor correction");
2233        let r_lw = compiled.blocks[1]
2234            .r_lw
2235            .as_ref()
2236            .expect("second block must expose r_lw");
2237        let p_a_kept = v_a.ncols();
2238        let p_b_kept = v_b.ncols();
2239        assert_eq!(
2240            m_compiled.dim(),
2241            (p_a_kept, p_b_kept),
2242            "anchor_correction must be at compiled width"
2243        );
2244        assert_eq!(r_lw.dim(), (p_a_kept, p_b_kept));
2245        // r_lw and anchor_correction are synonymous.
2246        let diff = r_lw - m_compiled;
2247        let max_diff = diff.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2248        assert!(
2249            max_diff == 0.0,
2250            "r_lw and anchor_correction must be identical"
2251        );
2252
2253        // H-orthogonality (identity row metric): the residualised
2254        // compiled B-design `B·V − A·(M·V)` must be orthogonal to A in
2255        // the column-inner-product sense. This validates that the
2256        // cumulative anchor build uses `(W_b − A·M)·V` rather than `W_b·V`.
2257        let b_compiled = b.dot(v_b) - a.dot(m_compiled);
2258        let cross = a.t().dot(&b_compiled);
2259        let max_cross = cross.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
2260        assert!(
2261            max_cross < 1e-10,
2262            "compiled B-design must be H-orthogonal to A: max cross = {max_cross:e}"
2263        );
2264    }
2265
2266    /// `K=4` dense row Hessian: per-row PSD matrix supplied directly.
2267    struct DenseRowHessian {
2268        h: Array3<f64>,
2269    }
2270
2271    impl RowHessian for DenseRowHessian {
2272        fn k(&self) -> usize {
2273            self.h.shape()[1]
2274        }
2275        fn nrows(&self) -> usize {
2276            self.h.shape()[0]
2277        }
2278        fn fill_row(&self, row: usize, out: &mut [f64]) {
2279            let k = self.k();
2280            assert_eq!(out.len(), k * k);
2281            for c in 0..k {
2282                for d in 0..k {
2283                    out[c * k + d] = self.h[[row, c, d]];
2284                }
2285            }
2286        }
2287        fn evaluate_full(&self) -> Array3<f64> {
2288            self.h.clone()
2289        }
2290    }
2291
2292    /// Reference W-based Gram for verification: build `W = sqrt(H) · J` then
2293    /// return `Wᵀ W`. Mirrors the in-walk path in [`compile`].
2294    fn reference_gram_from_w(j_full: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
2295        let w = scale_block_by_sqrt_h(j_full, h_full);
2296        fast_ata(&w)
2297    }
2298
2299    /// Two-block toy at K=4: build per-channel (n × p_b) blocks and verify
2300    /// the closed-form Gram matches the reference W-based Gram.
2301    #[test]
2302    fn closed_form_gram_matches_reference_two_block_k4() {
2303        let n = 17;
2304        let k = 4;
2305        let p_a = 3;
2306        let p_b = 2;
2307
2308        // Random-ish per-channel design matrices for each block.
2309        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2310            (0..4)
2311                .map(|c| {
2312                    let m = Array2::from_shape_fn((n, p), |(i, j)| {
2313                        ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2314                    });
2315                    Some(m)
2316                })
2317                .collect()
2318        };
2319        let block_a = make_block(0.3, n, p_a);
2320        let block_b = make_block(1.1, n, p_b);
2321
2322        // Per-row PSD H: random symmetric PSD via Mᵀ M.
2323        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2324            let mut acc = 0.0;
2325            for r in 0..k {
2326                let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2327                let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.13).cos();
2328                acc += mc * md;
2329            }
2330            acc + if c == d { 0.5 } else { 0.0 }
2331        });
2332        let row_hess = DenseRowHessian { h: h.clone() };
2333
2334        let channel_blocks = PrimaryChannelBlocks {
2335            blocks: vec![block_a.clone(), block_b.clone()],
2336        };
2337        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2338
2339        let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2340            .expect("closed-form Gram should succeed");
2341
2342        // Reference: assemble full row Jacobian J as (n × p_total × K) by
2343        // placing per-block, per-channel slices at the right columns.
2344        let p_total = p_a + p_b;
2345        let mut j_full = Array3::<f64>::zeros((n, p_total, k));
2346        for c in 0..k {
2347            if let Some(xa) = block_a[c].as_ref() {
2348                for i in 0..n {
2349                    for j in 0..p_a {
2350                        j_full[[i, j, c]] = xa[[i, j]];
2351                    }
2352                }
2353            }
2354            if let Some(xb) = block_b[c].as_ref() {
2355                for i in 0..n {
2356                    for j in 0..p_b {
2357                        j_full[[i, p_a + j, c]] = xb[[i, j]];
2358                    }
2359                }
2360            }
2361        }
2362        let ref_gram = reference_gram_from_w(&j_full, &h);
2363
2364        let diff = &gram - &ref_gram;
2365        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2366        let scale = ref_gram.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2367        assert!(
2368            max_err < 1e-9 * scale.max(1.0),
2369            "closed-form Gram mismatches reference: max_err={max_err:e}, scale={scale:e}"
2370        );
2371
2372        // Symmetry of the result.
2373        for i in 0..p_total {
2374            for j in 0..p_total {
2375                assert!(
2376                    (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2377                    "closed-form Gram not symmetric at ({i},{j})"
2378                );
2379            }
2380        }
2381    }
2382
2383    /// Channel sparsity test: block A contributes only to channel 0, block B
2384    /// only to channel 3. Cross-block contribution must be exactly
2385    /// `(X_A^(0))ᵀ · diag(h_{03}) · X_B^(3)` — zero when `h_03 ≡ 0`,
2386    /// non-zero otherwise.
2387    #[test]
2388    fn closed_form_gram_channel_sparsity() {
2389        let n = 13;
2390        let k = 4;
2391        let p_a = 2;
2392        let p_b = 2;
2393
2394        let xa = Array2::from_shape_fn((n, p_a), |(i, j)| ((i + 1) as f64 * 0.21 + j as f64).cos());
2395        let xb = Array2::from_shape_fn((n, p_b), |(i, j)| {
2396            ((i + 1) as f64 * 0.17 + j as f64).sin() + 0.5
2397        });
2398
2399        let block_a: Vec<Option<Array2<f64>>> = vec![Some(xa.clone()), None, None, None];
2400        let block_b: Vec<Option<Array2<f64>>> = vec![None, None, None, Some(xb.clone())];
2401
2402        // Case 1: H with non-zero h_{03} (and h_{30}). The cross-block
2403        // (A, B) entries must equal `Xaᵀ · diag(h_03) · Xb`.
2404        let h_03_vec = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.4).sin());
2405        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2406            // Symmetric: only the (0,3)/(3,0) off-diagonal carries weight,
2407            // plus a strong PSD diagonal so per-row H is PSD.
2408            if (c, d) == (0, 3) || (c, d) == (3, 0) {
2409                h_03_vec[i]
2410            } else if c == d {
2411                2.0
2412            } else {
2413                0.0
2414            }
2415        });
2416        let row_hess = DenseRowHessian { h: h.clone() };
2417
2418        let channel_blocks = PrimaryChannelBlocks {
2419            blocks: vec![block_a.clone(), block_b.clone()],
2420        };
2421        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2422        let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2423            .expect("closed-form Gram should succeed");
2424
2425        // Cross-block submatrix.
2426        let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2427        // Expected: only the (c=0, d=3) channel-pair survives.
2428        let expected = fast_xt_diag_y(&xa, &h_03_vec, &xb);
2429        let diff = &cross - &expected;
2430        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2431        assert!(
2432            max_err < 1e-12,
2433            "cross-block Gram must equal Xaᵀ·diag(h_03)·Xb: max_err={max_err:e}"
2434        );
2435
2436        // Case 2: zero out h_{03} → cross-block must be zero.
2437        let h_zero = Array3::from_shape_fn((n, k, k), |(_, c, d)| if c == d { 2.0 } else { 0.0 });
2438        let row_hess_zero = DenseRowHessian { h: h_zero };
2439        let gram_zero =
2440            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess_zero, &raw_ranges)
2441                .expect("closed-form Gram should succeed");
2442        let cross_zero = gram_zero.slice(s![0..p_a, p_a..(p_a + p_b)]);
2443        let max_zero = cross_zero.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2444        assert!(
2445            max_zero < 1e-12,
2446            "cross-block Gram must vanish when coupling channel pair is zero: got {max_zero:e}"
2447        );
2448    }
2449
2450    /// Structural Gram: identity per-row Hessian collapses the channel-pair
2451    /// sum to within-channel `XᵀX`. Validates [`build_raw_grams_structural`].
2452    #[test]
2453    fn structural_gram_matches_within_channel_sum() {
2454        let n = 11;
2455        let p_a = 2;
2456        let p_b = 3;
2457        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2458            (0..4)
2459                .map(|c| {
2460                    if c == 1 {
2461                        // Sparse channel for variety.
2462                        return None;
2463                    }
2464                    Some(Array2::from_shape_fn((n, p), |(i, j)| {
2465                        ((i as f64 + 1.0) * (j as f64 + 1.0) + seed * (c as f64 + 1.0)).sin()
2466                    }))
2467                })
2468                .collect()
2469        };
2470        let block_a = make_block(0.1, n, p_a);
2471        let block_b = make_block(0.7, n, p_b);
2472        let channel_blocks = PrimaryChannelBlocks {
2473            blocks: vec![block_a.clone(), block_b.clone()],
2474        };
2475        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2476        let gram = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2477
2478        // Hand-compute cross block: Σ_c Xaᵀ Xb over channels where both
2479        // sides are present (skipping channel 1 entirely).
2480        let mut expected_cross = Array2::<f64>::zeros((p_a, p_b));
2481        for c in 0..4 {
2482            if let (Some(xa), Some(xb)) = (block_a[c].as_ref(), block_b[c].as_ref()) {
2483                expected_cross += &fast_atb(xa, xb);
2484            }
2485        }
2486        let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
2487        let diff = &cross - &expected_cross;
2488        let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2489        assert!(
2490            max_err < 1e-12,
2491            "structural cross-block must equal Σ_c Xaᵀ·Xb: max_err={max_err:e}"
2492        );
2493
2494        // Symmetry.
2495        for i in 0..(p_a + p_b) {
2496            for j in 0..(p_a + p_b) {
2497                assert!(
2498                    (gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
2499                    "structural Gram not symmetric at ({i},{j})"
2500                );
2501            }
2502        }
2503    }
2504
2505    // Per-row Hessian (K=1) sourced from an arbitrary positive vector —
2506    // used by the dual-metric sanity test to drive both structural and
2507    // curvature passes with the *same* non-identity weights.
2508    fn diag_hess(w: Array1<f64>) -> DiagonalScalarRowHessian {
2509        DiagonalScalarRowHessian::new(w)
2510    }
2511
2512    /// L#1: dual-metric with structural = curvature reproduces single-metric
2513    /// `compile()` exactly. The two passes degenerate to one because the
2514    /// structural-anchor and curvature-anchor are the same matrix.
2515    #[test]
2516    fn dual_metric_with_equal_metrics_matches_single_metric() {
2517        let n = 36;
2518        let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.13 + j as f64).sin());
2519        // B partially aliases A's first column.
2520        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2521            0.4 * a[[i, 0]] + (i as f64 * 0.07 + j as f64).cos()
2522        });
2523        let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.17).sin().abs());
2524        let curvature = diag_hess(w.clone());
2525        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2526
2527        let ops_single = vec![op(a.clone()), op(b.clone())];
2528        let single = compile(&ops_single, &curvature, &ordering)
2529            .expect("single-metric compile should succeed");
2530
2531        // Dual-metric with structural = curvature (same `RowHessian` on both
2532        // sides). The structural pass collapses to the curvature pass.
2533        let structural_same = diag_hess(w.clone());
2534        let ops_dual = vec![op(a.clone()), op(b.clone())];
2535        let dual = compile_with_dual_metric(&ops_dual, &curvature, &structural_same, &ordering)
2536            .expect("dual-metric compile should succeed");
2537
2538        assert_eq!(single.blocks.len(), dual.blocks.len());
2539        for (idx, (sb, db)) in single.blocks.iter().zip(dual.blocks.iter()).enumerate() {
2540            assert_eq!(sb.t_lw.dim(), db.t_lw.dim(), "block {idx}: V dims differ");
2541            let max_v = (&sb.t_lw - &db.t_lw)
2542                .iter()
2543                .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2544            assert!(max_v < 1e-10, "block {idx}: V mismatch {max_v:e}");
2545            match (sb.anchor_correction.as_ref(), db.anchor_correction.as_ref()) {
2546                (None, None) => {}
2547                (Some(s), Some(d)) => {
2548                    assert_eq!(s.dim(), d.dim());
2549                    let max_m = (s - d).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2550                    assert!(max_m < 1e-10, "block {idx}: M mismatch {max_m:e}");
2551                }
2552                _ => panic!("block {idx}: one side has anchor correction, the other does not"),
2553            }
2554        }
2555        assert_eq!(single.joint_rank, dual.joint_rank);
2556    }
2557
2558    /// L#2: the pilot-curvature trap. A 2-block toy where the pilot
2559    /// curvature `H` has a zero direction that is NOT a real gauge — the
2560    /// dual-metric path keeps it (identity-structural sees it as a full-
2561    /// rank structural direction), while a single-metric path through the
2562    /// same H would drop it.
2563    ///
2564    /// Construction: two K=1 blocks `A` (n × 1) and `B` (n × 1). Choose H
2565    /// (diagonal row weights) so that `H · B` happens to be a scalar
2566    /// multiple of `H · A` (curvature alias) but `B` is *not* a scalar
2567    /// multiple of `A` in the unweighted metric. Specifically, pick rows
2568    /// where `w_i` is non-zero only on a handful of rows where A and B
2569    /// happen to be proportional, and zero on the rows where they differ.
2570    /// Under identity-structural this is structurally-independent; under H
2571    /// it is a (spurious) curvature alias.
2572    #[test]
2573    fn dual_metric_resists_pilot_curvature_alias() {
2574        let n = 12;
2575        // A: x_i = i+1 (no zeros). B: equals 2·A on rows 0..6 only; the
2576        // remaining rows are uncorrelated (linear vs trigonometric).
2577        let a = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) + 1.0);
2578        let b = Array2::from_shape_fn((n, 1), |(i, _)| {
2579            if i < 6 {
2580                2.0 * a[[i, 0]]
2581            } else {
2582                ((i as f64) * 0.3).cos() + 0.5
2583            }
2584        });
2585
2586        // Curvature weights are non-zero ONLY on the rows where B == 2A.
2587        // Under curvature metric, B is exactly 2·A → curvature-rank drops
2588        // B fully. Under identity-structural, B is independent of A across
2589        // all rows → structural-rank is 1 (kept).
2590        let mut w_vec = vec![0.0_f64; n];
2591        for w in &mut w_vec[..6] {
2592            *w = 1.0;
2593        }
2594        let w = Array1::from(w_vec);
2595        let curvature = diag_hess(w.clone());
2596
2597        // Reference single-metric compile (uses identity by `compile()` —
2598        // which now routes through identity-structural). For this test we
2599        // explicitly invoke the dual-metric API both ways.
2600        let id_struct = IdentityRowHessian::new(n, 1);
2601        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2602
2603        // Path 1: dual-metric with identity-structural (the new default).
2604        // Structural pass: B is independent of A across all rows → keep
2605        // B's single column.
2606        let ops_dual = vec![op(a.clone()), op(b.clone())];
2607        let dual = compile_with_dual_metric(&ops_dual, &curvature, &id_struct, &ordering);
2608
2609        // Path 2: dual-metric with structural = curvature (the "H decides
2610        // everything" trap). On the curvature-only rows, B ≡ 2A, so
2611        // structural pass sees zero residual span and rejects the block.
2612        let ops_h_only = vec![op(a.clone()), op(b.clone())];
2613        let h_only = compile_with_dual_metric(&ops_h_only, &curvature, &curvature, &ordering);
2614
2615        // The H-only path must fail (FullyAliased) or strip B's column.
2616        // The dual (identity-structural) path must keep B.
2617        match h_only {
2618            Err(CompilerError::FullyAliased { block_idx, .. }) => {
2619                assert_eq!(block_idx, 1, "H-only path must alias block 1");
2620            }
2621            Ok(out) => {
2622                // If the H-only path somehow compiled, it must have
2623                // either dropped B's column to zero width or audited it
2624                // out. Either way B's V must be empty after the audit
2625                // attributes the drop.
2626                let v1_cols = out.blocks[1].t_lw.ncols();
2627                assert!(
2628                    v1_cols == 0 || !out.dropped.is_empty(),
2629                    "H-only path should reject B's curvature-aliased column; v1_cols={v1_cols}, dropped={dropped:?}",
2630                    dropped = out.dropped,
2631                );
2632            }
2633            Err(other) => panic!("unexpected H-only error: {other:?}"),
2634        }
2635
2636        let dual =
2637            dual.expect("dual-metric must succeed: identity-structural sees B as independent");
2638        // The dual path may still drop B's column at the joint audit step
2639        // because the joint H-scaled design is rank-1 (only the first
2640        // block contributes non-zero rows under the curvature weights).
2641        // What matters is that the *structural* decision did NOT drop B
2642        // — verified by the structural pass not raising FullyAliased and
2643        // by B's `t_lw` having the full structural width before the audit
2644        // demotes it. After audit, B's V may shrink because the curvature
2645        // joint design is rank-deficient, and that is expected.
2646        assert_eq!(dual.blocks.len(), 2);
2647        assert_eq!(dual.blocks[0].t_lw.ncols(), 1, "A must keep its column");
2648        // Block 1 either keeps its structural rank-1 column or is audited
2649        // away by the joint H-rank check, but in either case the per-block
2650        // pre-audit width must reflect that the structural pass kept the
2651        // column (i.e. the function did not return FullyAliased).
2652        let v1_post_audit = dual.blocks[1].t_lw.ncols();
2653        let dropped_count = dual.dropped.len();
2654        assert_eq!(
2655            v1_post_audit + dropped_count,
2656            1,
2657            "structural pass kept B's column; audit may demote it but the pre-audit width was 1"
2658        );
2659    }
2660
2661    /// L#3: identity-structural lets the compiler keep a direction even
2662    /// when the pilot curvature has reduced rank. This is the same
2663    /// scenario as L#2 but with a curvature `H` whose row weights are all
2664    /// strictly positive — so the *only* aliasing source is the structural
2665    /// pass deciding to keep or drop. The dual-metric path with non-trivial
2666    /// `H` and identity-structural must agree with the dual-metric path
2667    /// with identity on both sides whenever the blocks are structurally
2668    /// non-aliased.
2669    #[test]
2670    fn dual_metric_identity_structural_preserves_full_rank() {
2671        let n = 24;
2672        let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 + j as f64).sqrt());
2673        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2674            ((i + 1) as f64).ln() + (i as f64 * 0.1 + j as f64).cos()
2675        });
2676        let w = Array1::from_shape_fn(n, |i| 0.4 + (i as f64 * 0.05).sin().powi(2));
2677        let curvature = diag_hess(w.clone());
2678        let id_struct = IdentityRowHessian::new(n, 1);
2679        let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
2680
2681        let ops = vec![op(a.clone()), op(b.clone())];
2682        let out =
2683            compile_with_dual_metric(&ops, &curvature, &id_struct, &ordering).expect("compile");
2684        // Both blocks structurally independent → both keep full width.
2685        assert_eq!(out.blocks[0].t_lw.ncols(), 2);
2686        assert_eq!(out.blocks[1].t_lw.ncols(), 2);
2687        assert_eq!(out.dropped.len(), 0);
2688        assert_eq!(out.joint_rank, 4);
2689    }
2690
2691    /// Smoke test for the GPU-or-CPU dispatch helper. On non-CUDA hosts
2692    /// (or when the runtime is unavailable) the helper falls back to the
2693    /// CPU closed-form builders; the result must match the CPU builders
2694    /// called directly. When a CUDA runtime is live, parity vs. CPU is
2695    /// verified to tight tolerance.
2696    #[test]
2697    fn build_primary_grams_gpu_or_cpu_two_block_k4_matches_cpu() {
2698        let n = 11;
2699        let k = 4;
2700        let p_a = 2;
2701        let p_b = 3;
2702
2703        let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
2704            (0..4)
2705                .map(|c| {
2706                    let m = Array2::from_shape_fn((n, p), |(i, j)| {
2707                        ((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
2708                    });
2709                    Some(m)
2710                })
2711                .collect()
2712        };
2713        let block_a = make_block(0.7, n, p_a);
2714        let block_b = make_block(-0.4, n, p_b);
2715
2716        let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
2717            let mut acc = 0.0;
2718            for r in 0..k {
2719                let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2720                let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.11).cos();
2721                acc += mc * md;
2722            }
2723            acc + if c == d { 0.25 } else { 0.0 }
2724        });
2725        let row_hess = DenseRowHessian { h: h.clone() };
2726
2727        let channel_blocks = PrimaryChannelBlocks {
2728            blocks: vec![block_a, block_b],
2729        };
2730        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2731
2732        let (gram_h, gram_struct) =
2733            build_primary_grams_gpu_or_cpu(&channel_blocks, &row_hess, &raw_ranges)
2734                .expect("dispatch helper should succeed");
2735
2736        let cpu_h = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
2737            .expect("CPU curvature Gram should succeed");
2738        let cpu_s = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2739
2740        let tol = 1e-9_f64;
2741        for idx in cpu_h.indexed_iter().map(|(i, _)| i) {
2742            let diff = (gram_h[idx] - cpu_h[idx]).abs();
2743            let scale = cpu_h[idx].abs().max(1.0);
2744            assert!(
2745                diff <= tol * scale,
2746                "gram_h mismatch at {idx:?}: helper={} cpu={}",
2747                gram_h[idx],
2748                cpu_h[idx]
2749            );
2750        }
2751        for idx in cpu_s.indexed_iter().map(|(i, _)| i) {
2752            let diff = (gram_struct[idx] - cpu_s[idx]).abs();
2753            let scale = cpu_s[idx].abs().max(1.0);
2754            assert!(
2755                diff <= tol * scale,
2756                "gram_struct mismatch at {idx:?}: helper={} cpu={}",
2757                gram_struct[idx],
2758                cpu_s[idx]
2759            );
2760        }
2761    }
2762
2763    // ---- compile_from_raw_grams tests ----
2764
2765    /// Build (gram_h, gram_struct) for a K=1 scalar two-block toy via the
2766    /// per-block channel-block builders. Used by the closed-form tests
2767    /// below.
2768    fn scalar_grams_two_block(
2769        a: &Array2<f64>,
2770        b: &Array2<f64>,
2771        w: &Array1<f64>,
2772    ) -> (Array2<f64>, Array2<f64>, Vec<std::ops::Range<usize>>) {
2773        let p_a = a.ncols();
2774        let p_b = b.ncols();
2775        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
2776        let channel_blocks = PrimaryChannelBlocks {
2777            blocks: vec![vec![Some(a.clone())], vec![Some(b.clone())]],
2778        };
2779        let row_hess = DiagonalScalarRowHessian::new(w.clone());
2780        let gram_h =
2781            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2782        let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2783        (gram_h, gram_struct, raw_ranges)
2784    }
2785
2786    /// Block B is a column-duplicate of block A in the structural metric
2787    /// → the lower-priority block compiles to zero width instead of making
2788    /// callers skip reduced-coordinate construction.
2789    #[test]
2790    fn compile_from_raw_grams_full_structural_alias() {
2791        let n = 10;
2792        let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (j + 1) as f64).sin());
2793        // Block B = A · L for some 2×2 invertible L → same column span.
2794        let l = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, -0.25, 1.0]).unwrap();
2795        let b = a.dot(&l);
2796        let w = Array1::ones(n);
2797        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2798        let res = compile_from_raw_grams(
2799            &gram_h,
2800            &gram_struct,
2801            &raw_ranges,
2802            &[BlockOrder::Marginal, BlockOrder::Logslope],
2803        )
2804        .expect("lower-priority full alias should compile to zero width");
2805        assert_eq!(res.compiled_block_ranges[0].len(), 2);
2806        assert_eq!(res.compiled_block_ranges[1].len(), 0);
2807        assert_eq!(res.raw_from_compiled.dim(), (4, 2));
2808        assert!(
2809            res.raw_from_compiled
2810                .slice(s![raw_ranges[1].clone(), ..])
2811                .iter()
2812                .all(|v| v.abs() <= 1.0e-12),
2813            "zero-width block must not retain raw coefficient directions in T"
2814        );
2815    }
2816
2817    /// A zero-width *first* block has no columns to alias and must compile to
2818    /// an empty range with the remaining blocks intact — not abort with
2819    /// `FullyAliased`. Regression for the survival location-scale lognormal AFT
2820    /// pre-fit channel-aware audit, whose `time_transform` block collapses to
2821    /// zero free coefficients under the parametric AFT reduction and previously
2822    /// crashed the fit ("block of width 0 has zero structural span").
2823    #[test]
2824    fn compile_from_raw_grams_zero_width_first_block_is_identifiable() {
2825        let n = 12;
2826        let empty = Array2::<f64>::zeros((n, 0));
2827        let b = Array2::from_shape_fn((n, 2), |(i, j)| {
2828            ((i + 1) as f64 * (j + 1) as f64 * 0.23).cos()
2829        });
2830        let w = Array1::ones(n);
2831        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&empty, &b, &w);
2832        let map = compile_from_raw_grams(
2833            &gram_h,
2834            &gram_struct,
2835            &raw_ranges,
2836            &[BlockOrder::Marginal, BlockOrder::Logslope],
2837        )
2838        .expect("zero-width first block must be trivially identifiable, not FullyAliased");
2839        assert_eq!(
2840            map.compiled_block_ranges[0].len(),
2841            0,
2842            "empty first block keeps zero columns"
2843        );
2844        assert_eq!(
2845            map.compiled_block_ranges[1].len(),
2846            2,
2847            "the second block keeps its full structural rank"
2848        );
2849        assert_eq!(map.raw_from_compiled.dim(), (2, 2));
2850    }
2851
2852    #[test]
2853    fn orthogonalization_annotates_independent_and_fully_absorbed_blocks() {
2854        let n = 18;
2855        let anchor = Array2::from_shape_fn((n, 2), |(i, j)| {
2856            ((i + 1) as f64 * (0.19 + j as f64 * 0.07)).sin()
2857        });
2858        let duplicate = anchor.clone();
2859        let independent = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 1) as f64 * 0.43).cos());
2860        let weight = vec![1.0; n];
2861        let ortho = orthogonalize_design_blocks(
2862            &[anchor, duplicate, independent],
2863            &[200, 100, 50],
2864            &weight,
2865        )
2866        .expect("structural annotation compile");
2867
2868        assert_eq!(
2869            ortho.direction_annotations[0].kind,
2870            PenalizedDirectionAnnotationKind::Independent
2871        );
2872        assert_eq!(ortho.direction_annotations[0].absorbed_width, 0);
2873        assert_eq!(
2874            ortho.direction_annotations[1].kind,
2875            PenalizedDirectionAnnotationKind::FullyAbsorbedByHigherPriority,
2876            "a duplicated lower-priority block is the same realized-design direction"
2877        );
2878        assert_eq!(ortho.direction_annotations[1].raw_width, 2);
2879        assert_eq!(ortho.direction_annotations[1].kept_width, 0);
2880        assert_eq!(ortho.direction_annotations[1].absorbed_width, 2);
2881        assert_eq!(
2882            ortho.direction_annotations[2].kind,
2883            PenalizedDirectionAnnotationKind::Independent,
2884            "a genuinely new realized-design direction keeps its own penalty block"
2885        );
2886        assert_eq!(ortho.direction_annotations[2].raw_width, 1);
2887        assert_eq!(ortho.direction_annotations[2].kept_width, 1);
2888        assert_eq!(ortho.dropped, vec![(1, 2)]);
2889    }
2890
2891    #[test]
2892    fn compile_from_raw_grams_three_block_full_logslope_alias_keeps_fast_path() {
2893        let n = 24;
2894        let time = Array2::from_shape_fn((n, 2), |(i, j)| {
2895            ((i + 1) as f64 * (j + 2) as f64 * 0.17).sin()
2896        });
2897        let marginal = Array2::from_shape_fn((n, 1), |(i, _)| ((i + 3) as f64 * 0.11).cos());
2898        let logslope = marginal.clone();
2899        let p_time = time.ncols();
2900        let p_marg = marginal.ncols();
2901        let p_log = logslope.ncols();
2902        let raw_ranges = vec![
2903            0..p_time,
2904            p_time..(p_time + p_marg),
2905            (p_time + p_marg)..(p_time + p_marg + p_log),
2906        ];
2907        let channel_blocks = PrimaryChannelBlocks {
2908            blocks: vec![
2909                vec![Some(time.clone())],
2910                vec![Some(marginal.clone())],
2911                vec![Some(logslope.clone())],
2912            ],
2913        };
2914        let row_hess = DiagonalScalarRowHessian::new(Array1::ones(n));
2915        let gram_h =
2916            build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
2917        let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
2918
2919        let map = compile_from_raw_grams(
2920            &gram_h,
2921            &gram_struct,
2922            &raw_ranges,
2923            &[BlockOrder::Time, BlockOrder::Marginal, BlockOrder::Logslope],
2924        )
2925        .expect("fully aliased logslope block should not skip the compiled-map path");
2926
2927        assert_eq!(map.compiled_block_ranges[0].len(), p_time);
2928        assert_eq!(map.compiled_block_ranges[1].len(), p_marg);
2929        assert_eq!(map.compiled_block_ranges[2].len(), 0);
2930        assert_eq!(
2931            map.raw_from_compiled.dim(),
2932            (p_time + p_marg + p_log, p_time + p_marg)
2933        );
2934        let x_raw = {
2935            let mut out = Array2::<f64>::zeros((n, p_time + p_marg + p_log));
2936            out.slice_mut(s![.., raw_ranges[0].clone()]).assign(&time);
2937            out.slice_mut(s![.., raw_ranges[1].clone()])
2938                .assign(&marginal);
2939            out.slice_mut(s![.., raw_ranges[2].clone()])
2940                .assign(&logslope);
2941            out
2942        };
2943        let x_compiled = fast_ab(&x_raw, &map.raw_from_compiled);
2944        let rrqr = rrqr_with_permutation(&x_compiled, default_rrqr_rank_alpha()).unwrap();
2945        assert_eq!(rrqr.rank, x_compiled.ncols());
2946    }
2947
2948    /// Partial alias: block B's first column duplicates A; second column is
2949    /// independent. Closed-form `T` must have shape `(p_raw × (p_a + 1))`
2950    /// — block 1's compiled width is exactly the independent direction —
2951    /// and the joint design `X_raw · T` must span the same column space as
2952    /// the W-based reference compile result.
2953    #[test]
2954    fn compile_from_raw_grams_partial_alias_matches_w_reference() {
2955        let n = 25;
2956        let a = Array2::from_shape_fn((n, 2), |(i, j)| {
2957            ((i + 1) as f64 * (j + 1) as f64 * 0.3).sin()
2958        });
2959        // B = [a_0  +  independent]
2960        let mut b = Array2::<f64>::zeros((n, 2));
2961        for i in 0..n {
2962            b[[i, 0]] = a[[i, 0]];
2963            b[[i, 1]] = ((i + 1) as f64 * 0.7).cos();
2964        }
2965        let w = Array1::from_shape_fn(n, |i| 1.0 + 0.1 * (i as f64));
2966        let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
2967        let compiled = compile_from_raw_grams(
2968            &gram_h,
2969            &gram_struct,
2970            &raw_ranges,
2971            &[BlockOrder::Marginal, BlockOrder::Logslope],
2972        )
2973        .expect("closed-form compile must succeed");
2974        let p_a = a.ncols();
2975        let p_b = b.ncols();
2976        assert_eq!(compiled.raw_from_compiled.shape()[0], p_a + p_b);
2977        assert_eq!(
2978            compiled.raw_from_compiled.shape()[1],
2979            p_a + 1,
2980            "partial alias should leave compiled width = p_a + 1 (one column dropped from B)"
2981        );
2982        // Block ranges sum to compiled width.
2983        assert_eq!(compiled.compiled_block_ranges[0], 0..p_a);
2984        assert_eq!(
2985            compiled.compiled_block_ranges[1].end - compiled.compiled_block_ranges[1].start,
2986            1
2987        );
2988
2989        // Column-span equality vs. W-reference: stack the raw design
2990        // X_raw = [A | B] and check that range(X_raw · T) ⊆ range(X_raw)
2991        // and has the same rank as the W-based compile.
2992        let mut x_raw = Array2::<f64>::zeros((n, p_a + p_b));
2993        for i in 0..n {
2994            for j in 0..p_a {
2995                x_raw[[i, j]] = a[[i, j]];
2996            }
2997            for j in 0..p_b {
2998                x_raw[[i, p_a + j]] = b[[i, j]];
2999            }
3000        }
3001        let x_compiled = fast_ab(&x_raw, &compiled.raw_from_compiled);
3002        // Rank of compiled design via Gram eigvals.
3003        let g_compiled = fast_ata(&x_compiled);
3004        let (evals, _) = g_compiled.eigh(Side::Lower).unwrap();
3005        let lam_max = evals.iter().cloned().fold(0.0_f64, f64::max);
3006        let tol = lam_max * 64.0 * (g_compiled.nrows() as f64) * f64::EPSILON;
3007        let rank_compiled = evals.iter().filter(|&&l| l > tol).count();
3008        assert_eq!(
3009            rank_compiled,
3010            p_a + 1,
3011            "compiled design column rank must equal p_a + 1 after dropping the alias"
3012        );
3013
3014        // Reference compile via the W-based dual-metric path on the same
3015        // scalar blocks; compiled total width should also be p_a + 1.
3016        let ops_dual: Vec<Arc<dyn RowJacobianOperator>> = vec![op(a.clone()), op(b.clone())];
3017        let curvature = DiagonalScalarRowHessian::new(w.clone());
3018        let id_struct = IdentityRowHessian::new(n, 1);
3019        let dual = compile_with_dual_metric(
3020            &ops_dual,
3021            &curvature,
3022            &id_struct,
3023            &[BlockOrder::Marginal, BlockOrder::Logslope],
3024        )
3025        .expect("dual metric compile should succeed");
3026        let dual_total: usize = dual.blocks.iter().map(|b| b.t_lw.ncols()).sum();
3027        assert_eq!(dual_total, p_a + 1, "W-reference total width should match");
3028    }
3029
3030    /// Three-block toy: changing the ordering changes the per-block
3031    /// compiled widths (later blocks absorb the alias instead of earlier).
3032    #[test]
3033    fn compile_from_raw_grams_three_block_ordering_matters() {
3034        let n = 30;
3035        let a = Array2::from_shape_fn((n, 2), |(i, j)| {
3036            ((i + 1) as f64 * (j + 2) as f64 * 0.2).sin()
3037        });
3038        // B has 2 cols: col 0 independent, col 1 = a[:, 0]
3039        let mut b = Array2::<f64>::zeros((n, 2));
3040        for i in 0..n {
3041            b[[i, 0]] = ((i + 1) as f64 * 0.4).cos();
3042            b[[i, 1]] = a[[i, 0]];
3043        }
3044        // C has 2 cols: col 0 independent, col 1 = a[:, 1]
3045        let mut c = Array2::<f64>::zeros((n, 2));
3046        for i in 0..n {
3047            c[[i, 0]] = ((i + 1) as f64 * 0.55).sin();
3048            c[[i, 1]] = a[[i, 1]];
3049        }
3050        let w = Array1::ones(n);
3051
3052        let build = |b0: &Array2<f64>, b1: &Array2<f64>, b2: &Array2<f64>| {
3053            let raw_ranges = vec![
3054                0..b0.ncols(),
3055                b0.ncols()..(b0.ncols() + b1.ncols()),
3056                (b0.ncols() + b1.ncols())..(b0.ncols() + b1.ncols() + b2.ncols()),
3057            ];
3058            let channel_blocks = PrimaryChannelBlocks {
3059                blocks: vec![
3060                    vec![Some(b0.clone())],
3061                    vec![Some(b1.clone())],
3062                    vec![Some(b2.clone())],
3063                ],
3064            };
3065            let row_hess = DiagonalScalarRowHessian::new(w.clone());
3066            let gram_h =
3067                build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
3068                    .unwrap();
3069            let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
3070            (gram_h, gram_struct, raw_ranges)
3071        };
3072
3073        // Order 1: A, B, C — B drops 1 (col 1 aliased to A), C drops 1.
3074        let (gh, gs, rr) = build(&a, &b, &c);
3075        let order_abc = compile_from_raw_grams(
3076            &gh,
3077            &gs,
3078            &rr,
3079            &[
3080                BlockOrder::Marginal,
3081                BlockOrder::Logslope,
3082                BlockOrder::LinkDev,
3083            ],
3084        )
3085        .expect("ABC compile");
3086        assert_eq!(order_abc.compiled_block_ranges[0].len(), 2);
3087        assert_eq!(order_abc.compiled_block_ranges[1].len(), 1);
3088        assert_eq!(order_abc.compiled_block_ranges[2].len(), 1);
3089
3090        // Order 2: B, A, C — A's col 0 is aliased by B's col 1 now; A's
3091        // col 1 is independent. So A drops 1; C still drops 1.
3092        let (gh2, gs2, rr2) = build(&b, &a, &c);
3093        let order_bac = compile_from_raw_grams(
3094            &gh2,
3095            &gs2,
3096            &rr2,
3097            &[
3098                BlockOrder::Marginal,
3099                BlockOrder::Logslope,
3100                BlockOrder::LinkDev,
3101            ],
3102        )
3103        .expect("BAC compile");
3104        assert_eq!(order_bac.compiled_block_ranges[0].len(), 2);
3105        assert_eq!(order_bac.compiled_block_ranges[1].len(), 1);
3106        // Total rank invariant under permutation: 4.
3107        let total_abc: usize = order_abc
3108            .compiled_block_ranges
3109            .iter()
3110            .map(|r| r.len())
3111            .sum();
3112        let total_bac: usize = order_bac
3113            .compiled_block_ranges
3114            .iter()
3115            .map(|r| r.len())
3116            .sum();
3117        assert_eq!(total_abc, total_bac);
3118        assert_eq!(total_abc, 4);
3119    }
3120
3121    /// Build a K=1 raw `(gram_h, gram_struct)` pair for a single stacked design
3122    /// `X` with per-row curvature weights `w`: `gram_struct = Xᵀ X`,
3123    /// `gram_h = Xᵀ diag(w) X`. Mirrors the closed-form definitions the
3124    /// production Gram builders implement for the scalar-channel case.
3125    fn k1_grams(x: &Array2<f64>, w: &Array1<f64>) -> (Array2<f64>, Array2<f64>) {
3126        let gram_struct = fast_atb(x, x);
3127        let xw = fast_xt_diag_y(x, w, x);
3128        (xw, gram_struct)
3129    }
3130
3131    /// Full-rank reduction: when the two blocks are jointly independent the
3132    /// compiled width equals the raw width and the lift `T` reproduces a raw
3133    /// coefficient exactly from its compiled image `θ = T⁺ β` (here, with no
3134    /// aliasing, `lift_coefficients(θ)` of any compiled `θ` lands in the raw
3135    /// design's column interpretation: applying `T` then comparing the induced
3136    /// raw predictor `X·Tθ` to `X·β_raw` for the `θ` solving `Tθ=β_raw`).
3137    #[test]
3138    fn compiled_map_lift_coefficients_roundtrips_full_rank() {
3139        let n = 21;
3140        let p_a = 2;
3141        let p_b = 2;
3142        // Distinct per-column frequencies make the four sinusoidal columns
3143        // genuinely linearly independent over the sample grid. (A shared phase
3144        // offset varying only by column would collapse every column into
3145        // span{sin θ, cos θ, 1}, i.e. rank 3, and the compiler would correctly
3146        // absorb a column — defeating the full-rank premise of this test.)
3147        let x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3148            ((i as f64 + 1.0) * (0.21 + 0.17 * j as f64)).sin() + 0.11 * (j as f64)
3149        });
3150        let w = Array1::from_shape_fn(n, |i| 0.5 + 0.5 * ((i as f64) * 0.3).cos().abs());
3151        let (gh, gs) = k1_grams(&x, &w);
3152        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3153        let map = compile_from_raw_grams(
3154            &gh,
3155            &gs,
3156            &raw_ranges,
3157            &[BlockOrder::Marginal, BlockOrder::Logslope],
3158        )
3159        .expect("full-rank compile");
3160        // Jointly independent ⇒ no columns absorbed.
3161        assert_eq!(map.p_compiled(), p_a + p_b);
3162        assert_eq!(map.p_raw(), p_a + p_b);
3163        // For a target raw coefficient, solve T θ = β_raw (T square invertible
3164        // here) and confirm lift_coefficients(θ) == β_raw.
3165        let beta_raw = Array1::from_shape_fn(p_a + p_b, |j| 0.4 * (j as f64) - 0.7);
3166        // T is (p × p); recover θ by a least-squares solve via the normal
3167        // equations TᵀT θ = Tᵀ β.
3168        let tt = fast_atb(&map.raw_from_compiled, &map.raw_from_compiled);
3169        let tb = map.raw_from_compiled.t().dot(&beta_raw);
3170        let theta = solve_psd_system(&tt, &tb.insert_axis(Axis(1)))
3171            .expect("normal-equation solve")
3172            .column(0)
3173            .to_owned();
3174        let lifted = map.lift_coefficients(&theta).expect("lift");
3175        let max_err = (&lifted - &beta_raw)
3176            .iter()
3177            .fold(0.0_f64, |a, &v| a.max(v.abs()));
3178        assert!(
3179            max_err < 1e-8,
3180            "lift round-trip error {max_err:e} (full-rank reduction must be exactly invertible)"
3181        );
3182    }
3183
3184    /// Design reparameterisation exactness: the compiled design predicts
3185    /// identically to the raw design on every lifted coefficient, i.e.
3186    /// `X_compiled · θ == X_raw · (T θ)`. This is the contract that lets a
3187    /// family fit in reduced coordinates and still produce raw-design
3188    /// predictions.
3189    #[test]
3190    fn compiled_map_reduce_design_matches_lifted_raw_predictor() {
3191        let n = 23;
3192        let p_a = 3;
3193        let p_b = 3;
3194        let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3195            ((i as f64 + 1.0) * 0.41 + (j as f64 + 1.0) * 0.7).sin() + 0.05 * (i % 3) as f64
3196        });
3197        // Alias one B column onto an A column so the reduction is non-trivial.
3198        for i in 0..n {
3199            x[[i, p_a + 1]] = x[[i, 1]];
3200        }
3201        let w = Array1::from_shape_fn(n, |i| 0.6 + 0.4 * ((i as f64) * 0.25).cos().abs());
3202        let (gh, gs) = k1_grams(&x, &w);
3203        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3204        let map = compile_from_raw_grams(
3205            &gh,
3206            &gs,
3207            &raw_ranges,
3208            &[BlockOrder::Marginal, BlockOrder::Logslope],
3209        )
3210        .expect("compile");
3211        let x_compiled = map.reduce_design(&x).expect("reduce_design");
3212        assert_eq!(x_compiled.ncols(), map.p_compiled());
3213        let theta = Array1::from_shape_fn(map.p_compiled(), |j| 0.3 * (j as f64) - 0.5);
3214        let pred_compiled = x_compiled.dot(&theta);
3215        let beta_raw = map.lift_coefficients(&theta).expect("lift");
3216        let pred_raw = x.dot(&beta_raw);
3217        let max_err = (&pred_compiled - &pred_raw)
3218            .iter()
3219            .fold(0.0_f64, |a, &v| a.max(v.abs()));
3220        assert!(
3221            max_err < 1e-9,
3222            "compiled-design predictor diverges from lifted raw predictor: {max_err:e}"
3223        );
3224    }
3225
3226    /// Penalty-energy preservation: the reduced penalty `Tᵀ Ŝ_b T` reproduces
3227    /// the raw penalty energy `βᵀ Ŝ_b β` on every lifted point `β = T θ`. This
3228    /// is the exactness contract the lift map must satisfy for REML/inference
3229    /// to be invariant to the quotient reparameterisation.
3230    #[test]
3231    fn reduce_penalties_with_map_preserves_energy_on_lift() {
3232        let n = 19;
3233        let p_a = 3;
3234        let p_b = 2;
3235        // Make block B partly aliased with A so the reduction actually drops a
3236        // column — the penalty reduction must still preserve energy on the
3237        // surviving compiled directions.
3238        let mut x = Array2::from_shape_fn((n, p_a + p_b), |(i, j)| {
3239            ((i as f64 + 1.0) * 0.29 + (j as f64 + 1.0) * 0.9).cos()
3240        });
3241        // Column (p_a+0) := column 0 (exact alias) ⇒ B loses one direction.
3242        for i in 0..n {
3243            x[[i, p_a]] = x[[i, 0]];
3244        }
3245        let w = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.2).sin().abs());
3246        let (gh, gs) = k1_grams(&x, &w);
3247        let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
3248        let map = compile_from_raw_grams(
3249            &gh,
3250            &gs,
3251            &raw_ranges,
3252            &[BlockOrder::Marginal, BlockOrder::Logslope],
3253        )
3254        .expect("compile with alias");
3255        assert!(
3256            map.p_compiled() < p_a + p_b,
3257            "expected at least one absorbed column, got p_compiled={}",
3258            map.p_compiled()
3259        );
3260        // A simple per-block raw penalty: ridge on each block.
3261        let s_a = Array2::<f64>::eye(p_a);
3262        let s_b = Array2::<f64>::eye(p_b);
3263        let reduced = reduce_penalties_with_map(&map, &[Some(s_a.clone()), Some(s_b.clone())])
3264            .expect("reduce penalties");
3265        // For random compiled θ, raw β = T θ. Raw energy for block b is
3266        // β[range_b]ᵀ S_b β[range_b]; reduced energy is θᵀ S_reduced_b θ.
3267        let theta = Array1::from_shape_fn(map.p_compiled(), |j| {
3268            0.6 * (j as f64) - 0.3 + 0.05 * (j % 2) as f64
3269        });
3270        let beta = map.lift_coefficients(&theta).expect("lift");
3271        for (block_idx, s_raw) in [(0usize, &s_a), (1usize, &s_b)] {
3272            let range = &map.raw_block_ranges[block_idx];
3273            let beta_b = beta.slice(s![range.start..range.end]).to_owned();
3274            let raw_energy = beta_b.dot(&s_raw.dot(&beta_b));
3275            let s_reduced = reduced[block_idx]
3276                .as_ref()
3277                .expect("reduced penalty present");
3278            let reduced_energy = theta.dot(&s_reduced.dot(&theta));
3279            assert!(
3280                (raw_energy - reduced_energy).abs() < 1e-8 * raw_energy.abs().max(1.0),
3281                "block {block_idx} energy mismatch: raw={raw_energy:e} reduced={reduced_energy:e}"
3282            );
3283        }
3284    }
3285}