Skip to main content

gam_sae/
identifiability.rs

1//! SAE identifiability primitives and partial-supervision gauge fixing.
2//!
3//! # Object 4 — the Certificate ([`residual_gauge`])
4//!
5//! The partial-supervision solver above *removes* gauge freedom by aligning to
6//! auxiliary supervision. The certificate answers the dual question: after a fit
7//! has converged, **which gauge group is the model identified up to?** It does
8//! so by running the same penalty-aware RRQR rank machinery the cross-block
9//! identifiability audit uses
10//! ([`gam_identifiability::audit::audit_identifiability`] /
11//! [`gam_linalg::faer_ndarray::rrqr_with_permutation`]) — but on the
12//! **symmetry generators** of the fitted model rather than on stacked design
13//! columns.
14//!
15//! Each candidate symmetry of the SAE-manifold model (an isometry of an atom's
16//! latent manifold, a rotation inside an ARD-equal eigenspace, a rotation of the
17//! decoder output frame, an exchange of two topology-identical atoms) is
18//! realised as a **tangent direction** `ξ` in the model's free-parameter space.
19//! A generator is an *unpinned residual gauge freedom* iff the converged
20//! objective is flat along it — i.e. `ξ` lies in the null space of the total
21//! curvature operator `H = H_data + H_isometry` (data/likelihood curvature plus
22//! the isometry-penalty curvature). It is *pinned* (broken by the data or the
23//! isometry penalty) iff `ξ` has a component in `range(H)`.
24//!
25//! The RRQR supplies the pinning RANK via the same penalty-aware,
26//! leverage-scaled rank decision the audit uses. Each generator's verdict,
27//! however, keeps the curvature **magnitudes**: the relative curvature
28//! fraction `‖R ξ̂‖² / σ_max(R)²` measures how much objective curvature the
29//! unit generator carries, relative to the model's stiffest direction. A
30//! generator is **unpinned** iff that fraction is within the calibrated
31//! tolerance `max(`[`GENERATOR_FLAT_ENERGY_TOL`]`, lowering_error_scale)` —
32//! genuinely flat up to numerical noise and up to the mean-frame lowering's
33//! own resolution ([`FittedAtom::lowering_error`], #995). Anything larger
34//! means the orbit costs objective, so the exact symmetry is broken and the
35//! generator is **pinned** — including the *mixed* case (partly curved,
36//! partly flat), where replicate fits do NOT differ by that group element
37//! even though some flat directions remain nearby. Magnitudes (not span
38//! membership) keep the statistic informative when `range(H)` is full-rank,
39//! which production fits always are. The fraction and the calibration scale
40//! are reported per generator so partial flatness stays visible instead of
41//! being collapsed into the boolean.
42//!
43//! The whole computation is performed in the inner product carried by the fit's
44//! [`gam_problem::RowMetric`]: the curvature root `R` is built
45//! from the metric-whitened Jacobian, so the certificate's "computed in metric
46//! X" line reads straight off [`gam_problem::RowMetric::provenance`]
47//! ([`gam_problem::MetricProvenance`]) and cannot misreport —
48//! there is only one metric object.
49
50use crate::inference::layer_transport::{ChartTopology, TransportLadderReport, transport_ladder};
51use crate::inference::probe_runner::{ProbeRunner, RealizedProbe};
52use crate::inference::riesz::{RieszInput, SmoothFunctional, debias_with_dense_hessian};
53use gam_problem::{MetricProvenance, RowMetric};
54use gam_terms::inference::structure_evidence::{StructureCertificate, StructureLedger};
55use gam_linalg::faer_ndarray::{
56    FaerCholesky, FaerEigh, FaerQr, FaerSvd, default_rrqr_rank_alpha, rrqr_with_permutation,
57};
58use crate::chart_canonicalization::CanonicalChartTopology;
59use crate::manifold::SaeManifoldTerm;
60use faer::Side;
61use ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, s};
62use std::f64::consts::TAU;
63
64/// Smoothed column-2-norm of the decoder Jacobian.
65///
66/// Returns `(value, grad)` where `value = Σ_k √(Σ_d W[d,k]² + ε²) − ε`
67/// scaled by `weight`, and `grad[d, k] = weight · W[d, k] / √(Σ_d W[d,k]² + ε²)`.
68#[derive(Debug, Clone)]
69pub struct MechanismSparsityJacobian {
70    pub weight: f64,
71    pub epsilon: f64,
72}
73
74impl MechanismSparsityJacobian {
75    pub fn new(weight: f64, epsilon: f64) -> Result<Self, String> {
76        if !(weight.is_finite() && weight > 0.0) {
77            return Err(format!(
78                "MechanismSparsityJacobian: weight must be finite and >0, got {weight}"
79            ));
80        }
81        if !(epsilon.is_finite() && epsilon > 0.0) {
82            return Err(format!(
83                "MechanismSparsityJacobian: epsilon must be finite and >0, got {epsilon}"
84            ));
85        }
86        Ok(Self { weight, epsilon })
87    }
88
89    /// Evaluate value and gradient on a (d_obs, k_latent) decoder weight matrix.
90    pub fn value_and_grad(&self, w: ArrayView2<f64>) -> (f64, Array2<f64>) {
91        let (d, k) = w.dim();
92        let eps2 = self.epsilon * self.epsilon;
93        let mut grad = Array2::<f64>::zeros((d, k));
94        let mut value = 0.0;
95        for col in 0..k {
96            let mut sq = 0.0;
97            for row in 0..d {
98                sq += w[[row, col]] * w[[row, col]];
99            }
100            let denom = (sq + eps2).sqrt();
101            value += denom - self.epsilon;
102            let factor = self.weight / denom;
103            for row in 0..d {
104                grad[[row, col]] = factor * w[[row, col]];
105            }
106        }
107        (self.weight * value, grad)
108    }
109
110    /// Diagonal of the Hessian wrt vec(W). Used as a Newton preconditioner.
111    pub fn hessian_diag(&self, w: ArrayView2<f64>) -> Array2<f64> {
112        let (d, k) = w.dim();
113        let eps2 = self.epsilon * self.epsilon;
114        let mut out = Array2::<f64>::zeros((d, k));
115        for col in 0..k {
116            let mut sq = 0.0;
117            for row in 0..d {
118                sq += w[[row, col]] * w[[row, col]];
119            }
120            let denom = (sq + eps2).sqrt();
121            let inv = 1.0 / denom;
122            let inv3 = inv * inv * inv;
123            for row in 0..d {
124                // ∂² / ∂W[d,k]² of √(||·||²+ε²) = 1/r − W[d,k]²/r³
125                out[[row, col]] = self.weight * (inv - w[[row, col]] * w[[row, col]] * inv3);
126            }
127        }
128        out
129    }
130}
131
132/// iVAE-style auxiliary-conditional Gaussian log-prior on the latent block.
133///
134/// Stores per-row conditional means `μ` of shape `(n_rows, latent_dim)` and
135/// scales `σ` of shape `(n_rows, latent_dim)`, where `(μ_{n,i}, σ_{n,i})` are
136/// presumed evaluated by some external Smooth at the auxiliary `u_n`. The
137/// negative log-prior contribution to the latent objective is
138///
139///   `½ Σ_n Σ_i [ ((t_{n,i} − μ_{n,i}) / σ_{n,i})²
140///                + 2 log σ_{n,i} + log 2π ]`
141///
142/// scaled by `weight`. The gradient w.r.t. `t` is `(t − μ) / σ²` (times
143/// `weight`); the gradient w.r.t. `μ` is its negative. Per-row scales make
144/// this strictly more general than a fixed `N(0, I)`, which is recovered by
145/// `μ ≡ 0`, `σ ≡ 1`.
146#[derive(Debug, Clone)]
147pub struct ConditionalPriorIvae {
148    pub mean: Array2<f64>,
149    pub scale: Array2<f64>,
150    pub weight: f64,
151}
152
153impl ConditionalPriorIvae {
154    pub fn new(mean: Array2<f64>, scale: Array2<f64>, weight: f64) -> Result<Self, String> {
155        if mean.dim() != scale.dim() {
156            return Err(format!(
157                "ConditionalPriorIvae: mean shape {:?} != scale shape {:?}",
158                mean.dim(),
159                scale.dim()
160            ));
161        }
162        if !(weight.is_finite() && weight > 0.0) {
163            return Err(format!(
164                "ConditionalPriorIvae: weight must be finite and >0, got {weight}"
165            ));
166        }
167        for &v in scale.iter() {
168            if !(v.is_finite() && v > 0.0) {
169                return Err(format!(
170                    "ConditionalPriorIvae: every scale must be finite and >0, got {v}"
171                ));
172            }
173        }
174        for &v in mean.iter() {
175            if !v.is_finite() {
176                return Err("ConditionalPriorIvae: mean contains non-finite entry".to_string());
177            }
178        }
179
180        // Khemakhem et al. (arXiv:2107.10098) Theorem 1 identifiability
181        // precondition for the exponential-family conditional prior:
182        // the auxiliary index `u` must yield 2k+1 distinct conditional
183        // priors `p(t|u)` whose sufficient-statistic parameters
184        // `(η_1(u), η_2(u)) = (μ(u)/σ(u)², −1/(2σ(u)²))` span a
185        // 2k-dimensional set. For the diagonal Gaussian family this is
186        // equivalent (an invertible reparameterisation) to requiring that
187        // the stacked signature `S = [μ(u) ‖ log σ(u)]` of shape
188        // (n_rows, 2k) have rank 2k, with at least 2k+1 distinct rows.
189        let (n_rows, latent_dim) = mean.dim();
190        let needed_rows = 2 * latent_dim + 1;
191        if n_rows < needed_rows {
192            return Err(format!(
193                "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
194                 precondition violated: need at least 2k+1 = {needed_rows} distinct \
195                 auxiliary states for latent_dim k = {latent_dim}, got n_rows = {n_rows}"
196            ));
197        }
198        let signature = {
199            let mut s = Array2::<f64>::zeros((n_rows, 2 * latent_dim));
200            for r in 0..n_rows {
201                for c in 0..latent_dim {
202                    s[[r, c]] = mean[[r, c]];
203                    s[[r, latent_dim + c]] = scale[[r, c]].ln();
204                }
205            }
206            s
207        };
208        let first = signature.row(0).to_owned();
209        let all_identical = signature
210            .outer_iter()
211            .all(|row| row.iter().zip(first.iter()).all(|(a, b)| a == b));
212        if all_identical {
213            return Err(format!(
214                "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
215                 precondition violated: all {n_rows} rows of the stacked auxiliary \
216                 signature [μ ‖ log σ] are identical, so the conditional prior is the \
217                 trivial unconditional N(μ, σ²) — provably non-identifiable (no \
218                 auxiliary information)"
219            ));
220        }
221        let (_u, sv, _vt) = signature
222            .svd(false, false)
223            .map_err(|e| format!("ConditionalPriorIvae: SVD of auxiliary signature failed: {e}"))?;
224        let max_sv = sv.iter().cloned().fold(0.0_f64, f64::max);
225        let tol = max_sv * (n_rows.max(2 * latent_dim) as f64) * f64::EPSILON;
226        let numerical_rank = sv.iter().filter(|&&s| s > tol).count();
227        let required = 2 * latent_dim;
228        if numerical_rank < required {
229            return Err(format!(
230                "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
231                 precondition violated: stacked auxiliary signature [μ ‖ log σ] has \
232                 numerical rank {numerical_rank} < 2·latent_dim = {required} \
233                 (tolerance {tol:.3e}); the family `p(t|u)` does not span a \
234                 2k-dimensional set of natural parameters"
235            ));
236        }
237
238        Ok(Self {
239            mean,
240            scale,
241            weight,
242        })
243    }
244
245    /// Evaluate negative-log-prior value and gradient w.r.t. latent t.
246    pub fn value_and_grad(&self, t: ArrayView2<f64>) -> (f64, Array2<f64>) {
247        assert_eq!(
248            t.dim(),
249            self.mean.dim(),
250            "ConditionalPriorIvae: t/mean shape mismatch"
251        );
252        let (n, d) = t.dim();
253        let log_2pi = (2.0 * std::f64::consts::PI).ln();
254        let mut grad = Array2::<f64>::zeros((n, d));
255        let mut value = 0.0;
256        for row in 0..n {
257            for col in 0..d {
258                let mu = self.mean[[row, col]];
259                let sigma = self.scale[[row, col]];
260                let z = (t[[row, col]] - mu) / sigma;
261                value += 0.5 * (z * z + 2.0 * sigma.ln() + log_2pi);
262                grad[[row, col]] = self.weight * z / sigma;
263            }
264        }
265        (self.weight * value, grad)
266    }
267
268    /// Evaluate value only — useful when only the loss is needed.
269    pub fn value(&self, t: ArrayView2<f64>) -> f64 {
270        self.value_and_grad(t).0
271    }
272}
273
274/// Helper: evaluate a piecewise-linear "smooth" `f(u)` columnwise, given a
275/// (k_centres, latent_dim) coefficient table and a (n_rows,) auxiliary vector
276/// `u`. Used by the Python wrapper to back the iVAE per-latent (μ_i(u), σ_i(u))
277/// without having to round-trip through gam's full Smooth machinery for the
278/// minimal experiments. Centres are assumed evenly spaced in [u_min, u_max].
279pub fn piecewise_linear_eval(
280    u: ArrayView1<f64>,
281    coeffs: ArrayView2<f64>,
282    u_min: f64,
283    u_max: f64,
284) -> Array2<f64> {
285    let (k, d) = coeffs.dim();
286    assert!(k >= 2, "piecewise_linear_eval: need ≥2 centres");
287    let n = u.len();
288    let mut out = Array2::<f64>::zeros((n, d));
289    let step = (u_max - u_min) / (k - 1) as f64;
290    for (row, &val) in u.iter().enumerate() {
291        // Clamp `pos` to the exact endpoint `(k-1)`, not `(k-1) - 1e-12`,
292        // so `val = u_max` evaluates to exactly `coeffs[k-1, col]` instead
293        // of `coeffs[k-1, col] + 1e-12 · (coeffs[k-2, col] − coeffs[k-1,
294        // col])`. The historical `1e-12` shift was there to keep `lo + 1`
295        // in range, but capping `lo` at `k − 2` achieves the same
296        // structural guarantee without perturbing the endpoint value.
297        let pos = ((val - u_min) / step).clamp(0.0, (k - 1) as f64);
298        let lo = (pos.floor() as usize).min(k - 2);
299        let hi = lo + 1;
300        let frac = pos - lo as f64;
301        for col in 0..d {
302            out[[row, col]] = coeffs[[lo, col]] * (1.0 - frac) + coeffs[[hi, col]] * frac;
303        }
304    }
305    out
306}
307
308/// Outcome of a 2D log-λ grid-search weight selection.
309///
310/// `evidence_grid[i, j]` is the Laplace-style log marginal-likelihood proxy
311/// at `(lam1_grid[i], lam2_grid[j])`:
312/// `evidence = −½ N log(RSS/N) − ½ (penalty)` with `RSS = rss_grid[i, j]`
313/// and `penalty = penalty_grid[i, j]`.
314///
315/// The winner is `argmax` over the grid; ties are broken by selecting the
316/// `(i, j)` with the smallest `i + j` (i.e. smallest log-weight sum on a
317/// log-spaced grid), then by smallest `i`, then smallest `j` — a fully
318/// deterministic, reproducible policy.
319#[derive(Debug, Clone)]
320pub struct WeightSearchResult {
321    pub best_i: usize,
322    pub best_j: usize,
323    pub best_lam1: f64,
324    pub best_lam2: f64,
325    pub best_evidence: f64,
326    pub evidence_grid: Array2<f64>,
327}
328
329/// Generic 2D log-λ weight-selection driver.
330///
331/// Given a precomputed `(G1, G2)` grid of residual sums-of-squares
332/// `rss_grid`, a matching grid of total-penalty values `penalty_grid`, and
333/// the two 1D weight grids `lam1_grid` / `lam2_grid`, computes the Laplace
334/// log marginal-likelihood proxy on every cell and returns the maximising
335/// cell with deterministic tie-breaking.
336///
337/// The primitive is intentionally agnostic to *what* the two penalty
338/// weights regularise — it takes only the RSS and penalty surfaces, so it
339/// can drive weight selection for any two-penalty model (identifiable
340/// factor model, double-penalty smooths, IBP + sparsity, etc.).
341pub fn identifiable_factor_select_weights(
342    rss_grid: ArrayView2<'_, f64>,
343    penalty_grid: ArrayView2<'_, f64>,
344    lam1_grid: ArrayView1<'_, f64>,
345    lam2_grid: ArrayView1<'_, f64>,
346    n_obs: usize,
347) -> Result<WeightSearchResult, String> {
348    let (g1, g2) = rss_grid.dim();
349    if penalty_grid.dim() != (g1, g2) {
350        return Err(format!(
351            "identifiable_factor_select_weights: penalty_grid shape {:?} \
352             must match rss_grid shape ({}, {})",
353            penalty_grid.dim(),
354            g1,
355            g2
356        ));
357    }
358    if lam1_grid.len() != g1 {
359        return Err(format!(
360            "identifiable_factor_select_weights: lam1_grid len {} must \
361             equal rss_grid rows {}",
362            lam1_grid.len(),
363            g1
364        ));
365    }
366    if lam2_grid.len() != g2 {
367        return Err(format!(
368            "identifiable_factor_select_weights: lam2_grid len {} must \
369             equal rss_grid cols {}",
370            lam2_grid.len(),
371            g2
372        ));
373    }
374    if g1 == 0 || g2 == 0 {
375        return Err("identifiable_factor_select_weights: grids must be non-empty".to_string());
376    }
377    if n_obs == 0 {
378        return Err("identifiable_factor_select_weights: n_obs must be > 0".to_string());
379    }
380    for v in rss_grid.iter() {
381        if !v.is_finite() || *v < 0.0 {
382            return Err(format!(
383                "identifiable_factor_select_weights: rss_grid contains non-finite or \
384                 negative value {v}"
385            ));
386        }
387    }
388    for v in penalty_grid.iter() {
389        if !v.is_finite() {
390            return Err(format!(
391                "identifiable_factor_select_weights: penalty_grid contains non-finite value {v}"
392            ));
393        }
394    }
395    for v in lam1_grid.iter().chain(lam2_grid.iter()) {
396        if !v.is_finite() || *v <= 0.0 {
397            return Err(format!(
398                "identifiable_factor_select_weights: λ grids must contain finite positive \
399                 values, got {v}"
400            ));
401        }
402    }
403
404    let n = n_obs as f64;
405    let rss_floor = 1.0e-300_f64;
406    let mut evidence_grid = Array2::<f64>::zeros((g1, g2));
407    let mut best: Option<(usize, usize, f64)> = None;
408    for i in 0..g1 {
409        for j in 0..g2 {
410            let rss = rss_grid[[i, j]];
411            let pen = penalty_grid[[i, j]];
412            let mean_sq = (rss / n).max(rss_floor);
413            let ev = -0.5 * n * mean_sq.ln() - 0.5 * pen;
414            evidence_grid[[i, j]] = ev;
415            let better = match best {
416                None => true,
417                Some((bi, bj, bev)) => {
418                    if ev > bev {
419                        true
420                    } else if ev == bev {
421                        let cur_sum = i + j;
422                        let best_sum = bi + bj;
423                        if cur_sum < best_sum {
424                            true
425                        } else if cur_sum == best_sum && i < bi {
426                            true
427                        } else {
428                            cur_sum == best_sum && i == bi && j < bj
429                        }
430                    } else {
431                        false
432                    }
433                }
434            };
435            if better {
436                best = Some((i, j, ev));
437            }
438        }
439    }
440    let (best_i, best_j, best_evidence) = best.ok_or_else(|| {
441        "identifiable_factor_select_weights: empty search (this is a bug)".to_string()
442    })?;
443    Ok(WeightSearchResult {
444        best_i,
445        best_j,
446        best_lam1: lam1_grid[best_i],
447        best_lam2: lam2_grid[best_j],
448        best_evidence,
449        evidence_grid,
450    })
451}
452
453/// Column-centred thin-SVD scores: returns the leading `k` columns of
454/// `U Σ` for the centred predictor matrix `X − mean(X, axis=0)`.
455///
456/// Used to seed `T_init` for the partial-supervision recipe when the
457/// caller does not supply one. Pure-Rust path (faer SVD via the
458/// `FaerSvd` bridge) so the seeding math lives in the same crate as the
459/// gauge-fix solver.
460pub fn thin_svd_scores(x: ArrayView2<f64>, k: usize) -> Result<Array2<f64>, String> {
461    let (n, p) = x.dim();
462    if k == 0 {
463        return Ok(Array2::<f64>::zeros((n, 0)));
464    }
465    if k > n.min(p) {
466        return Err(format!(
467            "thin_svd_scores: requested {k} components but min(n={n}, p={p}) limits to {}",
468            n.min(p)
469        ));
470    }
471    let mut mean_row = Array1::<f64>::zeros(p);
472    for row in 0..n {
473        for col in 0..p {
474            mean_row[col] += x[[row, col]];
475        }
476    }
477    if n > 0 {
478        let inv_n = 1.0 / (n as f64);
479        for col in 0..p {
480            mean_row[col] *= inv_n;
481        }
482    }
483    let mut xc = Array2::<f64>::zeros((n, p));
484    for row in 0..n {
485        for col in 0..p {
486            xc[[row, col]] = x[[row, col]] - mean_row[col];
487        }
488    }
489    let (u_opt, sigma, _vt_opt) = xc
490        .svd(true, false)
491        .map_err(|e| format!("thin_svd_scores: SVD failed: {e}"))?;
492    let u = u_opt.ok_or_else(|| "thin_svd_scores: SVD did not return U".to_string())?;
493    let mut out = Array2::<f64>::zeros((n, k));
494    for row in 0..n {
495        for col in 0..k {
496            out[[row, col]] = u[[row, col]] * sigma[col];
497        }
498    }
499    Ok(out)
500}
501
502/// Method for tying the supervised block to the auxiliary signal.
503#[derive(Debug, Clone, Copy, PartialEq, Eq)]
504pub enum PartialSupervisionSupMethod {
505    /// Orthogonal Procrustes: `min_{RᵀR=I} ‖T_sup R - aux‖_F²`.
506    Procrustes,
507    /// Affine least-squares pinned to `anchor_idx`.
508    Anchor,
509    /// Ridge map `A_λ = (TᵀT + λI)⁻¹ Tᵀaux` with REML-selected λ.
510    SoftL2,
511}
512
513/// Free-block decorrelation rule.
514#[derive(Debug, Clone, Copy, PartialEq, Eq)]
515pub enum PartialSupervisionFreeConstraint {
516    /// QR-based projection onto the orthogonal complement of `col(T_sup)`.
517    OrthogonalToSup,
518    /// No projection.
519    None,
520}
521
522/// Result of [`partial_supervision_solve`].
523///
524/// `alignment_score = 1 - ‖T_sup_aligned - aux‖_F² / ‖aux‖_F²` for every
525/// method (1.0 = perfect, 0.0 = no better than the constant-zero predictor).
526/// The fitted gauge map lives in the variant-specific fields:
527///
528/// * Procrustes → `map_r = R` (`d × d` orthogonal).
529/// * Anchor    → `map_a = A` (`d × d`), `map_b` (`d`).
530/// * SoftL2    → `map_a = A_λ` (`d × d`), `selected_weight = λ`.
531#[derive(Debug, Clone)]
532pub struct PartialSupervisionResult {
533    pub t_supervised: Array2<f64>,
534    pub t_free: Array2<f64>,
535    pub alignment_score: f64,
536    pub selected_weight: Option<f64>,
537    pub map_r: Option<Array2<f64>>,
538    pub map_a: Option<Array2<f64>>,
539    pub map_b: Option<Array1<f64>>,
540}
541
542/// Library-level partial-supervision gauge-fix solver.
543///
544/// Solves the supervised-block alignment problem and applies the chosen
545/// free-block decorrelation rule. Pure numerical linear algebra: SVD,
546/// symmetric eigendecomposition (`Side::Lower`), and thin QR are routed
547/// through the faer bridge in `gam_linalg::faer_ndarray`.
548///
549/// This is the single Rust source-of-math for the gauge-fix step; it is
550/// language-agnostic so the CLI, R, and Julia bindings can reuse it
551/// through their own marshaling layers.
552///
553/// Shape requirements:
554/// * `t_sup` is `(N, d_sup)`; `aux` must equal that shape.
555/// * `t_free` is `(N, d_free)` — `d_free` may be 0.
556/// * `anchor_idx` is consulted only when `method == Anchor`; it must be
557///   non-empty and every index must be `< N`.
558pub fn partial_supervision_solve(
559    t_sup: ArrayView2<f64>,
560    aux: ArrayView2<f64>,
561    t_free: ArrayView2<f64>,
562    method: PartialSupervisionSupMethod,
563    anchor_idx: &[usize],
564    free_constraint: PartialSupervisionFreeConstraint,
565) -> Result<PartialSupervisionResult, String> {
566    let (n, d_sup) = t_sup.dim();
567    if aux.dim() != (n, d_sup) {
568        return Err(format!(
569            "partial_supervision_solve: aux shape {:?} must equal t_sup shape ({}, {})",
570            aux.dim(),
571            n,
572            d_sup
573        ));
574    }
575    if t_free.nrows() != n {
576        return Err(format!(
577            "partial_supervision_solve: t_free has {} rows, expected {}",
578            t_free.nrows(),
579            n
580        ));
581    }
582    let aux_norm_sq: f64 = aux.iter().map(|x| x * x).sum();
583    if !(aux_norm_sq.is_finite() && aux_norm_sq > 0.0) {
584        return Err(
585            "partial_supervision_solve: aux has zero or non-finite Frobenius norm".to_string(),
586        );
587    }
588
589    let mut t_sup_aligned = Array2::<f64>::zeros((n, d_sup));
590    let mut map_r: Option<Array2<f64>> = None;
591    let mut map_a: Option<Array2<f64>> = None;
592    let mut map_b: Option<Array1<f64>> = None;
593    let mut selected_weight: Option<f64> = None;
594
595    match method {
596        PartialSupervisionSupMethod::Procrustes => {
597            // R = U Vᵀ where T_supᵀ aux = U Σ Vᵀ.
598            let m = t_sup.t().dot(&aux);
599            let (u_opt, _sigma, vt_opt) = m
600                .svd(true, true)
601                .map_err(|e| format!("partial_supervision_solve: Procrustes SVD failed: {e}"))?;
602            let u = u_opt
603                .ok_or_else(|| "partial_supervision_solve: SVD did not return U".to_string())?;
604            let vt = vt_opt
605                .ok_or_else(|| "partial_supervision_solve: SVD did not return Vᵀ".to_string())?;
606            let r = u.dot(&vt);
607            t_sup_aligned = t_sup.dot(&r);
608            map_r = Some(r);
609        }
610        PartialSupervisionSupMethod::Anchor => {
611            if anchor_idx.is_empty() {
612                return Err(
613                    "partial_supervision_solve: anchor method requires anchor_idx with at \
614                     least one row"
615                        .to_string(),
616                );
617            }
618            for &idx in anchor_idx {
619                if idx >= n {
620                    return Err(format!(
621                        "partial_supervision_solve: anchor index {idx} out of bounds (n={n})"
622                    ));
623                }
624            }
625            // Stack design [Ta | 1] of shape (m, d_sup+1); solve via SVD pseudo-inverse.
626            let m_rows = anchor_idx.len();
627            let mut design = Array2::<f64>::zeros((m_rows, d_sup + 1));
628            let mut targets = Array2::<f64>::zeros((m_rows, d_sup));
629            for (row_out, &row_in) in anchor_idx.iter().enumerate() {
630                for c in 0..d_sup {
631                    design[[row_out, c]] = t_sup[[row_in, c]];
632                    targets[[row_out, c]] = aux[[row_in, c]];
633                }
634                design[[row_out, d_sup]] = 1.0;
635            }
636            let (u_opt, sigma, vt_opt) = design
637                .svd(true, true)
638                .map_err(|e| format!("partial_supervision_solve: Anchor SVD failed: {e}"))?;
639            let u = u_opt
640                .ok_or_else(|| "partial_supervision_solve: anchor SVD lacked U".to_string())?;
641            let vt = vt_opt
642                .ok_or_else(|| "partial_supervision_solve: anchor SVD lacked Vᵀ".to_string())?;
643            // Tikhonov cutoff matches numpy.linalg.lstsq's default rcond policy.
644            let leading = sigma.iter().cloned().fold(0.0_f64, f64::max);
645            let cutoff = leading * f64::EPSILON * (m_rows.max(d_sup + 1) as f64);
646            let rank = sigma.len();
647            let ut_targets = u.t().dot(&targets);
648            let mut scaled = Array2::<f64>::zeros((rank, d_sup));
649            for r in 0..rank {
650                let s = sigma[r];
651                if s > cutoff {
652                    let inv = 1.0 / s;
653                    for c in 0..d_sup {
654                        scaled[[r, c]] = inv * ut_targets[[r, c]];
655                    }
656                }
657            }
658            let coef = vt.t().dot(&scaled);
659            let a = coef.slice(s![..d_sup, ..]).to_owned();
660            let b_vec = coef.slice(s![d_sup, ..]).to_owned();
661            for row in 0..n {
662                for c in 0..d_sup {
663                    let mut acc = b_vec[c];
664                    for k in 0..d_sup {
665                        acc += t_sup[[row, k]] * a[[k, c]];
666                    }
667                    t_sup_aligned[[row, c]] = acc;
668                }
669            }
670            map_a = Some(a);
671            map_b = Some(b_vec);
672        }
673        PartialSupervisionSupMethod::SoftL2 => {
674            // Symmetric eigendecomposition of G = T_supᵀ T_sup.
675            let g = t_sup.t().dot(&t_sup);
676            let (eigvals, eigvecs) = g
677                .eigh(Side::Lower)
678                .map_err(|e| format!("partial_supervision_solve: eigh on Gram failed: {e}"))?;
679            let rhs = t_sup.t().dot(&aux);
680            let ut_aux = eigvecs.t().dot(&rhs);
681            // Per-eigenvector signal energy m_r = ‖row_r(Vᵀ Tᵀaux)‖²; the
682            // multi-response RSS at weight λ is then
683            //   S(λ) = ‖aux‖_F² − Σ_r m_r/(γ_r+λ)
684            // with γ_r the eigenvalues of G = TᵀT (`eigvals`).
685            let m_row: Array1<f64> = Array1::from_vec(
686                (0..d_sup)
687                    .map(|r| (0..d_sup).map(|c| ut_aux[[r, c]] * ut_aux[[r, c]]).sum())
688                    .collect(),
689            );
690            let lam_max = eigvals.iter().cloned().fold(0.0_f64, f64::max);
691            let floor = (lam_max * 1.0e-10).max(1.0e-12);
692            let top = (lam_max * 1.0e3).max(floor * 1.0e6);
693            let grid_n: usize = 64;
694            let log_floor = floor.ln();
695            let log_top = top.ln();
696            // Select λ by REML, never GCV. The ridge map is the linear mixed
697            // model aux_j = T β_j + ε with β_j ~ N(0, σ²/λ I), ε ~ N(0, σ² I)
698            // applied to each of the d columns sharing λ. The map carries no
699            // unpenalized fixed effect, so REML coincides with the marginal
700            // likelihood, whose profile (σ² concentrated out) criterion to
701            // MINIMIZE is
702            //   reml(λ) = n·log S(λ) + Σ_r log(1 + γ_r/λ),
703            // the exact analogue of the smoothing-parameter REML used
704            // everywhere else in gam.
705            let mut best_score = f64::INFINITY;
706            let mut best_lam = floor;
707            for k in 0..grid_n {
708                let frac = if grid_n == 1 {
709                    0.0
710                } else {
711                    (k as f64) / ((grid_n - 1) as f64)
712                };
713                let lam = (log_floor + frac * (log_top - log_floor)).exp();
714                let mut shrunk = 0.0_f64; // Σ_r m_r/(γ_r+λ)
715                let mut logdet = 0.0_f64; // Σ_r log(1 + γ_r/λ)
716                for r in 0..d_sup {
717                    let g = eigvals[r].max(0.0);
718                    shrunk += m_row[r] / (g + lam);
719                    logdet += (1.0 + g / lam).ln();
720                }
721                let s = aux_norm_sq - shrunk;
722                if !(s.is_finite() && s > 0.0) {
723                    continue;
724                }
725                let score = (n as f64) * s.ln() + logdet;
726                if score < best_score {
727                    best_score = score;
728                    best_lam = lam;
729                }
730            }
731            if !best_score.is_finite() {
732                return Err(
733                    "partial_supervision_solve: REML grid did not find a finite-score weight"
734                        .to_string(),
735                );
736            }
737            // Build the ridge map A_λ = (G + λI)⁻¹ Tᵀaux at the REML weight.
738            let denom: Array1<f64> = eigvals.mapv(|v| v + best_lam);
739            let mut a_eig = Array2::<f64>::zeros((d_sup, d_sup));
740            for r in 0..d_sup {
741                for c in 0..d_sup {
742                    a_eig[[r, c]] = ut_aux[[r, c]] / denom[r];
743                }
744            }
745            let best_a = eigvecs.dot(&a_eig);
746            t_sup_aligned = t_sup.dot(&best_a);
747            map_a = Some(best_a);
748            selected_weight = Some(best_lam);
749        }
750    }
751
752    // Single source of truth for alignment_score.
753    let mut sq_resid = 0.0_f64;
754    for row in 0..n {
755        for c in 0..d_sup {
756            let r = t_sup_aligned[[row, c]] - aux[[row, c]];
757            sq_resid += r * r;
758        }
759    }
760    let alignment_score = 1.0 - sq_resid / aux_norm_sq;
761
762    let t_free_out = match free_constraint {
763        PartialSupervisionFreeConstraint::None => t_free.to_owned(),
764        PartialSupervisionFreeConstraint::OrthogonalToSup => {
765            if t_sup_aligned.ncols() == 0 || t_free.ncols() == 0 {
766                t_free.to_owned()
767            } else {
768                let qr_pair = t_sup_aligned
769                    .qr()
770                    .map_err(|e| format!("partial_supervision_solve: QR on T_sup failed: {e}"))?;
771                let q = qr_pair.0;
772                let qt_free = q.t().dot(&t_free);
773                let proj = q.dot(&qt_free);
774                let mut out = t_free.to_owned();
775                out -= &proj;
776                out
777            }
778        }
779    };
780
781    Ok(PartialSupervisionResult {
782        t_supervised: t_sup_aligned,
783        t_free: t_free_out,
784        alignment_score,
785        selected_weight,
786        map_r,
787        map_a,
788        map_b,
789    })
790}
791
792// ============================================================================
793// Object 4 — the Certificate: `residual_gauge()`
794// ============================================================================
795
796/// The latent-manifold topology of one fitted atom, as far as the certificate
797/// needs it to enumerate the atom's isometry-group generators. This mirrors the
798/// user-facing [`crate::manifold::SaeAtomBasisKind`] choice but
799/// carries only what is required to build `Isom(M_k)` tangent directions, so the
800/// certificate is decoupled from the full `SaeManifoldAtom` machinery.
801#[derive(Debug, Clone, PartialEq, Eq)]
802pub enum AtomTopology {
803    /// `S¹` (periodic 1-D). `Isom(S¹) = O(2)`: a single continuous rotation
804    /// generator (shift of the circular coordinate) plus a reflection.
805    Circle,
806    /// `S²` (intrinsic sphere chart). `Isom(S²) = O(3)`: three rotation
807    /// generators (so(3) basis) plus the antipodal/reflection component.
808    Sphere,
809    /// `Tᵈ` (product of `latent_dim` circles). `Isom` contains the `d`
810    /// independent circle shifts (a maximal torus of rotations).
811    Torus { latent_dim: usize },
812    /// A `latent_dim`-dimensional Euclidean patch / Duchon patch. Its connected
813    /// isometry group `SE(d)` is generated by `d` translations and
814    /// `d(d−1)/2` rotations of the latent coordinate frame.
815    EuclideanPatch { latent_dim: usize },
816}
817
818impl AtomTopology {
819    /// Intrinsic latent dimensionality of the atom's manifold.
820    fn latent_dim(&self) -> usize {
821        match self {
822            AtomTopology::Circle => 1,
823            AtomTopology::Sphere => 2,
824            AtomTopology::Torus { latent_dim } => *latent_dim,
825            AtomTopology::EuclideanPatch { latent_dim } => *latent_dim,
826        }
827    }
828}
829
830/// One fitted atom as the certificate sees it.
831///
832/// `frame` is the fitted decoder frame whose columns the isometry generators
833/// rotate: an `(output_dim, latent_dim)` matrix whose column `a` is the fitted
834/// image of latent axis `a` in output space (e.g. the decoder Jacobian columns
835/// at the atom's centroid, or the leading decoder directions). The isometry
836/// generators of `Isom(M_k)` act on these columns; the certificate lifts that
837/// action to a tangent direction on the flattened decoder frame.
838#[derive(Debug, Clone)]
839pub struct FittedAtom {
840    pub name: String,
841    pub topology: AtomTopology,
842    /// `(output_dim, latent_dim)` fitted decoder frame.
843    pub frame: Array2<f64>,
844    /// ARD prior variances (one per latent axis of this atom), used to detect
845    /// equal-ARD eigenspaces inside which a rotation is unpinned by the prior.
846    /// `None` ⇒ no ARD prior on this atom (every within-frame rotation is then
847    /// a candidate generator, pinned-or-not decided solely by the data + the
848    /// isometry penalty).
849    pub ard_variances: Option<Array1<f64>>,
850    /// **Lowering-error scale** (#995), in `[0, 1]`: the mass-weighted relative
851    /// dispersion of the atom's per-row decoder tangents around the mean
852    /// `frame` the certificate compresses them into,
853    /// `Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖²`.
854    ///
855    /// `0` ⇒ the frame represents every row exactly (hand-built fixtures, flat
856    /// decoders) and the certificate's verdicts within this atom are at full
857    /// resolution. Values toward `1` ⇒ a curved decoder whose tangent field
858    /// disperses strongly (e.g. a full circle, whose tangents average to ≈ 0):
859    /// the mean-frame lowering then cannot distinguish gauge motion from
860    /// genuine curvature, so the verdict tolerance for generators touching
861    /// this atom is *calibrated up to this scale* — the certificate refuses to
862    /// claim a pin it cannot resolve, the same honesty contract as the
863    /// `diffeomorphism-unpinned` escalation.
864    pub lowering_error: f64,
865    /// #1019 stage 1: `true` when the atom's `d = 1` latent chart was pinned
866    /// post-fit to its arc-length (unit-speed) canonical representative. #1019
867    /// stage 2: `true` as well when a `d = 2` torus atom's chart was pinned
868    /// post-fit to the minimum-isometry-defect flow representative, in which
869    /// case the residual chart freedom is `Isom(T², flat) = U(1)² ⋊ D₄`. The
870    /// certificate then records that this atom's continuous chart
871    /// (reparameterization) freedom is **pinned by canonicalization** — a
872    /// provenance distinct from curvature/penalty pinning
873    /// ([`VerdictProvenance::PinnedByCanonicalization`]) — and that the
874    /// residual chart freedom is the finite isometry group of the reference
875    /// manifold for `d = 1` charts: rotation + reflection (`O(2)`) on the
876    /// circle, reflection + translation on the interval.
877    pub chart_canonicalized: bool,
878    /// Per-atom inner-decoder-smooth byproducts harvested at fit time, the
879    /// single source the post-PIRLS atom inference reports
880    /// ([`AtomFunctionalReport`] #1097, [`AtomSmoothSignificance`] #1103)
881    /// consume in [`dictionary_report`].
882    ///
883    /// The certificate path that builds `FittedSaeManifold` does so *without* a
884    /// fit harness in scope, so it leaves this `None`; callers that own the
885    /// fitted term attach it through [`FittedAtom::with_inner_fit`] (the term
886    /// builder fills it from the live per-atom basis, decoder, assignment mass,
887    /// and smoothness Gram). When `None`, both reports below are `None`: the
888    /// genuine prerequisite — the post-fit inner-smooth design, penalized
889    /// Hessian, and row scores — is simply not present on a bare
890    /// certificate-only `FittedSaeManifold`.
891    pub inner_fit: Option<AtomInnerFit>,
892}
893
894/// The fitted per-atom inner-decoder smooth, captured once at fit time so the
895/// post-PIRLS atom-inference reports reuse the *same* design, penalized Hessian,
896/// and per-row scores the identifiability certificate's curvature sees.
897///
898/// The SAE decoder reconstructs `Z_i ≈ Σ_k a_ik Φ_k(t_ik) B_k`. Holding all
899/// other atoms and the assignment fixed at the fitted optimum, atom `k`'s own
900/// contribution along a single output channel `j` is the Gaussian-identity
901/// penalized smooth `a_ik · Φ_k(t_ik)ᵀ β_{k,j}` with roughness penalty `S_k`,
902/// Gauss–Newton observation weight `w_i = a_ik²` (the assignment mass enters the
903/// channel linearly, so the normal-equation weight is its square), and
904/// dispersion the fitted reconstruction dispersion. That is an ordinary
905/// penalized WLS smooth — exactly what [`crate::inference::riesz`],
906/// [`gam_terms::inference::lawley`], and the κ-profile machinery consume. The
907/// channel `j` is the atom's dominant decoder output direction (largest column
908/// norm of `B_k`), i.e. the channel that carries the atom's signal.
909#[derive(Debug, Clone)]
910pub struct AtomInnerFit {
911    /// `Φ_k` evaluated on the atom's active rows, `(n_active, M_k)`. The inner
912    /// GAM smooth design. Column 0 is the constant/intercept basis column.
913    pub design: Array2<f64>,
914    /// `∂Φ_k/∂t` along the atom's leading latent axis on the active rows,
915    /// `(n_active, M_k)`: the derivative design the average-derivative
916    /// functional integrates.
917    pub derivative_design: Array2<f64>,
918    /// The fitted decoder coefficients for the captured output channel,
919    /// `β_{k,j} ∈ ℝ^{M_k}`.
920    pub beta: Array1<f64>,
921    /// The atom roughness Gram `S_k`, `(M_k, M_k)`.
922    pub penalty: Array2<f64>,
923    /// The penalized Hessian `H = ΦᵀWΦ + S_k` at the fitted state, `(M_k, M_k)`.
924    pub penalized_hessian: Array2<f64>,
925    /// Per-row Gaussian-identity scores `s_i = ∂nll_i/∂β = −w_i r_i Φ_i / φ`,
926    /// `(n_active, M_k)`, on the captured channel.
927    pub row_scores: Array2<f64>,
928    /// Per-row Gauss–Newton weights `w_i = a_ik²` on the captured channel.
929    pub weights: Array1<f64>,
930    /// Fitted reconstruction dispersion `φ` (Gaussian σ²).
931    pub dispersion: f64,
932    /// Design row at the latent peak `t_peak` (largest fitted `|g_k|`).
933    pub peak_design_row: Array1<f64>,
934    /// Design row at the latent mode `t_mode` (largest assignment mass).
935    pub mode_design_row: Array1<f64>,
936}
937
938impl FittedAtom {
939    /// Attach the inner-decoder-smooth byproducts harvested at fit time. The
940    /// term builder calls this so [`dictionary_report`] can produce the three
941    /// post-PIRLS atom inference reports.
942    pub fn with_inner_fit(mut self, inner_fit: AtomInnerFit) -> Self {
943        self.inner_fit = Some(inner_fit);
944        self
945    }
946}
947
948/// Descriptive penalty-debiased POINT summaries of one fitted atom's decoder
949/// curve (#1097, narrowed under #1115). Each field is a scalar functional of the
950/// atom's inner smooth `g_k(t)`, reported as a plug-in value and a one-step
951/// penalty-debiased value (the regularization bias relative to the conditional
952/// target is removed through the atom fit's penalized Hessian). No standard
953/// error and no confidence interval are reported — by design (see below).
954///
955/// # Why these carry NO coverage claim (#1115)
956///
957/// Conditional on the fitted latent coordinates `t̂` and assignment `â`, each
958/// functional is an ordinary linear functional of the penalized-WLS coefficients
959/// `β` with a well-defined *conditional* population value, and one-step debiasing
960/// validly removes the penalty bias for that conditional target. The point
961/// estimates are therefore meaningful. A *standard error*, however, would only be
962/// honest if `t̂` and `â` were fixed/known. They are not: they are **generated
963/// regressors** estimated from the very activations that also form the response
964/// `Z`, so `Z` enters both the design (via `t̂(Z), â(Z)`) and the response. An
965/// influence-function SE built from the β-only Hessian and row scores carries no
966/// `∂t̂/∂Z` / `∂â/∂Z` channel — exactly the generated-regressor correction the
967/// marginal-slope family (#461 Stage 2) is *defined* by — so it omits a
968/// first-order variance term and is generally anti-conservative. Rather than ship
969/// an SE/CI that silently under-covers, this report exposes only the debiased
970/// point summaries; a coverage-valid interval would require either freezing the
971/// dictionary on a held-out split or propagating the generated-regressor
972/// Jacobian, neither of which the fixed inner-fit snapshot supports.
973#[derive(Debug, Clone)]
974pub struct AtomFunctionalReport {
975    /// `g(t_peak) − g(t_mode)`: the peak-vs-baseline contrast of the fitted
976    /// decoder, penalty-debiased through the inner-fit Hessian. Point summary
977    /// only (no coverage claim — see the type doc).
978    pub peak_contrast: Option<AtomFunctionalEstimate>,
979    /// `E_data[g(t_i)]`: the data-averaged decoder value over the atom's active
980    /// rows, penalty-debiased. Point summary only.
981    pub average_value: Option<AtomFunctionalEstimate>,
982    /// `E_data[∂g/∂t]` along the atom's leading latent axis: how much the fitted
983    /// decoder curve varies across the data distribution, **conditional on the
984    /// fit**. A descriptive variation measure of the fitted curve, NOT a
985    /// population "marginal slope" (the latent coordinate is itself a fitted,
986    /// generated regressor). Point summary only.
987    ///
988    /// Despite the historical `_norm` suffix this is the **signed** mass-weighted
989    /// mean derivative `E_data[∂g/∂t]` over the single leading axis, not a
990    /// magnitude — it can be negative, and a value near 0 means the average slope
991    /// cancels (a symmetric bump), not that the curve is flat. Use
992    /// [`AtomSmoothSignificance::log_e_nonconstant`] for an honest non-constancy
993    /// test; this field only describes the average local slope.
994    pub decoder_variation_norm: Option<AtomFunctionalEstimate>,
995}
996
997/// One atom decoder-functional point summary: the plug-in value and the one-step
998/// penalty-debiased value, with the removed penalty bias. Deliberately carries
999/// NO standard error / confidence interval — the conditional-on-generated-
1000/// regressors variance channel is unmodelled, so any SE would under-cover
1001/// (#1115). Use [`AtomSmoothSignificance`] for an honest any-n-valid structure
1002/// test instead.
1003#[derive(Debug, Clone, Copy)]
1004pub struct AtomFunctionalEstimate {
1005    /// The raw plug-in functional value `θ̂ = g·β̂`.
1006    pub theta_plugin: f64,
1007    /// The one-step penalty-debiased value `θ̂ − bias`, removing the
1008    /// regularization bias relative to the conditional target.
1009    pub theta_onestep: f64,
1010    /// The removed penalty bias `(H⁻¹ g)·(Sβ̂)`.
1011    pub penalty_bias: f64,
1012}
1013
1014/// Any-n-valid structure evidence that one atom's inner smooth `h_k(t)` is
1015/// genuinely non-constant (#1103): the same split-likelihood-ratio e-value the
1016/// atom-birth gate uses ([`gam_terms::inference::structure_evidence`]), under the
1017/// null H0 = "the atom's decoder curve is constant in its latent coordinate".
1018///
1019/// This replaces the earlier Lawley–Bartlett-corrected χ² test. That correction
1020/// was a category error here: the penalized smooth's null is effectively
1021/// rank ≈ n, the first-order χ² is the wrong reference entirely, and an O(1/n)
1022/// Bartlett factor (whose own stated size shift is ≈0.15%, flipping no admit/
1023/// demote decision) does not rescue it. The split-LRT e-value is finite-sample
1024/// valid with NO regularity conditions — exactly the instrument for "does this
1025/// atom earn a latent dimension".
1026#[derive(Debug, Clone)]
1027pub struct AtomSmoothSignificance {
1028    /// `log E` for "the atom's smooth is non-constant" (null = constant). A
1029    /// universal-inference split-likelihood-ratio e-value: `E_{H0}[E] ≤ 1`
1030    /// exactly, so `E ≥ 1/α` certifies the non-constant alternative at level α,
1031    /// at any data-dependent stopping time. `None` when the split is degenerate
1032    /// (too few active rows / a fold with no curvature column).
1033    pub log_e_nonconstant: Option<f64>,
1034}
1035
1036/// The post-PIRLS inference reports for one atom, paired by atom index.
1037///
1038/// Two reports survive #1115: the descriptive penalty-debiased point summaries
1039/// of the fitted decoder curve ([`AtomFunctionalReport`], no coverage claim) and
1040/// the any-n-valid split-LRT smooth-structure e-value ([`AtomSmoothSignificance`],
1041/// a genuine finite-sample-valid test). The #1099 per-atom curvature *confidence
1042/// interval* was removed: its target (a sup-norm extrinsic-curvature BOUND read
1043/// off the fitted decoder) is not an estimand with a profiled criterion, and its
1044/// delta-method SE conditioned on the generated latent coordinates as if known.
1045/// The plug-in curvature point estimate itself survives — as the per-atom
1046/// `kappa_hat` entries of
1047/// [`crate::manifold::CertificateInputs::per_atom_kappa_hat`] (the
1048/// #1008 empirical curved-dictionary report, surfaced to Python as
1049/// `ManifoldSAE.curvature_report`), the single source of truth for the bound.
1050/// It is deliberately *not* duplicated onto this report: a descriptive geometry
1051/// bound is a property of the fitted decoder frames, not of the post-PIRLS
1052/// inner-smooth inference snapshot this type carries.
1053#[derive(Debug, Clone)]
1054pub struct AtomInferenceReport {
1055    pub atom_index: usize,
1056    pub atom_name: String,
1057    pub functionals: Option<AtomFunctionalReport>,
1058    pub smooth_significance: Option<AtomSmoothSignificance>,
1059}
1060
1061/// The fitted SAE-manifold model the certificate consumes.
1062///
1063/// Self-contained on purpose: it carries exactly the objects the residual-gauge
1064/// computation needs — the atoms (with topology + fitted frames + ARD), the
1065/// curvature/Jacobian row-blocks that pin directions, and the one
1066/// [`RowMetric`] whose provenance the report reads. The flattened free-parameter
1067/// vector the generators live in is `vec(frame_0) ⊕ vec(frame_1) ⊕ …` in atom
1068/// order; `param_dim()` is its length.
1069pub struct FittedSaeManifold {
1070    pub atoms: Vec<FittedAtom>,
1071    /// Per-row decoder Jacobian blocks `J_n ∈ ℝ^{p × param_dim}` flattened
1072    /// row-major (`J_n[i, c] = jacobian_rows[n][i * param_dim + c]`), one entry
1073    /// per metric row. These are the directions the *data* gives cost to; the
1074    /// certificate whitens them through [`RowMetric`] and orthonormalizes to
1075    /// obtain the data part of the pinning span `range(H_data)`.
1076    pub jacobian_rows: Vec<Vec<f64>>,
1077    /// The isometry-penalty curvature root `R ∈ ℝ^{r × param_dim}` (so the
1078    /// penalty Hessian is `RᵀR`). Its row space is `range(H_isometry)` — the
1079    /// directions the isometry pin gives cost to. Empty (`0 × param_dim`) when
1080    /// the isometry pin is inactive, which is exactly the condition that
1081    /// escalates the verdict to `diffeomorphism-unpinned`.
1082    pub isometry_penalty_root: Array2<f64>,
1083    /// The single provenance-carrying per-row inner product. Read for the
1084    /// report's "computed in metric X" line and used to whiten the Jacobian
1085    /// rows so the rank decision happens in the fit's actual metric.
1086    pub metric: RowMetric,
1087}
1088
1089impl FittedSaeManifold {
1090    /// Total flattened free-parameter dimension `Σ_k output_dim_k · latent_dim_k`
1091    /// (the decoder-frame coordinates the generators are tangent directions in).
1092    pub fn param_dim(&self) -> usize {
1093        self.atoms.iter().map(|a| a.frame.len()).sum()
1094    }
1095
1096    /// Column offset of atom `k`'s flattened frame inside the joint parameter
1097    /// vector.
1098    fn atom_offset(&self, k: usize) -> usize {
1099        self.atoms[..k].iter().map(|a| a.frame.len()).sum()
1100    }
1101}
1102
1103/// Which symmetry family a generator belongs to. Carried per-generator so the
1104/// report names the group the residual freedom (or pin) lives in.
1105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1106pub enum GeneratorFamily {
1107    /// A generator of `Isom(M_k)` for a single atom (frame rotation/reflection
1108    /// realising the atom's own manifold isometry).
1109    IsomAtom,
1110    /// A rotation inside an ARD-equal eigenspace (the ARD prior cannot
1111    /// distinguish the two axes, so the prior does not pin this rotation).
1112    EqualArdRotation,
1113    /// A rotation of the global decoder output frame `O(output_dim)`.
1114    FrameRotation,
1115    /// An exchange of two topology-identical atoms (`Sym(F)` permutation, built
1116    /// as the antisymmetric transposition direction).
1117    AtomPermutation,
1118    /// The continuous chart (reparameterization) freedom `Diff(M_k)` of one
1119    /// `d = 1` atom (arc-length canonicalization) or `d = 2` torus atom
1120    /// (isometry-flow canonicalization, #1019 stage 2). Always reported
1121    /// **pinned** with
1122    /// [`VerdictProvenance::PinnedByCanonicalization`]; the verdict's
1123    /// description names the surviving residual group (rotation + reflection
1124    /// on `S¹`, reflection + translation on the interval, or `Isom(T², flat) =
1125    /// U(1)² ⋊ D₄` for a `d = 2` torus).
1126    ChartReparameterization,
1127}
1128
1129impl GeneratorFamily {
1130    fn label(self) -> &'static str {
1131        match self {
1132            GeneratorFamily::IsomAtom => "Isom(M_k)",
1133            GeneratorFamily::EqualArdRotation => "equal-ARD rotation",
1134            GeneratorFamily::FrameRotation => "frame rotation O(output_dim)",
1135            GeneratorFamily::AtomPermutation => "Sym(F) atom permutation",
1136            GeneratorFamily::ChartReparameterization => "Diff(M_k) chart reparameterization",
1137        }
1138    }
1139}
1140
1141/// How a generator's pinned/unpinned verdict was decided. Carried
1142/// per-generator so the report distinguishes a chart fixed **by convention**
1143/// (the #1019 post-fit arc-length canonicalization — an exact, image-frozen
1144/// representative choice) from a direction pinned **by curvature** (data or
1145/// the isometry penalty giving the orbit genuine objective cost).
1146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1147pub enum VerdictProvenance {
1148    /// Decided by the relative-curvature flatness test against the stacked
1149    /// pinning root (data + isometry penalty, in the fit's metric) — the
1150    /// historical path for every enumerated generator.
1151    CurvatureTest,
1152    /// Pinned by the post-fit arc-length chart canonicalization (#1019) or the
1153    /// `d = 2` torus isometry-flow canonicalization (#1019 stage 2): the atom's
1154    /// chart is the selected representative of its `Diff(M)` orbit, so the
1155    /// continuous reparameterization freedom is fixed by construction — no
1156    /// curvature was (or needed to be) measured. Distinct from penalty-pinning
1157    /// on purpose: the certificate must not claim the objective resists chart
1158    /// motion when it is the canonicalization that removed it.
1159    PinnedByCanonicalization,
1160}
1161
1162/// Noise floor for the per-generator flatness verdict: a generator is
1163/// certified **unpinned** iff its relative curvature fraction
1164/// `‖R ξ̂‖² / σ_max(R)²` (curvature along the unit generator, relative to the
1165/// stiffest direction of the stacked curvature root `R`) is at or below the
1166/// verdict tolerance `max(GENERATOR_FLAT_ENERGY_TOL, lowering_error_scale)`.
1167///
1168/// An exact residual symmetry of the converged objective has fraction 0 up to
1169/// roundoff; any genuinely curved component — however partial — means the
1170/// orbit costs objective and the exact group element is broken, so a *mixed*
1171/// generator (e.g. a frame rotation the anisotropic output-Fisher isometry pin
1172/// gives partial curvature, the #980 Theorem-2 situation) must be reported
1173/// pinned, never as a surviving freedom. The `lowering_error_scale` arm of the
1174/// tolerance is the #995 calibration: curvature attributable to the mean-frame
1175/// compression of a curved decoder must not be read as a pin.
1176pub const GENERATOR_FLAT_ENERGY_TOL: f64 = 1.0e-3;
1177
1178/// One enumerated symmetry generator and the certificate's verdict on it.
1179#[derive(Debug, Clone)]
1180pub struct GeneratorVerdict {
1181    /// Which symmetry family this generator realises.
1182    pub family: GeneratorFamily,
1183    /// Human-readable description (which atom(s) / axes it acts on).
1184    pub description: String,
1185    /// `true` ⇒ the converged objective is flat along this generator
1186    /// (`ξ ∈ ker(H)`): a genuine residual gauge freedom the data + isometry
1187    /// penalty leave unbroken. `false` ⇒ the generator is pinned — the data or
1188    /// the isometry penalty gives it curvature (a pinned-energy fraction above
1189    /// [`GENERATOR_FLAT_ENERGY_TOL`]).
1190    pub unpinned: bool,
1191    /// `‖ξ‖₂` of the realised tangent direction (0 ⇒ the generator was
1192    /// structurally trivial — e.g. a rotation of a rank-deficient frame — and
1193    /// is reported as pinned/absent, never as a spurious freedom).
1194    pub generator_norm: f64,
1195    /// `‖R ξ̂‖² / σ_max(R)²` ∈ [0, 1]: curvature along the unit generator,
1196    /// relative to the stiffest direction of the stacked curvature root `R`
1197    /// (data + isometry penalty, in the metric). `0` ⇒ exactly flat, `1` ⇒ as
1198    /// stiff as the stiffest direction; strictly-interior values are the
1199    /// *mixed* regime — partial curvature that breaks the exact symmetry
1200    /// (verdict pinned when above the tolerance) while leaving nearby flat
1201    /// directions, kept visible here rather than collapsed into the boolean.
1202    /// Relative-to-σ_max (not span membership) so the statistic stays
1203    /// informative when the pinning span is full-rank, which production fits
1204    /// always are. Structurally trivial generators (zero norm) report `1.0`.
1205    pub pinned_energy_fraction: f64,
1206    /// The #995 lowering-error arm of this generator's verdict tolerance: the
1207    /// largest [`FittedAtom::lowering_error`] over the atoms the generator
1208    /// touches (its own atom for within-atom families, the exchanged pair for
1209    /// permutations, all atoms for global output-frame rotations). The verdict
1210    /// is `unpinned ⇔ pinned_energy_fraction ≤
1211    /// max(GENERATOR_FLAT_ENERGY_TOL, lowering_error_scale)` — curvature the
1212    /// mean-frame compression cannot distinguish from gauge motion is never
1213    /// read as a pin.
1214    pub lowering_error_scale: f64,
1215    /// How this verdict was decided: by the curvature flatness test, or
1216    /// pinned by the #1019 post-fit arc-length chart canonicalization
1217    /// (see [`VerdictProvenance`]).
1218    pub provenance: VerdictProvenance,
1219}
1220
1221/// The #972 decoder-frame **inner-rotation gauge**, enumerated for the
1222/// certificate.
1223///
1224/// A frame-factored atom `B_k = U_k C_k` is *exactly* invariant under
1225/// `U_k → U_k R`, `C_k → Rᵀ C_k` for any `R ∈ O(r_k)`: the reconstruction,
1226/// the likelihood, the penalty — every objective term — sees only the
1227/// product. Unlike the latent-isometry / ARD-rotation / permutation
1228/// generators, this freedom is therefore **not** a candidate to be pinned by
1229/// data or penalty curvature (its orbit direction is identically zero in
1230/// function space), so running it through the pinning-span test would be a
1231/// category error: it would always come back "unpinned" and pollute the
1232/// verdict list with freedoms the parameterization already handles. The
1233/// honest certificate treatment is what this struct is: *enumerate* the
1234/// group and its dimension `Σ_k r_k(r_k−1)/2`, and record how it is fixed —
1235/// by the canonical orientation gauge
1236/// ([`crate::manifold::GrassmannFrame`]'s SVD-ordered
1237/// representative), which picks one point per `O(r_k)` orbit for
1238/// serialization/comparison stability.
1239#[derive(Debug, Clone, PartialEq, Eq)]
1240pub struct FrameInnerRotationGauge {
1241    /// Active frame rank `r_k` per frame-factored atom (atoms on the full-`B`
1242    /// path contribute no entry).
1243    pub per_atom_ranks: Vec<usize>,
1244    /// Total group dimension `Σ_k r_k (r_k − 1) / 2` (`dim O(r) = r(r−1)/2`).
1245    pub dim: usize,
1246}
1247
1248impl FrameInnerRotationGauge {
1249    /// Enumerate the gauge from the active frame ranks.
1250    pub fn from_ranks(per_atom_ranks: Vec<usize>) -> Self {
1251        let dim = frame_inner_rotation_dim(&per_atom_ranks);
1252        Self {
1253            per_atom_ranks,
1254            dim,
1255        }
1256    }
1257}
1258
1259/// `Σ_k r_k (r_k − 1) / 2` — the dimension of the #972 inner-rotation gauge
1260/// group `∏_k O(r_k)` over the active frame ranks. Rank-1 frames contribute
1261/// `0` (`O(1)` is finite, a sign — absorbed by the orientation gauge), so a
1262/// dictionary of single-direction atoms reports a zero-dimensional inner
1263/// gauge, matching the intuition that one direction has no inner rotation to
1264/// fix.
1265pub fn frame_inner_rotation_dim(ranks: &[usize]) -> usize {
1266    ranks.iter().map(|&r| r * r.saturating_sub(1) / 2).sum()
1267}
1268
1269/// The certificate produced by [`residual_gauge`].
1270#[derive(Debug, Clone)]
1271pub struct ResidualGaugeReport {
1272    /// "computed in metric X" — read straight off
1273    /// [`RowMetric::provenance`]; the single metric object guarantees this
1274    /// matches the inner product the fit actually used.
1275    pub metric_provenance: MetricProvenance,
1276    /// Per-generator pinned/unpinned verdict, in enumeration order.
1277    pub generators: Vec<GeneratorVerdict>,
1278    /// Rank of the pinning span `range(H)` (data + isometry penalty) the
1279    /// generators were tested against, in the metric.
1280    pub pinning_rank: usize,
1281    /// Number of generators certified as unpinned residual gauge freedoms.
1282    pub residual_gauge_dim: usize,
1283    /// `true` when the isometry pin is inactive (`isometry_penalty_root` has no
1284    /// rows): the model is then only identified up to an arbitrary
1285    /// diffeomorphism of the latent manifolds, and every isometry generator is
1286    /// reported as a residual freedom. This is the escalation flag.
1287    pub diffeomorphism_unpinned: bool,
1288    /// Under [`MetricProvenance::OutputFisher`] the `Sym(F)` permutation
1289    /// subgroup is expected to be *trivially pinned* — the output-Fisher metric
1290    /// distinguishes the atoms behaviorally so no atom-exchange can be a
1291    /// residual freedom. `true` ⇒ that triviality holds (every
1292    /// [`GeneratorFamily::AtomPermutation`] generator is pinned);
1293    /// `false` ⇒ a permutation survived as a residual freedom, which under
1294    /// OutputFisher provenance is a certificate violation the caller must
1295    /// surface. `None` ⇒ provenance is not `OutputFisher`, so the check does
1296    /// not apply.
1297    pub sym_f_trivial_under_output_fisher: Option<bool>,
1298    /// The #972 decoder-frame inner-rotation gauge `∏_k O(r_k)` — enumerated,
1299    /// never curvature-tested (see [`FrameInnerRotationGauge`] for why).
1300    /// `None` when the caller declared no frame factorization (full-`B`
1301    /// dictionaries, or a pre-#972 caller using [`residual_gauge`] directly);
1302    /// attach via [`ResidualGaugeReport::with_frame_inner_rotation`].
1303    pub frame_inner_rotation: Option<FrameInnerRotationGauge>,
1304    /// Human-readable one-line summary.
1305    pub summary: String,
1306}
1307
1308impl ResidualGaugeReport {
1309    /// The certified residual gauge group, as a compact string naming the
1310    /// surviving generator families and their multiplicities. Two replicate
1311    /// fits are "identified up to the same group" iff this string is equal.
1312    ///
1313    /// When a frame inner-rotation gauge is enumerated it is appended with its
1314    /// dimension and its `[canonical-fixed]` marker — it is part of the group
1315    /// two replicate fits must agree on, even though it is fixed by
1316    /// convention rather than by curvature.
1317    pub fn group_signature(&self) -> String {
1318        let base = group_signature_of(&self.generators, self.diffeomorphism_unpinned);
1319        match &self.frame_inner_rotation {
1320            Some(gauge) if gauge.dim > 0 => format!(
1321                "{base} ⊕ frame-inner ∏O(r_k)×{} [dim {}, canonical-fixed]",
1322                gauge.per_atom_ranks.len(),
1323                gauge.dim
1324            ),
1325            _ => base,
1326        }
1327    }
1328
1329    /// Attach the #972 frame inner-rotation enumeration to the certificate
1330    /// (consumed by frame-factored dictionaries; `ranks` are the active frame
1331    /// ranks `r_k`, one per factored atom). Extends the summary so the
1332    /// one-line report names the enumerated-but-convention-fixed gauge.
1333    pub fn with_frame_inner_rotation(mut self, ranks: Vec<usize>) -> Self {
1334        let gauge = FrameInnerRotationGauge::from_ranks(ranks);
1335        if gauge.dim > 0 {
1336            self.summary.push_str(&format!(
1337                "; frame inner-rotation gauge ∏O(r_k) of dim {} enumerated \
1338                 (exact reparameterization, fixed by the canonical orientation gauge)",
1339                gauge.dim
1340            ));
1341        }
1342        self.frame_inner_rotation = Some(gauge);
1343        self
1344    }
1345}
1346
1347/// Compact, order-independent signature of the unpinned generator families and
1348/// multiplicities. Two replicate fits agree on their residual gauge group iff
1349/// these strings are equal.
1350fn group_signature_of(generators: &[GeneratorVerdict], diffeomorphism_unpinned: bool) -> String {
1351    let mut counts: std::collections::BTreeMap<&'static str, usize> =
1352        std::collections::BTreeMap::new();
1353    for g in generators {
1354        if g.unpinned {
1355            *counts.entry(g.family.label()).or_insert(0) += 1;
1356        }
1357    }
1358    let body = if counts.is_empty() {
1359        "{e} [fully pinned: rigid up to nothing]".to_string()
1360    } else {
1361        counts
1362            .iter()
1363            .map(|(name, mult)| format!("{name}×{mult}"))
1364            .collect::<Vec<_>>()
1365            .join(" ⊕ ")
1366    };
1367    if diffeomorphism_unpinned {
1368        // With the isometry pin inactive the residual gauge is at least the
1369        // manifold reparametrization (diffeomorphism) group modulo whatever the
1370        // data alone still pins — the surviving generators below are the
1371        // isometry slice of that larger freedom.
1372        format!("Diff(M) ⊇ {{ {body} }} [diffeomorphism-unpinned: isometry pin inactive]")
1373    } else {
1374        body
1375    }
1376}
1377
1378/// Build the atom-local isometry generators for one atom as tangent directions
1379/// on the atom's flattened decoder frame.
1380///
1381/// An isometry of the latent manifold acts on the latent coordinate frame; we
1382/// lift it to the decoder output by acting on the frame columns. For a rotation
1383/// generator `A ∈ so(latent_dim)` (antisymmetric), the induced tangent direction
1384/// on `frame ∈ ℝ^{p × d}` is `frame · Aᵀ` (the first-order motion of the frame
1385/// columns under the one-parameter rotation `exp(tA)`), flattened row-major. For
1386/// the circle this is the single `so(2)` generator; for the sphere the three
1387/// `so(3)` generators; for the torus the `d` independent axis shifts (which on
1388/// the flat product manifold are translations of each circle coordinate —
1389/// realised as the unit tangent along each frame column).
1390fn atom_isometry_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
1391    let (p, d) = atom.frame.dim();
1392    // The intrinsic latent dimension of the manifold fixes `dim Isom(M_k)` (the
1393    // number of independent isometry generators we must enumerate). The fitted
1394    // decoder frame's column count `d` must realise exactly that many latent
1395    // axes; a frame whose column count disagrees with the topology's intrinsic
1396    // dimension is a structurally inconsistent atom and we refuse to fabricate
1397    // generators for it (returning none, so it cannot masquerade as either
1398    // pinned or a spurious residual freedom in the certificate).
1399    if d != atom.topology.latent_dim() {
1400        return Vec::new();
1401    }
1402    let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1403    match &atom.topology {
1404        AtomTopology::Circle => {
1405            // so(2): A = [[0,-1],[1,0]] on the 1 circle, but a Circle atom has a
1406            // single latent axis whose isometry is a *shift* of the periodic
1407            // coordinate. The first-order motion of the (cos,sin) frame columns
1408            // under a shift is the orthogonal frame column. With latent_dim == 1
1409            // the decoder frame's single column moves along its own
1410            // 90°-rotated image, which (lacking a second column) is realised as
1411            // the tangent that advances the periodic phase: the unit direction
1412            // along the frame column itself (the generator of the U(1) shift).
1413            if d >= 1 {
1414                let mut g = Array1::<f64>::zeros(p * d);
1415                for i in 0..p {
1416                    g[i * d] = atom.frame[[i, 0]];
1417                }
1418                out.push((g, format!("{}: S¹ U(1) phase shift", atom.name)));
1419            }
1420        }
1421        AtomTopology::Sphere | AtomTopology::EuclideanPatch { .. } | AtomTopology::Torus { .. } => {
1422            // so(d) rotation generators: one per unordered axis pair (a < b).
1423            // The induced frame motion is frame · A_{ab}ᵀ, i.e. column a picks
1424            // up −column b and column b picks up +column a.
1425            for a in 0..d {
1426                for b in (a + 1)..d {
1427                    let mut g = Array1::<f64>::zeros(p * d);
1428                    for i in 0..p {
1429                        // (frame · Aᵀ)[i, a] = −frame[i, b]; [i, b] = +frame[i, a].
1430                        g[i * d + a] = -atom.frame[[i, b]];
1431                        g[i * d + b] = atom.frame[[i, a]];
1432                    }
1433                    out.push((
1434                        g,
1435                        format!(
1436                            "{}: {} rotation axes ({a},{b})",
1437                            atom.name,
1438                            match &atom.topology {
1439                                AtomTopology::Sphere => "S² so(3)",
1440                                AtomTopology::Torus { .. } => "Tᵈ frame",
1441                                _ => "patch so(d)",
1442                            }
1443                        ),
1444                    ));
1445                }
1446            }
1447            // Torus additionally carries `d` independent circle shifts: the unit
1448            // tangent advancing each axis's periodic phase (translation of that
1449            // circle coordinate), realised as motion along each frame column.
1450            if let AtomTopology::Torus { .. } = atom.topology {
1451                for a in 0..d {
1452                    let mut g = Array1::<f64>::zeros(p * d);
1453                    for i in 0..p {
1454                        g[i * d + a] = atom.frame[[i, a]];
1455                    }
1456                    out.push((g, format!("{}: Tᵈ circle shift axis {a}", atom.name)));
1457                }
1458            }
1459        }
1460    }
1461    out
1462}
1463
1464/// Build equal-ARD rotation generators for one atom: a rotation between two
1465/// latent axes whose ARD variances are equal (within `rel_tol`) is not pinned by
1466/// the ARD prior, so it is a candidate residual gauge freedom (the data +
1467/// isometry penalty decide). Returns the antisymmetric frame-rotation tangent
1468/// for each such equal pair.
1469fn equal_ard_rotation_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
1470    let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1471    let (p, d) = atom.frame.dim();
1472    let Some(ard) = atom.ard_variances.as_ref() else {
1473        return out;
1474    };
1475    if ard.len() != d {
1476        return out;
1477    }
1478    const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
1479    for a in 0..d {
1480        for b in (a + 1)..d {
1481            let va = ard[a];
1482            let vb = ard[b];
1483            let scale = va.abs().max(vb.abs()).max(f64::MIN_POSITIVE);
1484            if (va - vb).abs() <= ARD_EQUAL_REL_TOL * scale {
1485                let mut g = Array1::<f64>::zeros(p * d);
1486                for i in 0..p {
1487                    g[i * d + a] = -atom.frame[[i, b]];
1488                    g[i * d + b] = atom.frame[[i, a]];
1489                }
1490                out.push((
1491                    g,
1492                    format!("{}: equal-ARD rotation axes ({a},{b})", atom.name),
1493                ));
1494            }
1495        }
1496    }
1497    out
1498}
1499
1500/// Build global decoder output-frame rotation generators `O(output_dim)`: a
1501/// rotation `B ∈ so(output_dim)` acts on every atom's frame from the left
1502/// (`B · frame`). The induced tangent on the joint parameter vector stacks
1503/// `B · frame_k` per atom. We enumerate the full `so(output_dim)` basis — one
1504/// generator per unordered output-axis pair `(oi < oj)`, count
1505/// `output_dim·(output_dim−1)/2` — since the per-generator rank test treats each
1506/// independently and we want the certificate to find every output-frame freedom,
1507/// not a subset. `output_dim` is taken as the maximum frame row-count across
1508/// atoms; an atom whose frame lacks one of the two axes contributes nothing to
1509/// that generator.
1510fn frame_rotation_generators(model: &FittedSaeManifold) -> Vec<(Array1<f64>, String)> {
1511    let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1512    let p = model
1513        .atoms
1514        .iter()
1515        .map(|a| a.frame.nrows())
1516        .max()
1517        .unwrap_or(0);
1518    let param_dim = model.param_dim();
1519    for oi in 0..p {
1520        for oj in (oi + 1)..p {
1521            let mut g = Array1::<f64>::zeros(param_dim);
1522            for (k, atom) in model.atoms.iter().enumerate() {
1523                let (ap, ad) = atom.frame.dim();
1524                if oi >= ap || oj >= ap {
1525                    continue;
1526                }
1527                let base = model.atom_offset(k);
1528                // (B · frame)[oi, c] = −frame[oj, c]; [oj, c] = +frame[oi, c].
1529                for c in 0..ad {
1530                    g[base + oi * ad + c] = -atom.frame[[oj, c]];
1531                    g[base + oj * ad + c] = atom.frame[[oi, c]];
1532                }
1533            }
1534            out.push((g, format!("output-frame rotation axes ({oi},{oj})")));
1535        }
1536    }
1537    out
1538}
1539
1540/// Build exchangeable-atom permutation generators: for every pair of atoms with
1541/// identical topology and matching frame shape, the transposition that swaps
1542/// their decoder frames is a candidate `Sym(F)` symmetry. Realised as the
1543/// antisymmetric "swap" tangent `(frame_b − frame_a)` placed on atom a's slot and
1544/// `(frame_a − frame_b)` on atom b's slot — the first-order direction of the
1545/// one-parameter family interpolating the swap.
1546/// Embed an atom-local generator (length = that atom's flattened frame length)
1547/// into the joint parameter vector at the atom's column offset. The per-atom
1548/// generator builders do not know the joint layout; the certificate does, and
1549/// mixing the two coordinate systems is a shape error for every model with more
1550/// than one atom.
1551fn embed_local_generator(offset: usize, local: &Array1<f64>, param_dim: usize) -> Array1<f64> {
1552    let mut g = Array1::<f64>::zeros(param_dim);
1553    g.slice_mut(s![offset..offset + local.len()]).assign(local);
1554    g
1555}
1556
1557fn atom_permutation_generators(
1558    model: &FittedSaeManifold,
1559) -> Vec<(Array1<f64>, String, usize, usize)> {
1560    let mut out: Vec<(Array1<f64>, String, usize, usize)> = Vec::new();
1561    let param_dim = model.param_dim();
1562    for ka in 0..model.atoms.len() {
1563        for kb in (ka + 1)..model.atoms.len() {
1564            let a = &model.atoms[ka];
1565            let b = &model.atoms[kb];
1566            if a.topology != b.topology || a.frame.dim() != b.frame.dim() {
1567                continue;
1568            }
1569            let (ap, ad) = a.frame.dim();
1570            let base_a = model.atom_offset(ka);
1571            let base_b = model.atom_offset(kb);
1572            let mut g = Array1::<f64>::zeros(param_dim);
1573            for i in 0..ap {
1574                for c in 0..ad {
1575                    let diff = b.frame[[i, c]] - a.frame[[i, c]];
1576                    g[base_a + i * ad + c] = diff;
1577                    g[base_b + i * ad + c] = -diff;
1578                }
1579            }
1580            out.push((g, format!("atom-exchange {} ↔ {}", a.name, b.name), ka, kb));
1581        }
1582    }
1583    out
1584}
1585
1586// ============================================================================
1587// #998 — the full-resolution certificate: exact gauge orbits in the model's
1588// own (decoder, coordinate) parameter space.
1589// ============================================================================
1590
1591/// One atom's exact parameter-space view (#998): the raw objects the fit
1592/// actually optimizes, in which the model-class gauge orbits live.
1593///
1594/// The mean-frame certificate ([`FittedAtom::frame`]) is a lossy compression:
1595/// the true gauge orbits are **compensated** motions — the latent coordinates
1596/// move AND the decoder counter-rotates (e.g. `Φ(t+ε)·R(−ε)B = Φ(t)B` for the
1597/// harmonic circle) — whose net action on the mean frame is identically zero,
1598/// so no frame-space realisation can measure them (#995's calibrated tolerance
1599/// is the honest *floor* there). With this view the certificate realises each
1600/// orbit exactly: the coordinate motion field `δt` comes from the group
1601/// action, and the decoder compensation `δB` is **profiled out by least
1602/// squares** against the data motion. The leftover residual is the orbit's
1603/// true data cost — exactly zero when the basis family is closed under the
1604/// action (harmonics under shifts, linear charts under rotations), genuinely
1605/// positive when it is not (a Duchon patch under so(d)). Basis closure is
1606/// therefore a *computed* per-generator quantity, not a declared flag.
1607#[derive(Debug, Clone)]
1608pub struct AtomParameterView {
1609    /// Basis values `Φ`, `(n, M)`.
1610    pub basis_values: Array2<f64>,
1611    /// Basis first-derivative jet `Φ'`, `(n, M, latent_dim)`.
1612    pub basis_jacobian: Array3<f64>,
1613    /// Decoder coefficients `B`, `(M, p)`.
1614    pub decoder: Array2<f64>,
1615    /// Latent coordinates `t`, `(n, latent_dim)` — the chart the group acts on.
1616    pub coords: Array2<f64>,
1617    /// Per-row assignment mass `a_nk`, length `n`.
1618    pub activations: Array1<f64>,
1619    /// Basis second-derivative jet `Φ''`, `(n, M, latent_dim, latent_dim)`.
1620    /// Required only to lower an isometry [`OrbitPenaltyOperator`] for a
1621    /// *pin-active* fit (#998): the penalty is a function of the pullback
1622    /// metric `g_n = J_nᵀ W_n J_n`, and the first-order change of `g_n` under a
1623    /// coordinate motion `δt` differentiates `J_n = Φ'_n B` through `t`, which
1624    /// needs `Φ''`. `None` keeps the data-only orbit verdict (no pin), exactly
1625    /// as before; absence never errors.
1626    pub basis_second_jet: Option<Array4<f64>>,
1627}
1628
1629/// The penalty/prior channel of the exact certificate: an operator returning
1630/// the penalty curvature root's image of an orbit direction `(δB, δt)`,
1631/// together with its stiffness scale `σ_max²`. With exact orbits the data can
1632/// never pin a model-class symmetry (the LS-compensated motion is a data-null
1633/// by construction for closed bases), so **all** pinning of such symmetries
1634/// flows through this channel — exactly where the #981 gauge-reduction ladder
1635/// says identification lives (the isometry pin does the collapsing, rungs 2
1636/// and 4, in whichever metric it is computed). `None` ⇒ no pin installed on
1637/// this atom; the orbit's verdict is then decided by the data residual alone.
1638pub struct OrbitPenaltyOperator {
1639    /// Maps an orbit direction `(δB (M, p), δt (n, latent_dim))` to the
1640    /// penalty curvature root's image (any length); the penalty cost along the
1641    /// direction is the squared norm of the image.
1642    pub apply: Box<dyn Fn(ArrayView2<f64>, ArrayView2<f64>) -> Array1<f64> + Send + Sync>,
1643    /// `σ_max²` of the penalty curvature root — the stiffness scale the
1644    /// orbit's penalty cost is reported relative to (the same
1645    /// relative-curvature convention as the frame certificate).
1646    pub stiffness_sq: f64,
1647}
1648
1649/// Build the isometry-pin [`OrbitPenaltyOperator`] for one viewed atom from its
1650/// second jet (#998 — the orbit-space pin operator the pin-active exact path
1651/// needs).
1652///
1653/// The isometry penalty is `P = ½ μ Σ_n ‖g_n − g_ref‖²_F` with the pullback
1654/// first-fundamental-form gram `g_n = J_nᵀ J_n`, `J_n[i,c] = Σ_m Φ'_n[m,c] B[m,i]`
1655/// (Euclidean metric — the default isometry reference; an output-Fisher metric
1656/// rides the same operator once its factors are threaded, which only re-weights
1657/// the `i`-sum). At a converged isometric fit the residual `g_n − g_ref ≈ 0`, so
1658/// the penalty's curvature along an orbit direction `(δB, δt)` is the
1659/// Gauss-Newton term `μ Σ_n ‖δg_n‖²_F`, and the curvature-root image is
1660/// `√μ · {δg_n[a,b]}` — its squared norm is exactly that cost. The first-order
1661/// gram change
1662///
1663///   `δJ_n[i,c] = Σ_m Φ'_n[m,c] δB[m,i] + Σ_{m,e} Φ''_n[m,c,e] δt_n[e] B[m,i]`
1664///   `δg_n[a,b] = Σ_i ( δJ_n[i,a] J_n[i,b] + J_n[i,a] δJ_n[i,b] )`
1665///
1666/// differentiates `J_n` through `t` via the **second jet** `Φ''` — which is why
1667/// the pin-active path needs it and the frame path (no second jet) could not
1668/// supply it. A model-class symmetry that preserves the metric (e.g. a circle
1669/// phase shift on a closed harmonic basis) yields `δg_n = 0` → the operator
1670/// gives it zero cost → it stays a certified freedom even under the pin; a
1671/// non-isometric orbit (a Duchon/quadratic patch under rotation) yields
1672/// `δg_n ≠ 0` → genuine pinning. The verdict is therefore conservative: the
1673/// operator can only *cost* an orbit, never spuriously free one.
1674///
1675/// `weight` is the penalty strength `μ`. Returns `None` when the view carries no
1676/// second jet (the atom's basis exposes no analytic Hessian): with no orbit-space
1677/// operator the atom's verdict falls back to the data residual, never an error.
1678/// The stiffness `σ_max²` is `μ` times the largest unit-coordinate-motion gram
1679/// curvature `max_n σ_max(∂g_n/∂t)²`, so the reported relative fraction is on the
1680/// same convention as the frame certificate.
1681pub fn isometry_orbit_penalty_operator(
1682    view: &AtomParameterView,
1683    weight: f64,
1684) -> Option<OrbitPenaltyOperator> {
1685    let second = view.basis_second_jet.as_ref()?.clone();
1686    let (n, m) = view.basis_values.dim();
1687    let d = view.coords.ncols();
1688    let p = view.decoder.ncols();
1689    if second.dim() != (n, m, d, d) || view.basis_jacobian.dim() != (n, m, d) {
1690        return None;
1691    }
1692    if !(weight.is_finite() && weight > 0.0) {
1693        return None;
1694    }
1695    let sqrt_w = weight.sqrt();
1696    let jac = view.basis_jacobian.clone();
1697    let decoder = view.decoder.clone();
1698
1699    // Base pullback Jacobian J_n[i,c] = Σ_m Φ'_n[m,c] B[m,i] and its per-row
1700    // first-fundamental gram σ_max scale (stiffness), computed once.
1701    let mut j_base = Array3::<f64>::zeros((n, p, d));
1702    for row in 0..n {
1703        for i in 0..p {
1704            for c in 0..d {
1705                let mut acc = 0.0;
1706                for mm in 0..m {
1707                    acc += jac[[row, mm, c]] * decoder[[mm, i]];
1708                }
1709                j_base[[row, i, c]] = acc;
1710            }
1711        }
1712    }
1713
1714    // Stiffness: σ_max over rows of the gram derivative ∂g_n/∂t along a unit
1715    // coordinate motion. ∂g_n/∂t_e [a,b] = Σ_i ( H_n[i,a,e] J_n[i,b]
1716    // + J_n[i,a] H_n[i,b,e] ), H_n[i,c,e] = Σ_m Φ''_n[m,c,e] B[m,i]. The
1717    // stiffest unit δt direction's gram change drives the relative-curvature
1718    // denominator; we take the largest ‖∂g/∂t_e‖_F over axes e and rows as a
1719    // conservative (≤ true σ_max) scale, so the reported fraction never
1720    // under-states the pin.
1721    let mut max_curv_sq = 0.0_f64;
1722    for row in 0..n {
1723        // H_n[i, c, e] = Σ_m Φ''_n[m, c, e] B[m, i].
1724        let mut hn = vec![0.0_f64; p * d * d];
1725        for i in 0..p {
1726            for c in 0..d {
1727                for e in 0..d {
1728                    let mut acc = 0.0;
1729                    for mm in 0..m {
1730                        acc += second[[row, mm, c, e]] * decoder[[mm, i]];
1731                    }
1732                    hn[(i * d + c) * d + e] = acc;
1733                }
1734            }
1735        }
1736        for e in 0..d {
1737            let mut frob = 0.0_f64;
1738            for a in 0..d {
1739                for b in 0..d {
1740                    let mut g = 0.0;
1741                    for i in 0..p {
1742                        g += hn[(i * d + a) * d + e] * j_base[[row, i, b]];
1743                        g += j_base[[row, i, a]] * hn[(i * d + b) * d + e];
1744                    }
1745                    frob += g * g;
1746                }
1747            }
1748            max_curv_sq = max_curv_sq.max(frob);
1749        }
1750    }
1751    let stiffness_sq = (weight * max_curv_sq).max(f64::MIN_POSITIVE);
1752
1753    let apply = move |delta_b: ArrayView2<f64>, delta_t: ArrayView2<f64>| -> Array1<f64> {
1754        let mut image = Array1::<f64>::zeros(n * d * d);
1755        // δJ_n[i,c] = Σ_m Φ'_n[m,c] δB[m,i] + Σ_{m,e} Φ''_n[m,c,e] δt_n[e] B[m,i].
1756        let valid_b = delta_b.dim() == (m, p);
1757        let valid_t = delta_t.dim() == (n, d);
1758        if !valid_t {
1759            return image;
1760        }
1761        for row in 0..n {
1762            let mut dj = vec![0.0_f64; p * d];
1763            for i in 0..p {
1764                for c in 0..d {
1765                    let mut acc = 0.0;
1766                    if valid_b {
1767                        for mm in 0..m {
1768                            acc += jac[[row, mm, c]] * delta_b[[mm, i]];
1769                        }
1770                    }
1771                    for e in 0..d {
1772                        let dte = delta_t[[row, e]];
1773                        if dte == 0.0 {
1774                            continue;
1775                        }
1776                        for mm in 0..m {
1777                            acc += second[[row, mm, c, e]] * dte * decoder[[mm, i]];
1778                        }
1779                    }
1780                    dj[i * d + c] = acc;
1781                }
1782            }
1783            // δg_n[a,b] = Σ_i ( δJ[i,a] J[i,b] + J[i,a] δJ[i,b] ).
1784            for a in 0..d {
1785                for b in 0..d {
1786                    let mut dg = 0.0;
1787                    for i in 0..p {
1788                        dg += dj[i * d + a] * j_base[[row, i, b]];
1789                        dg += j_base[[row, i, a]] * dj[i * d + b];
1790                    }
1791                    image[(row * d + a) * d + b] = sqrt_w * dg;
1792                }
1793            }
1794        }
1795        image
1796    };
1797
1798    Some(OrbitPenaltyOperator {
1799        apply: Box::new(apply),
1800        stiffness_sq,
1801    })
1802}
1803
1804/// Enumerate one atom's exact orbit coordinate-motion fields `δt ∈ ℝ^{n×d}`.
1805///
1806/// Supported charts are the ones the group acts on **linearly** (so the
1807/// first-order field is exact, not a linearisation): circle/torus axis shifts
1808/// (`δt = e_ax`, chart-free) and flat-patch `so(d)` rotations
1809/// (`δt_n = A_{ab} t_n`). The sphere's `so(3)` action on an intrinsic chart is
1810/// nonlinear, so sphere atoms stay on the frame path (the caller must not
1811/// build a view for them). Equal-ARD rotations reuse the rotation field for
1812/// the tied axis pairs (the ARD prior is their pinning channel).
1813fn exact_orbit_fields(
1814    atom: &FittedAtom,
1815    view: &AtomParameterView,
1816) -> Vec<(GeneratorFamily, Array2<f64>, String)> {
1817    let n = view.coords.nrows();
1818    let d = view.coords.ncols();
1819    let mut out: Vec<(GeneratorFamily, Array2<f64>, String)> = Vec::new();
1820    let rotation_field = |a: usize, b: usize| -> Array2<f64> {
1821        let mut dt = Array2::<f64>::zeros((n, d));
1822        for row in 0..n {
1823            dt[[row, a]] = -view.coords[[row, b]];
1824            dt[[row, b]] = view.coords[[row, a]];
1825        }
1826        dt
1827    };
1828    match &atom.topology {
1829        AtomTopology::Circle => {
1830            out.push((
1831                GeneratorFamily::IsomAtom,
1832                Array2::<f64>::ones((n, 1)),
1833                format!("{}: S¹ U(1) phase shift [exact orbit]", atom.name),
1834            ));
1835        }
1836        AtomTopology::Torus { .. } => {
1837            for ax in 0..d {
1838                let mut dt = Array2::<f64>::zeros((n, d));
1839                dt.column_mut(ax).fill(1.0);
1840                out.push((
1841                    GeneratorFamily::IsomAtom,
1842                    dt,
1843                    format!("{}: Tᵈ circle shift axis {ax} [exact orbit]", atom.name),
1844                ));
1845            }
1846        }
1847        AtomTopology::EuclideanPatch { .. } => {
1848            for a in 0..d {
1849                for b in (a + 1)..d {
1850                    out.push((
1851                        GeneratorFamily::IsomAtom,
1852                        rotation_field(a, b),
1853                        format!(
1854                            "{}: patch so(d) rotation axes ({a},{b}) [exact orbit]",
1855                            atom.name
1856                        ),
1857                    ));
1858                }
1859            }
1860        }
1861        AtomTopology::Sphere => {}
1862    }
1863    // Equal-ARD rotations between tied axes, on linearly-acting charts only.
1864    if !matches!(atom.topology, AtomTopology::Circle | AtomTopology::Sphere) {
1865        if let Some(ard) = atom.ard_variances.as_ref() {
1866            if ard.len() == d {
1867                const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
1868                for a in 0..d {
1869                    for b in (a + 1)..d {
1870                        let scale = ard[a].abs().max(ard[b].abs()).max(f64::MIN_POSITIVE);
1871                        if (ard[a] - ard[b]).abs() <= ARD_EQUAL_REL_TOL * scale {
1872                            out.push((
1873                                GeneratorFamily::EqualArdRotation,
1874                                rotation_field(a, b),
1875                                format!(
1876                                    "{}: equal-ARD rotation axes ({a},{b}) [exact orbit]",
1877                                    atom.name
1878                                ),
1879                            ));
1880                        }
1881                    }
1882                }
1883            }
1884        }
1885    }
1886    out
1887}
1888
1889/// Exact-orbit verdicts for one viewed atom (#998).
1890///
1891/// For each orbit field `δt`: the uncompensated data motion is
1892/// `u_n = a_n · (Φ'_n B) δt_n ∈ ℝ^p`; the decoder compensation `δB` minimizing
1893/// `Σ_n ‖a_n Φ_n δB + u_n‖²` is profiled out through one shared SVD
1894/// pseudo-inverse of the activation-weighted basis `D = diag(a) Φ`; and the
1895/// **compensation residual fraction** `r²/‖u‖²` is the orbit's true relative
1896/// data cost — exactly 0 for a basis closed under the group action, genuinely
1897/// positive otherwise (computed closure). The penalty channel, when installed,
1898/// contributes `‖penalty_root(δB, δt)‖² / σ_max²` on the same
1899/// relative-curvature convention. The verdict needs **no lowering-error
1900/// calibration** (`lowering_error_scale = 0`): nothing here is compressed.
1901///
1902/// The data likelihood this measures against is the activation-reconstruction
1903/// objective in its own (Euclidean) inner product — which per the amended #980
1904/// dispatch rule is the only thing that ever whitens the likelihood unless a
1905/// `WhitenedStructured` noise model is installed; the output-Fisher metric
1906/// reaches gauge verdicts only through the penalty operator.
1907fn exact_orbit_verdicts(
1908    atom: &FittedAtom,
1909    view: &AtomParameterView,
1910    penalty: Option<&OrbitPenaltyOperator>,
1911) -> Result<Vec<GeneratorVerdict>, String> {
1912    let (n, m) = view.basis_values.dim();
1913    let d = view.coords.ncols();
1914    let p = view.decoder.ncols();
1915    if view.basis_jacobian.dim() != (n, m, d) {
1916        return Err(format!(
1917            "exact_orbit_verdicts({}): basis_jacobian shape {:?} must be ({n}, {m}, {d})",
1918            atom.name,
1919            view.basis_jacobian.dim()
1920        ));
1921    }
1922    if view.decoder.nrows() != m {
1923        return Err(format!(
1924            "exact_orbit_verdicts({}): decoder has {} rows but basis has {m} columns",
1925            atom.name,
1926            view.decoder.nrows()
1927        ));
1928    }
1929    if view.coords.nrows() != n || view.activations.len() != n {
1930        return Err(format!(
1931            "exact_orbit_verdicts({}): coords/activations rows must match basis rows {n}",
1932            atom.name
1933        ));
1934    }
1935
1936    let fields = exact_orbit_fields(atom, view);
1937    if fields.is_empty() {
1938        return Ok(Vec::new());
1939    }
1940
1941    // Shared compensation operator: thin SVD of D = diag(a)·Φ, computed once.
1942    let mut design = Array2::<f64>::zeros((n, m));
1943    for row in 0..n {
1944        let a = view.activations[row];
1945        for c in 0..m {
1946            design[[row, c]] = a * view.basis_values[[row, c]];
1947        }
1948    }
1949    let (u_opt, sigma, vt_opt) = design
1950        .svd(true, true)
1951        .map_err(|e| format!("exact_orbit_verdicts({}): SVD of D failed: {e}", atom.name))?;
1952    let u_svd =
1953        u_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked U", atom.name))?;
1954    let vt = vt_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked Vᵀ", atom.name))?;
1955    let smax = sigma.iter().cloned().fold(0.0_f64, f64::max);
1956    let cutoff = smax * f64::EPSILON * (n.max(m) as f64);
1957
1958    let mut out: Vec<GeneratorVerdict> = Vec::with_capacity(fields.len());
1959    for (family, dt, description) in fields {
1960        // Uncompensated data motion u_n = a_n (Φ'_n B) δt_n.
1961        let mut u_mot = Array2::<f64>::zeros((n, p));
1962        for row in 0..n {
1963            let a = view.activations[row];
1964            if !(a != 0.0) {
1965                continue;
1966            }
1967            for ax in 0..d {
1968                let step = dt[[row, ax]];
1969                if step == 0.0 {
1970                    continue;
1971                }
1972                for bm in 0..m {
1973                    let dphi = view.basis_jacobian[[row, bm, ax]];
1974                    if dphi == 0.0 {
1975                        continue;
1976                    }
1977                    let w = a * step * dphi;
1978                    for j in 0..p {
1979                        u_mot[[row, j]] += w * view.decoder[[bm, j]];
1980                    }
1981                }
1982            }
1983        }
1984        let raw: f64 = u_mot.iter().map(|v| v * v).sum();
1985        if raw <= f64::MIN_POSITIVE {
1986            // The orbit does not move the fit at all (zero tangents / zero
1987            // mass): structurally trivial, reported pinned with zero norm,
1988            // mirroring the frame certificate's convention.
1989            out.push(GeneratorVerdict {
1990                family,
1991                description,
1992                unpinned: false,
1993                generator_norm: 0.0,
1994                pinned_energy_fraction: 1.0,
1995                lowering_error_scale: 0.0,
1996                provenance: VerdictProvenance::CurvatureTest,
1997            });
1998            continue;
1999        }
2000        // Profile out the decoder compensation: c = Uᵀu, keep σ > cutoff.
2001        // Residual cost r² = ‖u‖² − ‖c_kept‖² (Pythagoras on the projection).
2002        let coeffs = u_svd.t().dot(&u_mot);
2003        let mut kept_sq = 0.0_f64;
2004        let mut scaled = Array2::<f64>::zeros((sigma.len(), p));
2005        for r in 0..sigma.len() {
2006            if sigma[r] > cutoff {
2007                let inv = 1.0 / sigma[r];
2008                for j in 0..p {
2009                    kept_sq += coeffs[[r, j]] * coeffs[[r, j]];
2010                    scaled[[r, j]] = -inv * coeffs[[r, j]];
2011                }
2012            }
2013        }
2014        let resid_sq = (raw - kept_sq).max(0.0);
2015        let data_fraction = (resid_sq / raw).clamp(0.0, 1.0);
2016
2017        let penalty_fraction = match penalty {
2018            Some(op) if op.stiffness_sq > f64::MIN_POSITIVE => {
2019                let delta_b = vt.t().dot(&scaled); // δB = −V Σ⁺ Uᵀ u, (M, p)
2020                let image = (op.apply)(delta_b.view(), dt.view());
2021                let cost: f64 = image.iter().map(|v| v * v).sum();
2022                (cost / op.stiffness_sq).clamp(0.0, 1.0)
2023            }
2024            _ => 0.0,
2025        };
2026
2027        let pinned_energy_fraction = data_fraction.max(penalty_fraction);
2028        out.push(GeneratorVerdict {
2029            family,
2030            description,
2031            unpinned: pinned_energy_fraction <= GENERATOR_FLAT_ENERGY_TOL,
2032            generator_norm: raw.sqrt(),
2033            pinned_energy_fraction,
2034            lowering_error_scale: 0.0,
2035            provenance: VerdictProvenance::CurvatureTest,
2036        });
2037    }
2038    Ok(out)
2039}
2040
2041/// The stacked curvature root `R` of the pinning operator, in the fit's
2042/// metric: `(m, param_dim)` with `H = H_data + H_isometry = RᵀR`.
2043///
2044/// We assemble `R = [ W^{½} J ; R_isom ]` whose row space is
2045/// `range(H_data) + range(H_isometry)`, where `W^{½} J` is the metric-whitened
2046/// decoder Jacobian (the metric whitening is the `RowMetric`'s
2047/// `whiten_residual_row` applied to each output residual basis vector — i.e.
2048/// each Jacobian row is whitened in the same inner product the likelihood
2049/// sums). The caller derives both faces from this one object: the pinning
2050/// RANK (RRQR on `Rᵀ`, the audit's leverage-scaled rank decision) and the
2051/// per-generator relative curvature `‖R ξ̂‖² / σ_max(R)²` — magnitudes kept,
2052/// not orthonormalized away, so the statistic survives a full-rank span.
2053fn stacked_curvature_root(model: &FittedSaeManifold) -> Result<Array2<f64>, String> {
2054    let param_dim = model.param_dim();
2055    if param_dim == 0 {
2056        return Ok(Array2::<f64>::zeros((0, 0)));
2057    }
2058    let p = model.metric.p_out();
2059    // Metric-whitened Jacobian rows: each row's Jacobian J_n ∈ ℝ^{p × param_dim}
2060    // is whitened to U_nᵀ J_n ∈ ℝ^{rank × param_dim} so that the resulting rows
2061    // span the same directions the metric-whitened residual gives cost to. We
2062    // build the stacked matrix `R` with one block of whitened rows per metric
2063    // row, then the isometry-penalty root beneath it.
2064    let mut stacked_rows: Vec<Array1<f64>> = Vec::new();
2065    for (n, j_flat) in model.jacobian_rows.iter().enumerate() {
2066        if j_flat.len() != p * param_dim {
2067            return Err(format!(
2068                "stacked_curvature_root: jacobian_rows[{n}] has len {} but expected p*param_dim = {}*{} = {}",
2069                j_flat.len(),
2070                p,
2071                param_dim,
2072                p * param_dim
2073            ));
2074        }
2075        // Whiten each parameter column's p-vector of output sensitivities.
2076        // Column c of J_n is the p-vector (j_flat[i*param_dim + c])_i. Whitening
2077        // it through the metric row (U_nᵀ ·) maps each column to a
2078        // `whit_len`-vector; the resulting `whit_len × param_dim` block's rows
2079        // are the metric-whitened Jacobian rows whose span the data gives cost
2080        // to. For Euclidean provenance `whiten_residual_row` is the identity, so
2081        // `whit_len == p` and the block is J_n unchanged (bit-for-bit the
2082        // isotropic data span).
2083        let mut cols_whitened: Vec<Vec<f64>> = Vec::with_capacity(param_dim);
2084        for c in 0..param_dim {
2085            let mut col = vec![0.0_f64; p];
2086            for i in 0..p {
2087                col[i] = j_flat[i * param_dim + c];
2088            }
2089            cols_whitened.push(model.metric.whiten_residual_row(n, ArrayView1::from(&col)));
2090        }
2091        let whit_len = cols_whitened.first().map_or(0, |c| c.len());
2092        for r in 0..whit_len {
2093            let mut row = Array1::<f64>::zeros(param_dim);
2094            for (c, col) in cols_whitened.iter().enumerate() {
2095                row[c] = col[r];
2096            }
2097            stacked_rows.push(row);
2098        }
2099    }
2100    // Append isometry-penalty root rows.
2101    if model.isometry_penalty_root.ncols() != 0 {
2102        if model.isometry_penalty_root.ncols() != param_dim {
2103            return Err(format!(
2104                "stacked_curvature_root: isometry_penalty_root has {} cols but param_dim = {param_dim}",
2105                model.isometry_penalty_root.ncols()
2106            ));
2107        }
2108        for r in 0..model.isometry_penalty_root.nrows() {
2109            stacked_rows.push(model.isometry_penalty_root.row(r).to_owned());
2110        }
2111    }
2112    if stacked_rows.is_empty() {
2113        return Ok(Array2::<f64>::zeros((0, param_dim)));
2114    }
2115    let m = stacked_rows.len();
2116    let mut r_mat = Array2::<f64>::zeros((m, param_dim));
2117    for (i, row) in stacked_rows.iter().enumerate() {
2118        r_mat.row_mut(i).assign(row);
2119    }
2120    Ok(r_mat)
2121}
2122
2123enum CurvatureReduction {
2124    Root {
2125        pinning_rank: usize,
2126        sigma_max_sq: f64,
2127        root: Array2<f64>,
2128    },
2129    Gram {
2130        pinning_rank: usize,
2131        sigma_max_sq: f64,
2132        gram: Array2<f64>,
2133    },
2134}
2135
2136impl CurvatureReduction {
2137    fn from_model(model: &FittedSaeManifold) -> Result<Self, String> {
2138        let root = stacked_curvature_root(model)?;
2139        if root.nrows() == 0 {
2140            return Ok(Self::Root {
2141                pinning_rank: 0,
2142                sigma_max_sq: 0.0,
2143                root,
2144            });
2145        }
2146        let r_t = root.t().to_owned();
2147        let rrqr = rrqr_with_permutation(&r_t, default_rrqr_rank_alpha())
2148            .map_err(|e| format!("residual_gauge: RRQR on Rᵀ failed: {e:?}"))?;
2149        let (_u, sv, _vt) = root
2150            .svd(false, false)
2151            .map_err(|e| format!("residual_gauge: SVD of curvature root failed: {e}"))?;
2152        let smax = sv.iter().cloned().fold(0.0_f64, f64::max);
2153        Ok(Self::Root {
2154            pinning_rank: rrqr.rank,
2155            sigma_max_sq: smax * smax,
2156            root,
2157        })
2158    }
2159
2160    fn from_gram(gram: Array2<f64>, root_rows: usize, param_dim: usize) -> Result<Self, String> {
2161        if gram.nrows() != param_dim || gram.ncols() != param_dim {
2162            return Err(format!(
2163                "residual_gauge: curvature gram has shape ({}, {}) but param_dim = {param_dim}",
2164                gram.nrows(),
2165                gram.ncols()
2166            ));
2167        }
2168        if param_dim == 0 || root_rows == 0 {
2169            return Ok(Self::Gram {
2170                pinning_rank: 0,
2171                sigma_max_sq: 0.0,
2172                gram,
2173            });
2174        }
2175        let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
2176            format!("residual_gauge: eigendecomposition of curvature gram failed: {e}")
2177        })?;
2178        let sigma_max_sq = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
2179        let sigma_max = sigma_max_sq.sqrt();
2180        let rank_tol = default_rrqr_rank_alpha()
2181            * f64::EPSILON
2182            * (root_rows.max(param_dim).max(1) as f64)
2183            * sigma_max.max(1.0);
2184        let lambda_tol = rank_tol * rank_tol;
2185        let pinning_rank = evals
2186            .iter()
2187            .filter(|&&lambda| lambda.max(0.0) > lambda_tol)
2188            .count();
2189        Ok(Self::Gram {
2190            pinning_rank,
2191            sigma_max_sq,
2192            gram,
2193        })
2194    }
2195
2196    fn pinning_rank(&self) -> usize {
2197        match self {
2198            Self::Root { pinning_rank, .. } | Self::Gram { pinning_rank, .. } => *pinning_rank,
2199        }
2200    }
2201
2202    fn sigma_max_sq(&self) -> f64 {
2203        match self {
2204            Self::Root { sigma_max_sq, .. } | Self::Gram { sigma_max_sq, .. } => *sigma_max_sq,
2205        }
2206    }
2207
2208    fn unit_generator_energy(&self, unit: &Array1<f64>) -> f64 {
2209        match self {
2210            Self::Root { root, .. } => {
2211                let r_xi = root.dot(unit);
2212                r_xi.iter().map(|c| c * c).sum::<f64>()
2213            }
2214            Self::Gram { gram, .. } => {
2215                let h_xi = gram.dot(unit);
2216                unit.dot(&h_xi).max(0.0)
2217            }
2218        }
2219    }
2220}
2221
2222/// Evaluate the identifiability rank machinery on the symmetry generators of a
2223/// fitted SAE-manifold model and certify which gauge group the fit is identified
2224/// up to.
2225///
2226/// # Method
2227///
2228/// 1. Enumerate the symmetry generators as tangent directions on the flattened
2229///    decoder frames: per-atom `Isom(M_k)` generators
2230///    ([`atom_isometry_generators`]), equal-ARD rotations
2231///    ([`equal_ard_rotation_generators`]), global output-frame rotations
2232///    ([`frame_rotation_generators`]), and exchangeable-atom permutations
2233///    ([`atom_permutation_generators`]).
2234/// 2. Build the stacked curvature root `R` of the pinning operator
2235///    `H = H_data + H_isometry = RᵀR` in the fit's [`RowMetric`]
2236///    ([`stacked_curvature_root`]); the pinning RANK is the audit's RRQR rank
2237///    of `R`, reported alongside.
2238/// 3. For each generator `ξ`, the **relative curvature fraction**
2239///    `‖R ξ̂‖² / σ_max(R)²` measures the curvature the converged objective has
2240///    along the unit generator, relative to the model's stiffest direction.
2241///    `ξ` is **unpinned** (a residual gauge freedom) iff that fraction is at
2242///    or below the calibrated tolerance
2243///    `max(`[`GENERATOR_FLAT_ENERGY_TOL`]`, lowering_error_scale)` — flat up
2244///    to numerical noise and the mean-frame lowering's own resolution
2245///    ([`FittedAtom::lowering_error`], #995). Any larger fraction — including
2246///    the *mixed* regime where `ξ` carries both a curved and a flat component
2247///    — means the orbit costs objective, the exact group element is broken,
2248///    and the generator is **pinned**. (A span-membership or rank-increase
2249///    test degenerates when `R` is full-rank, which production fits always
2250///    are: every direction is "in the span", so verdicts would collapse to
2251///    all-pinned regardless of magnitudes. Keeping the curvature magnitudes
2252///    is what lets a genuinely flat direction stay visible inside a full-rank
2253///    span.) The fraction and the calibration scale are reported per
2254///    generator so partial flatness stays visible.
2255///
2256/// # Escalations
2257///
2258/// * When the isometry pin is inactive (`isometry_penalty_root` has no rows) the
2259///   report sets `diffeomorphism_unpinned = true`: with no metric pin the model
2260///   is only identified up to an arbitrary diffeomorphism of the latent
2261///   manifolds, so every isometry generator is a residual freedom.
2262/// * Under [`MetricProvenance::OutputFisher`] the `Sym(F)` permutation subgroup
2263///   is checked for triviality: every atom-exchange generator must be pinned
2264///   (the output-Fisher metric separates the atoms behaviorally). The result is
2265///   carried in `sym_f_trivial_under_output_fisher`.
2266pub fn residual_gauge(model: &FittedSaeManifold) -> Result<ResidualGaugeReport, String> {
2267    residual_gauge_inner(model, None, None)
2268}
2269
2270/// The #998 full-resolution certificate: within-atom gauge families are
2271/// realised as **exact orbits** in the model's own (decoder, coordinate)
2272/// parameter space for every atom that supplies an [`AtomParameterView`],
2273/// while cross-atom families (output-frame rotations, atom permutations) and
2274/// any unviewed atom (e.g. spheres, whose chart action is nonlinear) keep the
2275/// frame-space path with its #995 lowering-error calibration.
2276///
2277/// For a viewed atom the compensated orbit is a data-null **by construction**
2278/// when the basis family is closed under the group action — the verdict
2279/// carries no calibration (`lowering_error_scale = 0`), the compensation
2280/// residual is the computed closure, and all pinning of true model-class
2281/// symmetries flows through the per-atom [`OrbitPenaltyOperator`] channel
2282/// (the isometry pin / ARD prior — rungs 2 and 4 of the #981 ladder).
2283///
2284/// `views` and `penalty_ops` are aligned with `model.atoms`; a `None` view
2285/// keeps that atom entirely on the frame path. Supplying a view for an atom
2286/// whose pin is active without also supplying its penalty operator would
2287/// over-claim freedom, so callers must pass the operator (or no view) for
2288/// pinned atoms.
2289pub fn residual_gauge_exact(
2290    model: &FittedSaeManifold,
2291    views: &[Option<AtomParameterView>],
2292    penalty_ops: &[Option<OrbitPenaltyOperator>],
2293) -> Result<ResidualGaugeReport, String> {
2294    let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
2295    residual_gauge_inner(model, Some(exact), None)
2296}
2297
2298/// Exact-orbit residual-gauge certificate with a pre-reduced streamed curvature
2299/// Gram `RᵀR`.
2300///
2301/// This is the memory-scaled entry point for callers that can stream their
2302/// metric-whitened Jacobian rows into the reductions the certificate consumes,
2303/// instead of retaining every per-row `p × param_dim` Jacobian block. The Gram
2304/// must include the same rows [`stacked_curvature_root`] would have placed in
2305/// `R`; `root_rows` is that row count for the rank tolerance scale.
2306pub fn residual_gauge_exact_from_curvature_gram(
2307    model: &FittedSaeManifold,
2308    views: &[Option<AtomParameterView>],
2309    penalty_ops: &[Option<OrbitPenaltyOperator>],
2310    curvature_gram: Array2<f64>,
2311    root_rows: usize,
2312) -> Result<ResidualGaugeReport, String> {
2313    let param_dim = model.param_dim();
2314    let curvature = CurvatureReduction::from_gram(curvature_gram, root_rows, param_dim)?;
2315    let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
2316    residual_gauge_inner(model, Some(exact), Some(curvature))
2317}
2318
2319fn residual_gauge_exact_inputs(
2320    model: &FittedSaeManifold,
2321    views: &[Option<AtomParameterView>],
2322    penalty_ops: &[Option<OrbitPenaltyOperator>],
2323) -> Result<(Vec<bool>, Vec<GeneratorVerdict>), String> {
2324    if views.len() != model.atoms.len() || penalty_ops.len() != model.atoms.len() {
2325        return Err(format!(
2326            "residual_gauge_exact: views ({}) and penalty_ops ({}) must align with atoms ({})",
2327            views.len(),
2328            penalty_ops.len(),
2329            model.atoms.len()
2330        ));
2331    }
2332    let mut mask = vec![false; model.atoms.len()];
2333    let mut exact_verdicts: Vec<GeneratorVerdict> = Vec::new();
2334    for (k, (atom, view)) in model.atoms.iter().zip(views.iter()).enumerate() {
2335        let Some(view) = view else { continue };
2336        // Sphere charts: nonlinear group action — refuse exactness, keep the
2337        // calibrated frame path for this atom rather than pretending.
2338        if matches!(atom.topology, AtomTopology::Sphere) {
2339            continue;
2340        }
2341        exact_verdicts.extend(exact_orbit_verdicts(atom, view, penalty_ops[k].as_ref())?);
2342        mask[k] = true;
2343    }
2344    Ok((mask, exact_verdicts))
2345}
2346
2347fn residual_gauge_inner(
2348    model: &FittedSaeManifold,
2349    exact: Option<(Vec<bool>, Vec<GeneratorVerdict>)>,
2350    precomputed_curvature: Option<CurvatureReduction>,
2351) -> Result<ResidualGaugeReport, String> {
2352    let metric_provenance = model.metric.provenance();
2353    let param_dim = model.param_dim();
2354    let (exact_mask, exact_verdicts) = match exact {
2355        Some((mask, verdicts)) => (Some(mask), verdicts),
2356        None => (None, Vec::new()),
2357    };
2358
2359    // 1. Enumerate generators, tagged by family. The per-atom builders speak
2360    // the atom's LOCAL flattened-frame coordinates (length `frame.len()`); the
2361    // certificate's rank arithmetic runs in the joint parameter vector, so each
2362    // local generator is embedded at its atom's offset here. (Single-atom
2363    // models have local == joint, which is why only multi-atom models can
2364    // expose a missed embedding.)
2365    // Each generator carries its #995 lowering-error tolerance scale: the
2366    // largest `lowering_error` over the atoms it touches.
2367    let scale_of = |k: usize| -> f64 { model.atoms[k].lowering_error.clamp(0.0, 1.0) };
2368    let global_scale = (0..model.atoms.len()).map(scale_of).fold(0.0_f64, f64::max);
2369    let mut gens: Vec<(GeneratorFamily, Array1<f64>, String, f64)> = Vec::new();
2370    for (k, atom) in model.atoms.iter().enumerate() {
2371        // Atoms whose within-atom families are realised exactly (#998) are
2372        // skipped here: the frame-space lift of a compensated orbit measures
2373        // compression, not the symmetry, and the report must not carry both a
2374        // lossy and an exact verdict for the same group element.
2375        if exact_mask.as_ref().is_some_and(|mask| mask[k]) {
2376            continue;
2377        }
2378        let base = model.atom_offset(k);
2379        for (g, desc) in atom_isometry_generators(atom) {
2380            gens.push((
2381                GeneratorFamily::IsomAtom,
2382                embed_local_generator(base, &g, param_dim),
2383                desc,
2384                scale_of(k),
2385            ));
2386        }
2387        for (g, desc) in equal_ard_rotation_generators(atom) {
2388            gens.push((
2389                GeneratorFamily::EqualArdRotation,
2390                embed_local_generator(base, &g, param_dim),
2391                desc,
2392                scale_of(k),
2393            ));
2394        }
2395    }
2396    for (g, desc) in frame_rotation_generators(model) {
2397        // A global output rotation moves every atom's frame at once.
2398        gens.push((GeneratorFamily::FrameRotation, g, desc, global_scale));
2399    }
2400    for (g, desc, ka, kb) in atom_permutation_generators(model) {
2401        gens.push((
2402            GeneratorFamily::AtomPermutation,
2403            g,
2404            desc,
2405            scale_of(ka).max(scale_of(kb)),
2406        ));
2407    }
2408
2409    // 2. Stacked curvature root in the metric; pinning rank via the audit's
2410    // RRQR on Rᵀ, stiffness scale σ_max via SVD (magnitudes kept).
2411    let curvature = match precomputed_curvature {
2412        Some(curvature) => curvature,
2413        None => CurvatureReduction::from_model(model)?,
2414    };
2415    let pinning_rank = curvature.pinning_rank();
2416    let sigma_max_sq = curvature.sigma_max_sq();
2417
2418    // The isometry pin is inactive ⇒ diffeomorphism-unpinned escalation.
2419    let diffeomorphism_unpinned = model.isometry_penalty_root.nrows() == 0;
2420
2421    // 3. Per-generator flatness verdict: relative curvature vs the calibrated
2422    // tolerance.
2423    let mut verdicts: Vec<GeneratorVerdict> = Vec::with_capacity(gens.len());
2424    for (family, g, description, lowering_error_scale) in &gens {
2425        let norm = g.iter().map(|v| v * v).sum::<f64>().sqrt();
2426        // A structurally trivial generator (rotation of a rank-deficient frame,
2427        // zero swap) carries no direction — it cannot be a residual freedom.
2428        // Report it pinned with zero norm rather than as a spurious gauge.
2429        if norm <= f64::MIN_POSITIVE {
2430            verdicts.push(GeneratorVerdict {
2431                family: *family,
2432                description: description.clone(),
2433                unpinned: false,
2434                generator_norm: 0.0,
2435                pinned_energy_fraction: 1.0,
2436                lowering_error_scale: *lowering_error_scale,
2437                provenance: VerdictProvenance::CurvatureTest,
2438            });
2439            continue;
2440        }
2441        // Relative curvature fraction ‖R ξ̂‖² / σ_max(R)² of the unit
2442        // generator ξ̂ = ξ/‖ξ‖. Exactly flat directions score 0 even inside a
2443        // full-rank span (production fits!), where the previous
2444        // span-membership rule degenerated to all-pinned. A MIXED generator
2445        // (strictly interior fraction) above the tolerance is pinned: its
2446        // orbit costs objective, so the exact symmetry does not survive
2447        // (#980 Theorem-2 arm). The tolerance is calibrated by the #995
2448        // lowering-error scale: curvature the mean-frame compression cannot
2449        // distinguish from gauge motion must not be read as a pin — the
2450        // certificate refuses to claim resolution it does not have.
2451        let pinned_energy_fraction = if sigma_max_sq <= f64::MIN_POSITIVE {
2452            0.0
2453        } else {
2454            let unit = g.mapv(|v| v / norm);
2455            (curvature.unit_generator_energy(&unit) / sigma_max_sq).clamp(0.0, 1.0)
2456        };
2457        let tolerance = GENERATOR_FLAT_ENERGY_TOL.max(*lowering_error_scale);
2458        let unpinned = pinned_energy_fraction <= tolerance;
2459        verdicts.push(GeneratorVerdict {
2460            family: *family,
2461            description: description.clone(),
2462            unpinned,
2463            generator_norm: norm,
2464            pinned_energy_fraction,
2465            lowering_error_scale: *lowering_error_scale,
2466            provenance: VerdictProvenance::CurvatureTest,
2467        });
2468    }
2469
2470    // Exact-orbit verdicts (#998) join the report on equal footing: the
2471    // group signature, residual dimension, and Sym(F) check all range over
2472    // the union.
2473    verdicts.extend(exact_verdicts);
2474
2475    // #1019 — post-fit arc-length chart canonicalization records: for every
2476    // canonicalized d = 1 atom the continuous chart (reparameterization)
2477    // freedom is pinned BY CONSTRUCTION (the unit-speed representative of the
2478    // Diff(M) orbit was selected post-fit, image-frozen), so the certificate
2479    // records it pinned with the PinnedByCanonicalization provenance —
2480    // distinct from curvature/penalty pinning, since no objective resistance
2481    // was measured — and names the surviving FINITE isometry group of the
2482    // reference manifold. The group's continuous part (the circle's U(1)
2483    // shift) is still enumerated and curvature-tested above; this record is
2484    // the chart-freedom downgrade itself.
2485    let mut canonicalized_charts = 0usize;
2486    let mut canonicalized_torus_charts = 0usize;
2487    let mut canonicalized_patch_charts = 0usize;
2488    let mut canonicalized_sphere_charts = 0usize;
2489    for atom in &model.atoms {
2490        if !atom.chart_canonicalized {
2491            continue;
2492        }
2493        let (pinned_to, residual_group) = match &atom.topology {
2494            AtomTopology::Circle | AtomTopology::Torus { latent_dim: 1 } => {
2495                canonicalized_charts += 1;
2496                ("arc length", "O(2) on S¹ (rotation + reflection)")
2497            }
2498            AtomTopology::EuclideanPatch { latent_dim: 1 } => {
2499                canonicalized_charts += 1;
2500                (
2501                    "arc length",
2502                    "reflection + translation of the unit interval",
2503                )
2504            }
2505            // #1019 stage 2: d = 2 torus charts are pinned post-fit to the
2506            // minimum-isometry-defect flow representative; the surviving chart
2507            // freedom is the isometry group of the flat square torus.
2508            AtomTopology::Torus { latent_dim: 2 } => {
2509                canonicalized_torus_charts += 1;
2510                (
2511                    "the isometry-flow canonical chart",
2512                    "Isom(T², flat) = U(1)² ⋊ D₄ (axis translations + axis swap/reflections)",
2513                )
2514            }
2515            // #1019 free-chart arm: d = 2 free/patch (Euclidean-patch) charts
2516            // are pinned post-fit to the flat-reference minimum-anisotropy-
2517            // defect flow representative; the surviving chart freedom is the
2518            // isometry group of the flat plane.
2519            AtomTopology::EuclideanPatch { latent_dim: 2 } => {
2520                canonicalized_patch_charts += 1;
2521                (
2522                    "the flat-reference isometry-flow canonical chart",
2523                    "Isom(ℝ², flat) = O(2) ⋉ ℝ² (rotation + reflection + translation)",
2524                )
2525            }
2526            // #1019 sphere arm: d = 2 sphere (S²) charts are pinned post-fit to
2527            // the round-sphere conformal-boost minimum-isometry-defect flow,
2528            // which breaks the conformal (Möbius) moduli down to the round
2529            // sphere's isometry group; the surviving chart freedom is O(3).
2530            AtomTopology::Sphere => {
2531                canonicalized_sphere_charts += 1;
2532                (
2533                    "the round-sphere conformal-boost isometry-flow canonical chart",
2534                    "Isom(S², round) = O(3) (rotations + reflection)",
2535                )
2536            }
2537            // Canonicalization only ever applies to d = 1 charts, d = 2 torus,
2538            // d = 2 free/patch, and d = 2 sphere charts; a flag on any other
2539            // topology is structurally inconsistent and must not fabricate a
2540            // record.
2541            _ => continue,
2542        };
2543        verdicts.push(GeneratorVerdict {
2544            family: GeneratorFamily::ChartReparameterization,
2545            description: format!(
2546                "{}: chart pinned to {pinned_to} by post-fit canonicalization; \
2547                 residual chart freedom = {residual_group}",
2548                atom.name
2549            ),
2550            unpinned: false,
2551            generator_norm: 0.0,
2552            pinned_energy_fraction: 1.0,
2553            lowering_error_scale: 0.0,
2554            provenance: VerdictProvenance::PinnedByCanonicalization,
2555        });
2556    }
2557
2558    let residual_gauge_dim = verdicts.iter().filter(|v| v.unpinned).count();
2559
2560    // Sym(F)-triviality under any output-Fisher provenance — same-position
2561    // (`OutputFisher`) or downstream-influence (`OutputFisherDownstream`, #980).
2562    // Both behaviorally separate the atoms (the downstream metric strictly more,
2563    // since it sees far-future coupling the same-position metric misses), so the
2564    // permutation subgroup must be trivially pinned under either.
2565    let sym_f_trivial_under_output_fisher = if matches!(
2566        metric_provenance,
2567        MetricProvenance::OutputFisher { .. } | MetricProvenance::OutputFisherDownstream { .. }
2568    ) {
2569        let any_perm_unpinned = verdicts
2570            .iter()
2571            .any(|v| v.family == GeneratorFamily::AtomPermutation && v.unpinned);
2572        Some(!any_perm_unpinned)
2573    } else {
2574        None
2575    };
2576
2577    let summary = format!(
2578        "residual gauge certificate (computed in metric {metric_provenance:?}): \
2579         pinning rank {pinning_rank}, {residual_gauge_dim} unpinned residual gauge \
2580         generator(s) of {} enumerated; group = {}{}{}",
2581        verdicts.len(),
2582        group_signature_of(&verdicts, diffeomorphism_unpinned),
2583        match sym_f_trivial_under_output_fisher {
2584            Some(true) => "; Sym(F) trivially pinned under OutputFisher",
2585            Some(false) => "; ⚠ Sym(F) NON-trivial under OutputFisher (certificate violation)",
2586            None => "",
2587        },
2588        if diffeomorphism_unpinned {
2589            "; ⚠ isometry pin inactive"
2590        } else {
2591            ""
2592        },
2593    );
2594    let summary = if canonicalized_charts > 0 {
2595        format!(
2596            "{summary}; {canonicalized_charts} chart(s) pinned to arc length by post-fit \
2597             canonicalization (residual chart freedom = finite isometry group)"
2598        )
2599    } else {
2600        summary
2601    };
2602    let summary = if canonicalized_torus_charts > 0 {
2603        format!(
2604            "{summary}; {canonicalized_torus_charts} torus chart(s) pinned to the \
2605             isometry-flow canonical chart by post-fit canonicalization (residual chart \
2606             freedom = Isom(T², flat))"
2607        )
2608    } else {
2609        summary
2610    };
2611    let summary = if canonicalized_patch_charts > 0 {
2612        format!(
2613            "{summary}; {canonicalized_patch_charts} free/patch chart(s) pinned to the \
2614             flat-reference isometry-flow canonical chart by post-fit canonicalization \
2615             (residual chart freedom = Isom(ℝ², flat) = O(2) ⋉ ℝ²)"
2616        )
2617    } else {
2618        summary
2619    };
2620    let summary = if canonicalized_sphere_charts > 0 {
2621        format!(
2622            "{summary}; {canonicalized_sphere_charts} sphere chart(s) pinned to the \
2623             round-sphere conformal-boost isometry-flow canonical chart by post-fit \
2624             canonicalization (residual chart freedom = Isom(S², round) = O(3))"
2625        )
2626    } else {
2627        summary
2628    };
2629
2630    Ok(ResidualGaugeReport {
2631        metric_provenance,
2632        generators: verdicts,
2633        pinning_rank,
2634        residual_gauge_dim,
2635        diffeomorphism_unpinned,
2636        sym_f_trivial_under_output_fisher,
2637        // The #972 inner-rotation gauge is declared by the caller (it lives in
2638        // the (U_k, C_k) parameterization, not in the latent-frame coordinates
2639        // this certificate's generators are tangent to); frame-factored
2640        // dictionaries attach it via `with_frame_inner_rotation`.
2641        frame_inner_rotation: None,
2642        summary,
2643    })
2644}
2645
2646/// The model's two certificates, shipped together (#984 work-plan step 2):
2647/// the residual-gauge report says what NO data could distinguish (the
2648/// symmetry group the fit is identified up to — a statement about the
2649/// model class), the structure certificate says what THIS data
2650/// established (the e-BH-confirmed subset of the dictionary's structural
2651/// claims, FDR ≤ α, valid at the caller's stopping time — a statement
2652/// about the world). A claim can fail both ways, and the failure modes
2653/// are independent: an atom can be perfectly identified yet statistically
2654/// unestablished, or strongly evidenced yet gauge-ambiguous with a twin.
2655#[derive(Debug, Clone)]
2656pub struct DictionaryReport {
2657    /// What cannot be distinguished in principle ([`residual_gauge`]).
2658    pub gauge: ResidualGaugeReport,
2659    /// What the data established
2660    /// ([`gam_terms::inference::structure_evidence::StructureLedger::certify`]).
2661    pub structure: StructureCertificate,
2662    /// Per-atom inter-layer transport ladders (#1096). Empty when the caller
2663    /// has not supplied at least one atom's canonical coordinates across two or
2664    /// more layers. These reports are computed in the transport module's chart
2665    /// convention: circle coordinates are radians on `[0, 2π)`, while SAE
2666    /// canonical circle charts may use an arbitrary period and are rescaled by
2667    /// [`dictionary_report_with_transport_ladders`] before fitting.
2668    pub transport_ladders: Vec<AtomTransportLadderReport>,
2669    /// Per-atom post-PIRLS inference reports (#1097 penalty-debiased functional
2670    /// POINT summaries, #1103 split-LRT smooth-structure e-value), one entry
2671    /// per atom in [`FittedSaeManifold::atoms`] order. The #1099 per-atom
2672    /// curvature CI was removed under #1115 (a curvature BOUND is not an
2673    /// estimand and its SE conditioned on generated regressors); the surviving
2674    /// plug-in curvature point estimate lives on
2675    /// [`crate::manifold::CertificateInputs::per_atom_kappa_hat`],
2676    /// not here. Each report's
2677    /// fields are computed when the atom carries its fit-time
2678    /// [`AtomInnerFit`] byproducts and the relevant numerics succeed; otherwise
2679    /// the field is `None` (a bare certificate-only `FittedSaeManifold` — one
2680    /// built by the residual-gauge path with no fit harness — leaves every
2681    /// `inner_fit` `None`, so both fields are `None`).
2682    pub atom_inference: Vec<AtomInferenceReport>,
2683}
2684
2685/// Canonical per-layer coordinates for one atom, ready for the #1096 transport
2686/// ladder integration.
2687///
2688/// The caller owns extraction from the SAE fit: `layers[i]`, `coords[i]`, and
2689/// `topologies[i]` describe the same atom at the same layer. This type keeps
2690/// that extraction outside [`dictionary_report`] so the core certificate can be
2691/// wired without reaching into `SaeManifoldTerm`.
2692#[derive(Debug, Clone)]
2693pub struct AtomTransportLadderInput {
2694    /// Index into [`FittedSaeManifold::atoms`].
2695    pub atom_index: usize,
2696    /// Layer labels in ladder order.
2697    pub layers: Vec<usize>,
2698    /// One canonical coordinate vector per layer, all over the same rows.
2699    pub coords: Vec<Array1<f64>>,
2700    /// One canonical chart topology per layer.
2701    pub topologies: Vec<CanonicalChartTopology>,
2702}
2703
2704/// One atom's fitted inter-layer transport ladder.
2705#[derive(Debug, Clone)]
2706pub struct AtomTransportLadderReport {
2707    pub atom_index: usize,
2708    pub atom_name: String,
2709    pub report: TransportLadderReport,
2710}
2711
2712/// #1097 penalty-debiased smooth-functional POINT summaries for one atom's
2713/// captured inner-decoder smooth (narrowed under #1115).
2714///
2715/// All three functionals are *linear* in the atom's fitted coefficient vector
2716/// `β_{k,j}`, so each is one-step penalty-debiased through the SAME penalized
2717/// Hessian the identifiability certificate's curvature sees
2718/// ([`AtomInnerFit::penalized_hessian`]) by routing the functional gradient,
2719/// the per-row scores, and the penalty gradient `S̃_k β` through
2720/// [`debias_with_dense_hessian`]. Only the resulting POINT estimates (plug-in,
2721/// penalty-debiased, removed bias) are kept; the influence-function SE is
2722/// discarded because it conditions on the generated latent coordinates `t̂` /
2723/// assignment `â` as if known and so under-covers (see
2724/// [`AtomFunctionalReport`] for the full argument). A non-SPD Hessian or a
2725/// degenerate functional (empty design, non-finite gradient) leaves the
2726/// offending field `None`; the other two still report.
2727fn atom_functional_report(fit: &AtomInnerFit) -> AtomFunctionalReport {
2728    let penalty_beta = fit.penalty.dot(&fit.beta);
2729
2730    // A small closed-form helper: build the Riesz input for a functional
2731    // gradient and penalty-debias it through the fitted penalized Hessian, then
2732    // KEEP ONLY the point estimates (the SE is not honest here — #1115). The
2733    // Riesz layer's own `EstimationError` is collapsed into `None` — a numerical
2734    // refusal is a missing field, not a poisoned report.
2735    let debias = |functional_gradient: Array1<f64>| -> Option<AtomFunctionalEstimate> {
2736        let input = RieszInput {
2737            beta: fit.beta.view(),
2738            functional_gradient: functional_gradient.view(),
2739            row_scores: fit.row_scores.view(),
2740            penalty_beta: penalty_beta.view(),
2741            leverage: None,
2742        };
2743        debias_with_dense_hessian(&input, fit.penalized_hessian.view())
2744            .ok()
2745            .map(|r| AtomFunctionalEstimate {
2746                theta_plugin: r.theta_plugin,
2747                theta_onestep: r.theta_onestep,
2748                penalty_bias: r.penalty_bias,
2749            })
2750    };
2751
2752    // Peak-vs-mode contrast g(t_peak) − g(t_mode): the linear functional whose
2753    // gradient is the difference of the two design rows.
2754    let peak_contrast = SmoothFunctional::Contrast {
2755        design_row_a: fit.peak_design_row.view(),
2756        design_row_b: fit.mode_design_row.view(),
2757    }
2758    .gradient()
2759    .ok()
2760    .and_then(debias);
2761
2762    // E_data[g(t_i)]: the mass-weighted average decoder value over active rows.
2763    let average_value = SmoothFunctional::AverageValue {
2764        value_design: fit.design.view(),
2765        weights: Some(fit.weights.view()),
2766    }
2767    .gradient()
2768    .ok()
2769    .and_then(debias);
2770
2771    // ‖E_data[∂g/∂t]‖ along the leading latent axis: the mass-weighted average
2772    // of the derivative-design rows (the Gauss–Newton weights `w_i = a_ik²` are
2773    // the data measure over the atom's active rows). This is the conditional-
2774    // on-fit decoder-VARIATION norm, not a population marginal slope.
2775    let decoder_variation_norm = SmoothFunctional::AverageDerivative {
2776        derivative_design: fit.derivative_design.view(),
2777        weights: Some(fit.weights.view()),
2778    }
2779    .gradient()
2780    .ok()
2781    .and_then(debias);
2782
2783    AtomFunctionalReport {
2784        peak_contrast,
2785        average_value,
2786        decoder_variation_norm,
2787    }
2788}
2789
2790/// #1103 Any-n-valid structure evidence that one atom's inner smooth is
2791/// non-constant, via the split-likelihood-ratio e-value.
2792///
2793/// The inner decoder smooth is the Gaussian-identity penalized WLS fit
2794/// `a_ik · Φ_k(t)ᵀ β_{k,j}` with dispersion `φ = `[`AtomInnerFit::dispersion`],
2795/// working response `z_i` reconstructed from the captured per-row scores. H0 is
2796/// "the smooth is constant": only the intercept column 0 is free.
2797///
2798/// We compute the universal-inference e-value the atom-birth gate
2799/// ([`gam_terms::inference::structure_evidence::split_likelihood_log_e_value`]) uses:
2800///
2801/// * Split the active rows deterministically into an ESTIMATION fold (even
2802///   index) and an EVALUATION fold (odd index).
2803/// * On the estimation fold, fit the penalized smooth (the alternative) by
2804///   `β̂ = (ΦᵀWΦ + S)⁻¹ ΦᵀW z` — any fitter is admissible; zero conditions.
2805/// * On the evaluation fold, score the Gaussian log-likelihood under that
2806///   prefit alternative, and the SUPREMUM of the evaluation-fold log-likelihood
2807///   over the null class (the constant fit = weighted-mean response refit on the
2808///   eval fold — the honest constrained sup on D₀).
2809/// * `log E = ℓ_alt(D₀) − sup_{H0} ℓ(D₀)`, with `E_{H0}[E] ≤ 1` exactly.
2810///
2811/// The dispersion `φ` is held fixed at the fitted reconstruction dispersion in
2812/// both log-likelihoods so it cancels structurally and the e-value isolates the
2813/// mean-curvature evidence. Returns `None` when the design has no curvature
2814/// column (`M_k ≤ 1`), either fold is empty, or the inner Gram is not SPD.
2815fn atom_smooth_significance(fit: &AtomInnerFit) -> Option<AtomSmoothSignificance> {
2816    let m = fit.design.ncols();
2817    if m <= 1 || fit.beta.len() != m {
2818        // No curvature column: the constant null IS the full model — there is no
2819        // non-constant alternative to earn an e-value.
2820        return None;
2821    }
2822    let n = fit.design.nrows();
2823    if n == 0 || fit.weights.len() != n || fit.row_scores.nrows() != n {
2824        return None;
2825    }
2826    let phi = if fit.dispersion.is_finite() && fit.dispersion > 0.0 {
2827        fit.dispersion
2828    } else {
2829        return None;
2830    };
2831
2832    // Per-row working response z_i = μ̂_i + r_i, reconstructing the scalar
2833    // residual r_i from the captured score projected onto the design row
2834    // (s_iᵀ Φ_i = −w_i r_i ‖Φ_i‖² / φ ⇒ r_i). Same reconstruction the previous
2835    // deviance path used; here it feeds the two folds' likelihoods.
2836    let mut z = Array1::<f64>::zeros(n);
2837    for i in 0..n {
2838        let mu_hat = fit.design.row(i).dot(&fit.beta);
2839        let w_i = fit.weights[i];
2840        let phi_row = fit.design.row(i);
2841        let phi_norm_sq = phi_row.dot(&phi_row);
2842        let r_i = if w_i > 0.0 && phi_norm_sq > 0.0 {
2843            let s_dot_phi = fit.row_scores.row(i).dot(&phi_row);
2844            -phi * s_dot_phi / (w_i * phi_norm_sq)
2845        } else {
2846            0.0
2847        };
2848        z[i] = mu_hat + r_i;
2849    }
2850
2851    // Deterministic estimation/evaluation split by row parity.
2852    let est: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
2853    let eval: Vec<usize> = (0..n).filter(|i| i % 2 == 1).collect();
2854    if est.is_empty() || eval.is_empty() {
2855        return None;
2856    }
2857
2858    // Penalized smooth fit on the estimation fold: β̂ = (ΦᵀWΦ + S)⁻¹ ΦᵀW z.
2859    let mut a_gram = fit.penalty.clone();
2860    let mut b = Array1::<f64>::zeros(m);
2861    for &i in &est {
2862        let w_i = fit.weights[i];
2863        if !(w_i > 0.0) {
2864            continue;
2865        }
2866        let row = fit.design.row(i);
2867        for r in 0..m {
2868            let xr = row[r];
2869            if xr == 0.0 {
2870                continue;
2871            }
2872            b[r] += w_i * xr * z[i];
2873            for c in 0..m {
2874                a_gram[[r, c]] += w_i * xr * row[c];
2875            }
2876        }
2877    }
2878    let beta_alt = a_gram.cholesky(Side::Lower).ok()?.solvevec(&b);
2879
2880    // Null sup on the EVALUATION fold: the weighted-mean response (the constant
2881    // fit's MLE on D₀, the honest constrained sup over the null class).
2882    let mut eval_mass = 0.0_f64;
2883    let mut eval_wz = 0.0_f64;
2884    for &i in &eval {
2885        let w_i = fit.weights[i];
2886        eval_mass += w_i;
2887        eval_wz += w_i * z[i];
2888    }
2889    if !(eval_mass > 0.0) {
2890        return None;
2891    }
2892    let null_mean = eval_wz / eval_mass;
2893
2894    // Gaussian log-likelihoods on the evaluation fold at fixed dispersion φ;
2895    // the −½ log(2πφ) and weight-log terms are identical under both models, so
2896    // log E = −(½/φ) [ Σ w(z − μ_alt)² − Σ w(z − μ_null)² ].
2897    let mut sse_alt = 0.0_f64;
2898    let mut sse_null = 0.0_f64;
2899    for &i in &eval {
2900        let w_i = fit.weights[i];
2901        let mu_alt = fit.design.row(i).dot(&beta_alt);
2902        let r_alt = z[i] - mu_alt;
2903        let r_null = z[i] - null_mean;
2904        sse_alt += w_i * r_alt * r_alt;
2905        sse_null += w_i * r_null * r_null;
2906    }
2907    let log_lik_alt = -0.5 * sse_alt / phi;
2908    let log_lik_null_sup = -0.5 * sse_null / phi;
2909    let log_e = gam_terms::inference::structure_evidence::split_likelihood_log_e_value(
2910        log_lik_alt,
2911        log_lik_null_sup,
2912    );
2913    if !log_e.is_finite() {
2914        return None;
2915    }
2916
2917    Some(AtomSmoothSignificance {
2918        log_e_nonconstant: Some(log_e),
2919    })
2920}
2921
2922/// Assemble the post-PIRLS inference reports for every atom, reusing the
2923/// per-atom [`AtomInnerFit`] harvested at fit time.
2924///
2925/// * #1097 penalty-debiased functional POINT summaries and the #1103 split-LRT
2926///   smooth-structure e-value are computed from the captured inner-decoder
2927///   smooth (design, penalized Hessian, row scores, roughness Gram) — they need
2928///   only the fixed fitted snapshot.
2929/// * The #1099 per-atom curvature *confidence interval* was removed under #1115:
2930///   a sup-norm curvature BOUND is not an estimand with a profiled criterion,
2931///   and its delta-method SE conditioned on generated latent coordinates as if
2932///   known. The plug-in curvature point estimate survives on
2933///   [`crate::manifold::CertificateInputs::per_atom_kappa_hat`] (the
2934///   #1008 empirical curved-dictionary report), not on this report.
2935pub(crate) fn atom_inference_reports(model: &FittedSaeManifold) -> Vec<AtomInferenceReport> {
2936    model
2937        .atoms
2938        .iter()
2939        .enumerate()
2940        .map(|(atom_index, atom)| {
2941            let (functionals, smooth_significance) = match &atom.inner_fit {
2942                Some(fit) => (
2943                    Some(atom_functional_report(fit)),
2944                    atom_smooth_significance(fit),
2945                ),
2946                None => (None, None),
2947            };
2948            AtomInferenceReport {
2949                atom_index,
2950                atom_name: atom.name.clone(),
2951                functionals,
2952                smooth_significance,
2953            }
2954        })
2955        .collect()
2956}
2957
2958/// Produce the paired certificate for a fitted model: the residual-gauge
2959/// report computed here plus the anytime-valid structure certificate from
2960/// the discovery run's evidence ledger at level `alpha`. The ledger is the
2961/// one the structure search absorbed its shard evidence into
2962/// (`structure_evidence::StructureLedger`); certifying at any
2963/// data-dependent stopping time is sound — that is the ledger's whole
2964/// design.
2965pub fn dictionary_report(
2966    model: &FittedSaeManifold,
2967    ledger: &StructureLedger,
2968    alpha: f64,
2969) -> Result<DictionaryReport, String> {
2970    Ok(DictionaryReport {
2971        gauge: residual_gauge(model)?,
2972        structure: ledger.certify(alpha),
2973        transport_ladders: Vec::new(),
2974        atom_inference: atom_inference_reports(model),
2975    })
2976}
2977
2978// --- #1100: closed-loop probe runner FFI ---------------------------------
2979// Top-level entry points exposing the steering→structure-evidence probe loop
2980// (`crate::inference::probe_runner::ProbeRunner`) beside `dictionary_report`, so
2981// the Python driver can design and absorb interventional probes against the same
2982// fitted term and evidence ledger the certificate is built from.
2983
2984/// Design the next interventional probe for the most contested steerable claim
2985/// in `ledger`, against the fitted SAE-manifold `term` read through its per-row
2986/// output-Fisher `metric`.
2987///
2988/// Thin top-level wrapper over [`crate::inference::probe_runner::ProbeRunner::design_next`]:
2989/// it selects the contested claim furthest from certification, realizes candidate
2990/// latent moves of its atom through `crate::inference::steering::steer_delta`,
2991/// and routes their doses through
2992/// `gam_terms::inference::structure_evidence::plan_probe_for_contested_claim` to pick
2993/// the most discriminating one. The returned
2994/// [`crate::inference::probe_runner::RealizedProbe`] carries both the experiment
2995/// plan and the chosen intervention's on-manifold activation delta with its
2996/// dosimetry and validity radius.
2997pub fn design_probe(
2998    term: &SaeManifoldTerm,
2999    metric: &RowMetric,
3000    ledger: &StructureLedger,
3001) -> Result<RealizedProbe, String> {
3002    ProbeRunner { term, metric }.design_next(ledger)
3003}
3004
3005/// Absorb a realized probe outcome into `ledger`, banking the delivered
3006/// behavioral dose (`realized_nats`, the observed output-Fisher KL of the steered
3007/// response) as anytime-valid evidence for the probe's claim.
3008///
3009/// Thin top-level wrapper over [`crate::inference::probe_runner::ProbeRunner::absorb`].
3010pub fn absorb_probe(
3011    term: &SaeManifoldTerm,
3012    metric: &RowMetric,
3013    ledger: &mut StructureLedger,
3014    probe: &RealizedProbe,
3015    realized_nats: f64,
3016) {
3017    ProbeRunner { term, metric }.absorb(ledger, probe, realized_nats);
3018}
3019
3020/// Produce the paired certificate plus #1096 per-atom layer-transport ladders.
3021///
3022/// This is the strict wiring seam for callers that already have canonical
3023/// per-layer atom coordinates. It validates atom indices, topology/coordinate
3024/// lengths, finite coordinates, and the circle-period convention before calling
3025/// [`transport_ladder`]. Single-layer inputs are refused: no transport estimand
3026/// exists without at least one adjacent layer pair.
3027pub fn dictionary_report_with_transport_ladders(
3028    model: &FittedSaeManifold,
3029    ledger: &StructureLedger,
3030    alpha: f64,
3031    ladders: &[AtomTransportLadderInput],
3032) -> Result<DictionaryReport, String> {
3033    let mut report = dictionary_report(model, ledger, alpha)?;
3034    report.transport_ladders = atom_transport_ladder_reports(model, ladders)?;
3035    Ok(report)
3036}
3037
3038/// Fit #1096 transport ladders for the supplied atom/layer coordinate blocks.
3039pub fn atom_transport_ladder_reports(
3040    model: &FittedSaeManifold,
3041    ladders: &[AtomTransportLadderInput],
3042) -> Result<Vec<AtomTransportLadderReport>, String> {
3043    let mut out = Vec::with_capacity(ladders.len());
3044    for input in ladders {
3045        let atom = model.atoms.get(input.atom_index).ok_or_else(|| {
3046            format!(
3047                "atom transport ladder index {} out of range for {} fitted atoms",
3048                input.atom_index,
3049                model.atoms.len()
3050            )
3051        })?;
3052        let depth = input.layers.len();
3053        if depth < 2 {
3054            return Err(format!(
3055                "atom transport ladder for atom {} ('{}') needs at least two layers, got {depth}",
3056                input.atom_index, atom.name
3057            ));
3058        }
3059        if input.coords.len() != depth || input.topologies.len() != depth {
3060            return Err(format!(
3061                "atom transport ladder for atom {} ('{}') has {} layers, {} coordinate blocks, {} topologies",
3062                input.atom_index,
3063                atom.name,
3064                depth,
3065                input.coords.len(),
3066                input.topologies.len()
3067            ));
3068        }
3069
3070        let mut coords = Vec::with_capacity(depth);
3071        let mut topologies = Vec::with_capacity(depth);
3072        for (layer_pos, (coord, topology)) in
3073            input.coords.iter().zip(input.topologies.iter()).enumerate()
3074        {
3075            coords.push(canonical_coords_for_transport(
3076                coord,
3077                topology,
3078                input.atom_index,
3079                &atom.name,
3080                input.layers[layer_pos],
3081            )?);
3082            topologies.push(ChartTopology::from(topology));
3083        }
3084
3085        let report = transport_ladder(&input.layers, &coords, &topologies).map_err(|e| {
3086            format!(
3087                "atom transport ladder for atom {} ('{}') failed: {e}",
3088                input.atom_index, atom.name
3089            )
3090        })?;
3091        out.push(AtomTransportLadderReport {
3092            atom_index: input.atom_index,
3093            atom_name: atom.name.clone(),
3094            report,
3095        });
3096    }
3097    Ok(out)
3098}
3099
3100fn canonical_coords_for_transport(
3101    coords: &Array1<f64>,
3102    topology: &CanonicalChartTopology,
3103    atom_index: usize,
3104    atom_name: &str,
3105    layer: usize,
3106) -> Result<Array1<f64>, String> {
3107    if coords.iter().any(|v| !v.is_finite()) {
3108        return Err(format!(
3109            "atom transport ladder for atom {atom_index} ('{atom_name}') layer {layer} has non-finite coordinates"
3110        ));
3111    }
3112    match topology {
3113        CanonicalChartTopology::Circle { period } => {
3114            if !(period.is_finite() && *period > 0.0) {
3115                return Err(format!(
3116                    "atom transport ladder for atom {atom_index} ('{atom_name}') layer {layer} has invalid circle period {period}"
3117                ));
3118            }
3119            Ok(coords.mapv(|t| (t / *period) * TAU))
3120        }
3121        CanonicalChartTopology::Interval => Ok(coords.clone()),
3122    }
3123}
3124
3125// ----------------------------------------------------------------------------
3126// #1102 cross-checkpoint atom-dynamics FFI entry (new top-level block).
3127// ----------------------------------------------------------------------------
3128
3129/// Run #1102 cross-checkpoint Riesz-debiased atom-trajectory dynamics for the
3130/// fitted dictionary's atoms.
3131///
3132/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]` and
3133/// `atom_names`/`checkpoint_ids`/`latent_grid` label its axes; see
3134/// [`crate::inference::checkpoint_dynamics`] for the estimator and the honest
3135/// accounting of which Riesz inputs the bare grid supports. This entry binds
3136/// the atom axis to the fitted model: `atom_names` must name exactly the
3137/// model's atoms in order, so trajectories are reported against real atoms.
3138pub fn atom_checkpoint_dynamics(
3139    model: &FittedSaeManifold,
3140    decoder_grid: ndarray::ArrayView4<'_, f64>,
3141    checkpoint_ids: &[String],
3142    atom_names: &[String],
3143    latent_grid: ArrayView1<'_, f64>,
3144) -> Result<Vec<crate::inference::checkpoint_dynamics::AtomTrajectory>, String> {
3145    if atom_names.len() != model.atoms.len() {
3146        return Err(format!(
3147            "atom_checkpoint_dynamics: {} atom names supplied for {} fitted atoms",
3148            atom_names.len(),
3149            model.atoms.len()
3150        ));
3151    }
3152    for (idx, (supplied, fitted)) in atom_names.iter().zip(model.atoms.iter()).enumerate() {
3153        if supplied != &fitted.name {
3154            return Err(format!(
3155                "atom_checkpoint_dynamics: atom {idx} name '{supplied}' does not match fitted atom '{}'",
3156                fitted.name
3157            ));
3158        }
3159    }
3160    crate::inference::checkpoint_dynamics::checkpoint_atom_dynamics(
3161        &crate::inference::checkpoint_dynamics::CheckpointDynamicsInput {
3162            decoder_grid,
3163            checkpoint_ids,
3164            atom_names,
3165            latent_grid,
3166        },
3167    )
3168}
3169
3170#[cfg(test)]
3171mod tests {
3172    use super::*;
3173    use ndarray::{Array1, array};
3174
3175    /// #1097: the per-atom penalty-debiased functional point summaries must
3176    /// reproduce the exact linear functionals of the fitted decoder smooth
3177    /// (plug-in) and a finite debiased value, on a synthetic atom whose inner
3178    /// smooth is an analytic polynomial. No SE/CI is asserted — none is reported
3179    /// (#1115).
3180    #[test]
3181    fn atom_functional_report_recovers_known_functionals() {
3182        use ndarray::{Array1 as A1, Array2 as A2};
3183        // Polynomial basis Φ(t) = [1, t, t²] on a uniform active grid; the atom's
3184        // fitted smooth is g(t) = β·Φ(t) with a known β. We assemble a genuine
3185        // penalized-WLS AtomInnerFit (unit weights, identity-ish penalty) so the
3186        // Riesz path runs end to end.
3187        let n = 40usize;
3188        let m = 3usize;
3189        let beta = A1::from(vec![0.5_f64, -1.0, 2.0]);
3190        let mut design = A2::<f64>::zeros((n, m));
3191        let mut derivative_design = A2::<f64>::zeros((n, m));
3192        let mut weights = A1::<f64>::ones(n);
3193        let mut t = vec![0.0_f64; n];
3194        for i in 0..n {
3195            let ti = i as f64 / (n - 1) as f64;
3196            t[i] = ti;
3197            design[[i, 0]] = 1.0;
3198            design[[i, 1]] = ti;
3199            design[[i, 2]] = ti * ti;
3200            // dΦ/dt = [0, 1, 2t].
3201            derivative_design[[i, 0]] = 0.0;
3202            derivative_design[[i, 1]] = 1.0;
3203            derivative_design[[i, 2]] = 2.0 * ti;
3204            weights[i] = 1.0;
3205        }
3206        let dispersion = 1.0_f64;
3207        // Working response equals the fitted curve so residuals are zero → the
3208        // plug-in is exactly the analytic functional of β; scores are zero.
3209        let row_scores = A2::<f64>::zeros((n, m));
3210        // Penalty S = small ridge on curvature column only; penalized Hessian
3211        // H = ΦᵀWΦ + S.
3212        let mut penalty = A2::<f64>::zeros((m, m));
3213        penalty[[2, 2]] = 1e-3;
3214        let mut xtwx = A2::<f64>::zeros((m, m));
3215        for i in 0..n {
3216            for a in 0..m {
3217                for b in 0..m {
3218                    xtwx[[a, b]] += weights[i] * design[[i, a]] * design[[i, b]];
3219                }
3220            }
3221        }
3222        let penalized_hessian = &xtwx + &penalty;
3223        // Peak: |g| largest; mode: pick endpoints to give a known contrast.
3224        let mut peak_slot = 0usize;
3225        let mut peak_val = -1.0;
3226        for i in 0..n {
3227            let g = design.row(i).dot(&beta).abs();
3228            if g > peak_val {
3229                peak_val = g;
3230                peak_slot = i;
3231            }
3232        }
3233        let peak_design_row = design.row(peak_slot).to_owned();
3234        let mode_design_row = design.row(0).to_owned();
3235
3236        let fit = AtomInnerFit {
3237            design: design.clone(),
3238            derivative_design: derivative_design.clone(),
3239            beta: beta.clone(),
3240            penalty,
3241            penalized_hessian,
3242            row_scores,
3243            weights: weights.clone(),
3244            dispersion,
3245            peak_design_row: peak_design_row.clone(),
3246            mode_design_row: mode_design_row.clone(),
3247        };
3248
3249        let report = atom_functional_report(&fit);
3250
3251        // Average value E_w[g] = mean_i β·Φ(t_i): exact plug-in match.
3252        let av = report.average_value.expect("average value");
3253        let expected_av: f64 = (0..n).map(|i| design.row(i).dot(&beta)).sum::<f64>() / n as f64;
3254        assert!(
3255            (av.theta_plugin - expected_av).abs() < 1e-9,
3256            "average value plug-in {} vs expected {}",
3257            av.theta_plugin,
3258            expected_av
3259        );
3260        // Point summary only: the debiased value is finite (no SE/CI is
3261        // reported by design — #1115).
3262        assert!(
3263            av.theta_onestep.is_finite(),
3264            "average-value debiased finite"
3265        );
3266
3267        // Decoder-variation norm (conditional on fit): g'(t) = β1 + 2β2 t, mean
3268        // over the grid is β1 + 2β2 * mean(t). The functional gradient is the
3269        // mean derivative row; its plug-in is exactly that scalar. This is the
3270        // descriptive variation of the fitted curve, not a population marginal
3271        // slope.
3272        let ad = report
3273            .decoder_variation_norm
3274            .expect("decoder variation norm");
3275        let mean_t: f64 = t.iter().sum::<f64>() / n as f64;
3276        let expected_ad = beta[1] + 2.0 * beta[2] * mean_t;
3277        assert!(
3278            (ad.theta_plugin - expected_ad).abs() < 1e-9,
3279            "decoder variation plug-in {} vs expected {}",
3280            ad.theta_plugin,
3281            expected_ad
3282        );
3283
3284        // Peak-vs-mode contrast g(t_peak) − g(t_mode): exact plug-in.
3285        let pc = report.peak_contrast.expect("peak contrast");
3286        let expected_pc = peak_design_row.dot(&beta) - mode_design_row.dot(&beta);
3287        assert!(
3288            (pc.theta_plugin - expected_pc).abs() < 1e-9,
3289            "peak contrast plug-in {} vs expected {}",
3290            pc.theta_plugin,
3291            expected_pc
3292        );
3293    }
3294
3295    #[test]
3296    fn mechanism_sparsity_jacobian_value_matches_closed_form() {
3297        let w = array![[3.0_f64, 0.0], [4.0, 0.0]]; // col0 norm=5, col1 norm=0
3298        let pen = MechanismSparsityJacobian::new(1.0, 1.0e-8).unwrap();
3299        let (v, _g) = pen.value_and_grad(w.view());
3300        assert!((v - 5.0).abs() < 1e-6, "value {v} expected ≈5");
3301    }
3302
3303    #[test]
3304    fn mechanism_sparsity_jacobian_grad_matches_finite_diff() {
3305        let w = array![[0.5_f64, -1.2, 0.3], [1.1, 0.4, -0.7]];
3306        let pen = MechanismSparsityJacobian::new(2.5, 1.0e-6).unwrap();
3307        let (_, g) = pen.value_and_grad(w.view());
3308        let h = 1.0e-5;
3309        for i in 0..w.nrows() {
3310            for j in 0..w.ncols() {
3311                let mut wp = w.clone();
3312                let mut wm = w.clone();
3313                wp[[i, j]] += h;
3314                wm[[i, j]] -= h;
3315                let (vp, _) = pen.value_and_grad(wp.view());
3316                let (vm, _) = pen.value_and_grad(wm.view());
3317                let fd = (vp - vm) / (2.0 * h);
3318                assert!(
3319                    (g[[i, j]] - fd).abs() < 1e-4,
3320                    "grad[{i},{j}] = {} vs fd {}",
3321                    g[[i, j]],
3322                    fd
3323                );
3324            }
3325        }
3326    }
3327
3328    #[test]
3329    fn mechanism_sparsity_jacobian_rejects_bad_input() {
3330        assert!(MechanismSparsityJacobian::new(-1.0, 1e-6).is_err());
3331        assert!(MechanismSparsityJacobian::new(1.0, 0.0).is_err());
3332    }
3333
3334    #[test]
3335    fn frame_inner_rotation_dim_is_sum_of_so_r_dims() {
3336        // dim O(r) = r(r−1)/2 per factored atom; rank-1 frames contribute 0.
3337        assert_eq!(frame_inner_rotation_dim(&[]), 0);
3338        assert_eq!(frame_inner_rotation_dim(&[1]), 0);
3339        assert_eq!(frame_inner_rotation_dim(&[2]), 1);
3340        assert_eq!(frame_inner_rotation_dim(&[4]), 6);
3341        assert_eq!(frame_inner_rotation_dim(&[1, 4, 8]), 0 + 6 + 28);
3342        assert_eq!(
3343            FrameInnerRotationGauge::from_ranks(vec![3, 3]).dim,
3344            6,
3345            "two rank-3 frames carry 2·3 inner-rotation dims"
3346        );
3347    }
3348
3349    /// The #972 inner-rotation gauge is enumerated in the certificate, never
3350    /// curvature-tested: attaching it must not change any generator verdict
3351    /// or the residual_gauge_dim, but it MUST change the group signature and
3352    /// the summary — two replicate frame-factored fits agree on their gauge
3353    /// iff they also agree on this enumerated, convention-fixed part.
3354    #[test]
3355    fn frame_inner_rotation_attaches_to_the_certificate_without_verdict_change() {
3356        let base = ResidualGaugeReport {
3357            metric_provenance: MetricProvenance::Euclidean,
3358            generators: Vec::new(),
3359            pinning_rank: 5,
3360            residual_gauge_dim: 0,
3361            diffeomorphism_unpinned: false,
3362            sym_f_trivial_under_output_fisher: None,
3363            frame_inner_rotation: None,
3364            summary: "base".to_string(),
3365        };
3366        let sig_before = base.group_signature();
3367        let report = base.with_frame_inner_rotation(vec![1, 4, 8]);
3368        assert_eq!(
3369            report.frame_inner_rotation,
3370            Some(FrameInnerRotationGauge {
3371                per_atom_ranks: vec![1, 4, 8],
3372                dim: 34,
3373            })
3374        );
3375        // Verdict-side facts untouched.
3376        assert_eq!(report.residual_gauge_dim, 0);
3377        assert!(report.generators.is_empty());
3378        // Signature and summary carry the enumeration.
3379        let sig_after = report.group_signature();
3380        assert_ne!(sig_before, sig_after);
3381        assert!(sig_after.contains("frame-inner"), "got: {sig_after}");
3382        assert!(sig_after.contains("dim 34"), "got: {sig_after}");
3383        assert!(sig_after.contains("canonical-fixed"), "got: {sig_after}");
3384        assert!(report.summary.contains("inner-rotation gauge"));
3385
3386        // A dictionary of rank-1 atoms has a zero-dimensional inner gauge:
3387        // enumerated (Some), but the signature is unchanged — there is
3388        // nothing to fix beyond the orientation sign convention.
3389        let trivial = ResidualGaugeReport {
3390            metric_provenance: MetricProvenance::Euclidean,
3391            generators: Vec::new(),
3392            pinning_rank: 0,
3393            residual_gauge_dim: 0,
3394            diffeomorphism_unpinned: false,
3395            sym_f_trivial_under_output_fisher: None,
3396            frame_inner_rotation: None,
3397            summary: "base".to_string(),
3398        };
3399        let sig_trivial_before = trivial.group_signature();
3400        let trivial = trivial.with_frame_inner_rotation(vec![1, 1, 1]);
3401        assert_eq!(
3402            trivial.frame_inner_rotation.as_ref().map(|g| g.dim),
3403            Some(0)
3404        );
3405        assert_eq!(trivial.group_signature(), sig_trivial_before);
3406        assert_eq!(trivial.summary, "base");
3407    }
3408
3409    /// Build a `(n, d)` `(mean, scale)` pair whose stacked signature
3410    /// `[μ ‖ log σ]` has full rank `2d` (so it satisfies the Khemakhem
3411    /// Theorem 1 precondition baked into `ConditionalPriorIvae::new`).
3412    ///
3413    /// Each per-column function is given a distinct *frequency* (not a
3414    /// shared frequency with a column-dependent phase) so the resulting
3415    /// `2d` columns are genuinely linearly independent. `sin(ω·t + φ)`
3416    /// with a shared `ω` lives in the 2-dimensional span of `{sin(ω t),
3417    /// cos(ω t)}`, so the earlier `sin(0.7t + 0.3c)` / `cos(0.5t + 0.9c)`
3418    /// fixture only ever produced rank `≤ 4`, no matter how many `d`
3419    /// columns it built. Distinct frequencies push each column into its
3420    /// own subspace, so for `n ≥ 2d + 1` the SVD of `[μ ‖ log σ]` has
3421    /// `2d` non-trivial singular values.
3422    fn ivae_precondition_pair(n: usize, d: usize) -> (Array2<f64>, Array2<f64>) {
3423        assert!(n >= 2 * d + 1, "need at least 2d+1 rows");
3424        let mut mean = Array2::<f64>::zeros((n, d));
3425        let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
3426        for r in 0..n {
3427            let t = r as f64 / (n as f64 - 1.0);
3428            for c in 0..d {
3429                let omega = (c + 1) as f64;
3430                mean[[r, c]] = (std::f64::consts::PI * omega * t).sin();
3431                scale[[r, c]] = (0.4 * (std::f64::consts::PI * omega * t).cos()).exp();
3432            }
3433        }
3434        (mean, scale)
3435    }
3436
3437    #[test]
3438    fn conditional_prior_ivae_zero_mean_unit_scale_matches_standard_gaussian() {
3439        // Use varying (μ, log σ) so the identifiability precondition holds,
3440        // then evaluate at a `t` that matches `μ` to recover the closed-form
3441        // Gaussian normaliser ½·n·d·log 2π + Σ log σ.
3442        let n = 7;
3443        let d = 3;
3444        let (mean, scale) = ivae_precondition_pair(n, d);
3445        let t = mean.clone();
3446        let log_norm: f64 = scale.iter().map(|s| s.ln()).sum();
3447        let pen = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap();
3448        let (v, g) = pen.value_and_grad(t.view());
3449        let expected = log_norm + 0.5 * (n * d) as f64 * (2.0 * std::f64::consts::PI).ln();
3450        assert!(
3451            (v - expected).abs() < 1e-9,
3452            "value {v} vs expected {expected}"
3453        );
3454        for &gv in g.iter() {
3455            assert!(gv.abs() < 1e-12);
3456        }
3457    }
3458
3459    #[test]
3460    fn conditional_prior_ivae_grad_matches_finite_diff() {
3461        let (mean, scale) = ivae_precondition_pair(5, 2);
3462        let mut t = mean.clone();
3463        for r in 0..5 {
3464            t[[r, 0]] += 0.4;
3465            t[[r, 1]] -= 0.3;
3466        }
3467        let pen = ConditionalPriorIvae::new(mean, scale, 1.7).unwrap();
3468        let (_, g) = pen.value_and_grad(t.view());
3469        let h = 1.0e-5;
3470        for i in 0..t.nrows() {
3471            for j in 0..t.ncols() {
3472                let mut tp = t.clone();
3473                let mut tm = t.clone();
3474                tp[[i, j]] += h;
3475                tm[[i, j]] -= h;
3476                let vp = pen.value(tp.view());
3477                let vm = pen.value(tm.view());
3478                let fd = (vp - vm) / (2.0 * h);
3479                assert!((g[[i, j]] - fd).abs() < 1e-5);
3480            }
3481        }
3482    }
3483
3484    #[test]
3485    fn conditional_prior_ivae_rejects_nonpositive_scale() {
3486        let mean = Array2::<f64>::zeros((2, 2));
3487        let mut scale = Array2::<f64>::ones((2, 2));
3488        scale[[0, 0]] = -0.1;
3489        assert!(ConditionalPriorIvae::new(mean, scale, 1.0).is_err());
3490    }
3491
3492    #[test]
3493    fn conditional_prior_ivae_accepts_when_signature_full_rank() {
3494        let (mean, scale) = ivae_precondition_pair(7, 3);
3495        let result = ConditionalPriorIvae::new(mean, scale, 1.0);
3496        assert!(
3497            result.is_ok(),
3498            "full-rank signature should satisfy Khemakhem Theorem 1, got {:?}",
3499            result.err(),
3500        );
3501    }
3502
3503    #[test]
3504    fn conditional_prior_ivae_rejects_trivial_constant_prior() {
3505        // All rows identical → unconditional N(μ, σ²), non-identifiable.
3506        let n = 9;
3507        let d = 3;
3508        let mean = Array2::<f64>::from_elem((n, d), 0.25);
3509        let scale = Array2::<f64>::from_elem((n, d), 1.5);
3510        let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3511        assert!(
3512            err.contains("trivial unconditional") && err.contains("Khemakhem"),
3513            "unexpected error: {err}"
3514        );
3515    }
3516
3517    #[test]
3518    fn conditional_prior_ivae_rejects_too_few_auxiliary_states() {
3519        // n_rows = 4, latent_dim = 3 → need ≥ 2·3+1 = 7 rows.
3520        let (full_mean, full_scale) = ivae_precondition_pair(7, 3);
3521        let mean = full_mean.slice(s![..4, ..]).to_owned();
3522        let scale = full_scale.slice(s![..4, ..]).to_owned();
3523        let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3524        assert!(
3525            err.contains("2k+1") && err.contains("Khemakhem"),
3526            "unexpected error: {err}"
3527        );
3528    }
3529
3530    #[test]
3531    fn conditional_prior_ivae_rejects_rank_deficient_signature() {
3532        // Enough rows (n = 9 ≥ 2·3+1 = 7) and rows are NOT all identical,
3533        // but the stacked [μ ‖ log σ] matrix lies in a strict subspace of
3534        // ℝ^{2d}: column 0 of μ equals column 0 of log σ, and columns 1,2
3535        // of both μ and σ are zero / one. So the signature has rank 1, far
3536        // below the required 2·3 = 6.
3537        let n = 9;
3538        let d = 3;
3539        let mut mean = Array2::<f64>::zeros((n, d));
3540        let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
3541        for r in 0..n {
3542            let v = ((r as f64) * 0.5).sin();
3543            mean[[r, 0]] = v;
3544            scale[[r, 0]] = v.exp(); // log σ column 0 = v = μ column 0
3545        }
3546        let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3547        assert!(
3548            err.contains("numerical rank") && err.contains("Khemakhem"),
3549            "unexpected error: {err}"
3550        );
3551    }
3552
3553    #[test]
3554    fn piecewise_linear_eval_endpoints_and_midpoint() {
3555        let coeffs = array![[0.0_f64, 10.0], [1.0, 20.0], [2.0, 30.0]];
3556        let u = Array1::from(vec![0.0, 0.5, 1.0]);
3557        let out = piecewise_linear_eval(u.view(), coeffs.view(), 0.0, 1.0);
3558        assert!((out[[0, 0]] - 0.0).abs() < 1e-12);
3559        assert!((out[[1, 0]] - 1.0).abs() < 1e-12);
3560        assert!((out[[2, 0]] - 2.0).abs() < 1e-12);
3561        assert!((out[[1, 1]] - 20.0).abs() < 1e-12);
3562    }
3563
3564    #[test]
3565    fn select_weights_picks_max_evidence() {
3566        let rss = array![[10.0, 9.0, 9.5], [8.0, 4.0, 5.0], [9.0, 6.0, 7.0]];
3567        let pen = Array2::<f64>::zeros((3, 3));
3568        let l1 = Array1::from(vec![0.1, 1.0, 10.0]);
3569        let l2 = Array1::from(vec![0.1, 1.0, 10.0]);
3570        let res =
3571            identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 80)
3572                .unwrap();
3573        assert_eq!((res.best_i, res.best_j), (1, 1));
3574        assert!((res.best_lam1 - 1.0).abs() < 1e-12);
3575        assert!((res.best_lam2 - 1.0).abs() < 1e-12);
3576        assert!(res.best_evidence.is_finite());
3577    }
3578
3579    #[test]
3580    fn select_weights_breaks_ties_by_smallest_log_weight_sum() {
3581        let rss = Array2::<f64>::from_elem((2, 2), 4.0);
3582        let pen = Array2::<f64>::from_elem((2, 2), 1.0);
3583        let l1 = Array1::from(vec![0.1, 10.0]);
3584        let l2 = Array1::from(vec![0.1, 10.0]);
3585        let res =
3586            identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
3587                .unwrap();
3588        assert_eq!((res.best_i, res.best_j), (0, 0));
3589    }
3590
3591    #[test]
3592    fn select_weights_rejects_shape_mismatch() {
3593        let rss = Array2::<f64>::zeros((2, 3));
3594        let pen = Array2::<f64>::zeros((2, 2));
3595        let l1 = Array1::from(vec![1.0, 1.0]);
3596        let l2 = Array1::from(vec![1.0, 1.0, 1.0]);
3597        let err =
3598            identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
3599                .unwrap_err();
3600        assert!(err.contains("penalty_grid"));
3601    }
3602
3603    #[test]
3604    fn partial_supervision_procrustes_recovers_rotation_and_orthogonalizes_free() {
3605        // Construct a known orthogonal rotation Q, supervised slice = aux @ Qᵀ.
3606        let aux = array![
3607            [1.0_f64, 0.0, 0.0],
3608            [0.0, 1.0, 0.0],
3609            [0.0, 0.0, 1.0],
3610            [1.0, 1.0, 0.0],
3611            [-1.0, 1.0, 2.0],
3612        ];
3613        // 90° rotation in the (0,1) plane.
3614        let q = array![[0.0_f64, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]];
3615        let t_sup = aux.dot(&q.t());
3616        let t_free = array![
3617            [1.5_f64, 0.0],
3618            [0.0, 1.0],
3619            [-1.0, 2.0],
3620            [0.3, -0.7],
3621            [2.0, 1.0],
3622        ];
3623        let result = partial_supervision_solve(
3624            t_sup.view(),
3625            aux.view(),
3626            t_free.view(),
3627            PartialSupervisionSupMethod::Procrustes,
3628            &[],
3629            PartialSupervisionFreeConstraint::OrthogonalToSup,
3630        )
3631        .expect("procrustes solve should succeed");
3632        // Aligned supervised block should equal aux exactly (noise-free).
3633        for r in 0..aux.nrows() {
3634            for c in 0..aux.ncols() {
3635                assert!(
3636                    (result.t_supervised[[r, c]] - aux[[r, c]]).abs() < 1.0e-10,
3637                    "sup[{r},{c}] = {} vs aux {}",
3638                    result.t_supervised[[r, c]],
3639                    aux[[r, c]]
3640                );
3641            }
3642        }
3643        // Cross-Gram T_freeᵀ T_sup should be near zero after orthogonalization.
3644        let cross = result.t_free.t().dot(&result.t_supervised);
3645        let frob: f64 = cross.iter().map(|x| x * x).sum::<f64>().sqrt();
3646        assert!(frob < 1.0e-8, "cross frobenius = {frob}");
3647        assert!(result.alignment_score > 1.0 - 1.0e-10);
3648        assert!(result.map_r.is_some());
3649    }
3650
3651    #[test]
3652    fn partial_supervision_anchor_pins_exact_anchors_when_full_rank() {
3653        let aux = array![[1.0_f64, 2.0], [-1.0, 0.5], [3.0, -2.0], [0.7, 1.2],];
3654        let t_sup = array![[0.5_f64, 1.0], [-0.5, 0.25], [1.5, -1.0], [0.35, 0.6],];
3655        let t_free = Array2::<f64>::zeros((4, 1));
3656        let result = partial_supervision_solve(
3657            t_sup.view(),
3658            aux.view(),
3659            t_free.view(),
3660            PartialSupervisionSupMethod::Anchor,
3661            &[0, 1, 2],
3662            PartialSupervisionFreeConstraint::None,
3663        )
3664        .expect("anchor solve should succeed");
3665        for &row in &[0, 1, 2] {
3666            for c in 0..2 {
3667                assert!(
3668                    (result.t_supervised[[row, c]] - aux[[row, c]]).abs() < 1.0e-9,
3669                    "anchor row {row} col {c} not pinned: {} vs {}",
3670                    result.t_supervised[[row, c]],
3671                    aux[[row, c]]
3672                );
3673            }
3674        }
3675        assert!(result.map_a.is_some() && result.map_b.is_some());
3676    }
3677
3678    #[test]
3679    fn partial_supervision_softl2_selects_a_finite_weight() {
3680        let aux = array![
3681            [1.0_f64, 0.0],
3682            [0.0, 1.0],
3683            [1.0, 1.0],
3684            [-1.0, 1.0],
3685            [0.5, -0.5],
3686        ];
3687        let t_sup = array![
3688            [1.0_f64, 0.1],
3689            [0.1, 1.0],
3690            [1.0, 1.0],
3691            [-1.0, 1.0],
3692            [0.5, -0.5],
3693        ];
3694        let t_free = array![[0.5_f64], [0.5], [0.5], [0.5], [0.5]];
3695        let result = partial_supervision_solve(
3696            t_sup.view(),
3697            aux.view(),
3698            t_free.view(),
3699            PartialSupervisionSupMethod::SoftL2,
3700            &[],
3701            PartialSupervisionFreeConstraint::OrthogonalToSup,
3702        )
3703        .expect("soft_l2 solve should succeed");
3704        let lam = result.selected_weight.unwrap();
3705        assert!(lam.is_finite() && lam > 0.0, "lam={lam}");
3706        assert!(result.map_a.is_some());
3707    }
3708}