gam_sae/identifiability.rs
1//! SAE identifiability primitives and partial-supervision gauge fixing.
2//!
3//! # Object 4 — the Certificate ([`residual_gauge`])
4//!
5//! The partial-supervision solver above *removes* gauge freedom by aligning to
6//! auxiliary supervision. The certificate answers the dual question: after a fit
7//! has converged, **which gauge group is the model identified up to?** It does
8//! so by running the same penalty-aware RRQR rank machinery the cross-block
9//! identifiability audit uses
10//! ([`gam_identifiability::audit::audit_identifiability`] /
11//! [`gam_linalg::faer_ndarray::rrqr_with_permutation`]) — but on the
12//! **symmetry generators** of the fitted model rather than on stacked design
13//! columns.
14//!
15//! Each candidate symmetry of the SAE-manifold model (an isometry of an atom's
16//! latent manifold, a rotation inside an ARD-equal eigenspace, a rotation of the
17//! decoder output frame, an exchange of two topology-identical atoms) is
18//! realised as a **tangent direction** `ξ` in the model's free-parameter space.
19//! A generator is an *unpinned residual gauge freedom* iff the converged
20//! objective is flat along it — i.e. `ξ` lies in the null space of the total
21//! curvature operator `H = H_data + H_isometry` (data/likelihood curvature plus
22//! the isometry-penalty curvature). It is *pinned* (broken by the data or the
23//! isometry penalty) iff `ξ` has a component in `range(H)`.
24//!
25//! The RRQR supplies the pinning RANK via the same penalty-aware,
26//! leverage-scaled rank decision the audit uses. Each generator's verdict,
27//! however, keeps the curvature **magnitudes**: the relative curvature
28//! fraction `‖R ξ̂‖² / σ_max(R)²` measures how much objective curvature the
29//! unit generator carries, relative to the model's stiffest direction. A
30//! generator is **unpinned** iff that fraction is within the calibrated
31//! tolerance `max(`[`GENERATOR_FLAT_ENERGY_TOL`]`, lowering_error_scale)` —
32//! genuinely flat up to numerical noise and up to the mean-frame lowering's
33//! own resolution ([`FittedAtom::lowering_error`], #995). Anything larger
34//! means the orbit costs objective, so the exact symmetry is broken and the
35//! generator is **pinned** — including the *mixed* case (partly curved,
36//! partly flat), where replicate fits do NOT differ by that group element
37//! even though some flat directions remain nearby. Magnitudes (not span
38//! membership) keep the statistic informative when `range(H)` is full-rank,
39//! which production fits always are. The fraction and the calibration scale
40//! are reported per generator so partial flatness stays visible instead of
41//! being collapsed into the boolean.
42//!
43//! The whole computation is performed in the inner product carried by the fit's
44//! [`gam_problem::RowMetric`]: the curvature root `R` is built
45//! from the metric-whitened Jacobian, so the certificate's "computed in metric
46//! X" line reads straight off [`gam_problem::RowMetric::provenance`]
47//! ([`gam_problem::MetricProvenance`]) and cannot misreport —
48//! there is only one metric object.
49
50use crate::inference::layer_transport::{ChartTopology, TransportLadderReport, transport_ladder};
51use crate::inference::probe_runner::{ProbeRunner, RealizedProbe};
52use crate::inference::riesz::{RieszInput, SmoothFunctional, debias_with_dense_hessian};
53use gam_problem::{MetricProvenance, RowMetric};
54use gam_terms::inference::structure_evidence::{StructureCertificate, StructureLedger};
55use gam_linalg::faer_ndarray::{
56 FaerCholesky, FaerEigh, FaerQr, FaerSvd, default_rrqr_rank_alpha, rrqr_with_permutation,
57};
58use crate::chart_canonicalization::CanonicalChartTopology;
59use crate::manifold::SaeManifoldTerm;
60use faer::Side;
61use ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, s};
62use std::f64::consts::TAU;
63
64/// Smoothed column-2-norm of the decoder Jacobian.
65///
66/// Returns `(value, grad)` where `value = Σ_k √(Σ_d W[d,k]² + ε²) − ε`
67/// scaled by `weight`, and `grad[d, k] = weight · W[d, k] / √(Σ_d W[d,k]² + ε²)`.
68#[derive(Debug, Clone)]
69pub struct MechanismSparsityJacobian {
70 pub weight: f64,
71 pub epsilon: f64,
72}
73
74impl MechanismSparsityJacobian {
75 pub fn new(weight: f64, epsilon: f64) -> Result<Self, String> {
76 if !(weight.is_finite() && weight > 0.0) {
77 return Err(format!(
78 "MechanismSparsityJacobian: weight must be finite and >0, got {weight}"
79 ));
80 }
81 if !(epsilon.is_finite() && epsilon > 0.0) {
82 return Err(format!(
83 "MechanismSparsityJacobian: epsilon must be finite and >0, got {epsilon}"
84 ));
85 }
86 Ok(Self { weight, epsilon })
87 }
88
89 /// Evaluate value and gradient on a (d_obs, k_latent) decoder weight matrix.
90 pub fn value_and_grad(&self, w: ArrayView2<f64>) -> (f64, Array2<f64>) {
91 let (d, k) = w.dim();
92 let eps2 = self.epsilon * self.epsilon;
93 let mut grad = Array2::<f64>::zeros((d, k));
94 let mut value = 0.0;
95 for col in 0..k {
96 let mut sq = 0.0;
97 for row in 0..d {
98 sq += w[[row, col]] * w[[row, col]];
99 }
100 let denom = (sq + eps2).sqrt();
101 value += denom - self.epsilon;
102 let factor = self.weight / denom;
103 for row in 0..d {
104 grad[[row, col]] = factor * w[[row, col]];
105 }
106 }
107 (self.weight * value, grad)
108 }
109
110 /// Diagonal of the Hessian wrt vec(W). Used as a Newton preconditioner.
111 pub fn hessian_diag(&self, w: ArrayView2<f64>) -> Array2<f64> {
112 let (d, k) = w.dim();
113 let eps2 = self.epsilon * self.epsilon;
114 let mut out = Array2::<f64>::zeros((d, k));
115 for col in 0..k {
116 let mut sq = 0.0;
117 for row in 0..d {
118 sq += w[[row, col]] * w[[row, col]];
119 }
120 let denom = (sq + eps2).sqrt();
121 let inv = 1.0 / denom;
122 let inv3 = inv * inv * inv;
123 for row in 0..d {
124 // ∂² / ∂W[d,k]² of √(||·||²+ε²) = 1/r − W[d,k]²/r³
125 out[[row, col]] = self.weight * (inv - w[[row, col]] * w[[row, col]] * inv3);
126 }
127 }
128 out
129 }
130}
131
132/// iVAE-style auxiliary-conditional Gaussian log-prior on the latent block.
133///
134/// Stores per-row conditional means `μ` of shape `(n_rows, latent_dim)` and
135/// scales `σ` of shape `(n_rows, latent_dim)`, where `(μ_{n,i}, σ_{n,i})` are
136/// presumed evaluated by some external Smooth at the auxiliary `u_n`. The
137/// negative log-prior contribution to the latent objective is
138///
139/// `½ Σ_n Σ_i [ ((t_{n,i} − μ_{n,i}) / σ_{n,i})²
140/// + 2 log σ_{n,i} + log 2π ]`
141///
142/// scaled by `weight`. The gradient w.r.t. `t` is `(t − μ) / σ²` (times
143/// `weight`); the gradient w.r.t. `μ` is its negative. Per-row scales make
144/// this strictly more general than a fixed `N(0, I)`, which is recovered by
145/// `μ ≡ 0`, `σ ≡ 1`.
146#[derive(Debug, Clone)]
147pub struct ConditionalPriorIvae {
148 pub mean: Array2<f64>,
149 pub scale: Array2<f64>,
150 pub weight: f64,
151}
152
153impl ConditionalPriorIvae {
154 pub fn new(mean: Array2<f64>, scale: Array2<f64>, weight: f64) -> Result<Self, String> {
155 if mean.dim() != scale.dim() {
156 return Err(format!(
157 "ConditionalPriorIvae: mean shape {:?} != scale shape {:?}",
158 mean.dim(),
159 scale.dim()
160 ));
161 }
162 if !(weight.is_finite() && weight > 0.0) {
163 return Err(format!(
164 "ConditionalPriorIvae: weight must be finite and >0, got {weight}"
165 ));
166 }
167 for &v in scale.iter() {
168 if !(v.is_finite() && v > 0.0) {
169 return Err(format!(
170 "ConditionalPriorIvae: every scale must be finite and >0, got {v}"
171 ));
172 }
173 }
174 for &v in mean.iter() {
175 if !v.is_finite() {
176 return Err("ConditionalPriorIvae: mean contains non-finite entry".to_string());
177 }
178 }
179
180 // Khemakhem et al. (arXiv:2107.10098) Theorem 1 identifiability
181 // precondition for the exponential-family conditional prior:
182 // the auxiliary index `u` must yield 2k+1 distinct conditional
183 // priors `p(t|u)` whose sufficient-statistic parameters
184 // `(η_1(u), η_2(u)) = (μ(u)/σ(u)², −1/(2σ(u)²))` span a
185 // 2k-dimensional set. For the diagonal Gaussian family this is
186 // equivalent (an invertible reparameterisation) to requiring that
187 // the stacked signature `S = [μ(u) ‖ log σ(u)]` of shape
188 // (n_rows, 2k) have rank 2k, with at least 2k+1 distinct rows.
189 let (n_rows, latent_dim) = mean.dim();
190 let needed_rows = 2 * latent_dim + 1;
191 if n_rows < needed_rows {
192 return Err(format!(
193 "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
194 precondition violated: need at least 2k+1 = {needed_rows} distinct \
195 auxiliary states for latent_dim k = {latent_dim}, got n_rows = {n_rows}"
196 ));
197 }
198 let signature = {
199 let mut s = Array2::<f64>::zeros((n_rows, 2 * latent_dim));
200 for r in 0..n_rows {
201 for c in 0..latent_dim {
202 s[[r, c]] = mean[[r, c]];
203 s[[r, latent_dim + c]] = scale[[r, c]].ln();
204 }
205 }
206 s
207 };
208 let first = signature.row(0).to_owned();
209 let all_identical = signature
210 .outer_iter()
211 .all(|row| row.iter().zip(first.iter()).all(|(a, b)| a == b));
212 if all_identical {
213 return Err(format!(
214 "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
215 precondition violated: all {n_rows} rows of the stacked auxiliary \
216 signature [μ ‖ log σ] are identical, so the conditional prior is the \
217 trivial unconditional N(μ, σ²) — provably non-identifiable (no \
218 auxiliary information)"
219 ));
220 }
221 let (_u, sv, _vt) = signature
222 .svd(false, false)
223 .map_err(|e| format!("ConditionalPriorIvae: SVD of auxiliary signature failed: {e}"))?;
224 let max_sv = sv.iter().cloned().fold(0.0_f64, f64::max);
225 let tol = max_sv * (n_rows.max(2 * latent_dim) as f64) * f64::EPSILON;
226 let numerical_rank = sv.iter().filter(|&&s| s > tol).count();
227 let required = 2 * latent_dim;
228 if numerical_rank < required {
229 return Err(format!(
230 "ConditionalPriorIvae: Khemakhem (arXiv:2107.10098) Theorem 1 \
231 precondition violated: stacked auxiliary signature [μ ‖ log σ] has \
232 numerical rank {numerical_rank} < 2·latent_dim = {required} \
233 (tolerance {tol:.3e}); the family `p(t|u)` does not span a \
234 2k-dimensional set of natural parameters"
235 ));
236 }
237
238 Ok(Self {
239 mean,
240 scale,
241 weight,
242 })
243 }
244
245 /// Evaluate negative-log-prior value and gradient w.r.t. latent t.
246 pub fn value_and_grad(&self, t: ArrayView2<f64>) -> (f64, Array2<f64>) {
247 assert_eq!(
248 t.dim(),
249 self.mean.dim(),
250 "ConditionalPriorIvae: t/mean shape mismatch"
251 );
252 let (n, d) = t.dim();
253 let log_2pi = (2.0 * std::f64::consts::PI).ln();
254 let mut grad = Array2::<f64>::zeros((n, d));
255 let mut value = 0.0;
256 for row in 0..n {
257 for col in 0..d {
258 let mu = self.mean[[row, col]];
259 let sigma = self.scale[[row, col]];
260 let z = (t[[row, col]] - mu) / sigma;
261 value += 0.5 * (z * z + 2.0 * sigma.ln() + log_2pi);
262 grad[[row, col]] = self.weight * z / sigma;
263 }
264 }
265 (self.weight * value, grad)
266 }
267
268 /// Evaluate value only — useful when only the loss is needed.
269 pub fn value(&self, t: ArrayView2<f64>) -> f64 {
270 self.value_and_grad(t).0
271 }
272}
273
274/// Helper: evaluate a piecewise-linear "smooth" `f(u)` columnwise, given a
275/// (k_centres, latent_dim) coefficient table and a (n_rows,) auxiliary vector
276/// `u`. Used by the Python wrapper to back the iVAE per-latent (μ_i(u), σ_i(u))
277/// without having to round-trip through gam's full Smooth machinery for the
278/// minimal experiments. Centres are assumed evenly spaced in [u_min, u_max].
279pub fn piecewise_linear_eval(
280 u: ArrayView1<f64>,
281 coeffs: ArrayView2<f64>,
282 u_min: f64,
283 u_max: f64,
284) -> Array2<f64> {
285 let (k, d) = coeffs.dim();
286 assert!(k >= 2, "piecewise_linear_eval: need ≥2 centres");
287 let n = u.len();
288 let mut out = Array2::<f64>::zeros((n, d));
289 let step = (u_max - u_min) / (k - 1) as f64;
290 for (row, &val) in u.iter().enumerate() {
291 // Clamp `pos` to the exact endpoint `(k-1)`, not `(k-1) - 1e-12`,
292 // so `val = u_max` evaluates to exactly `coeffs[k-1, col]` instead
293 // of `coeffs[k-1, col] + 1e-12 · (coeffs[k-2, col] − coeffs[k-1,
294 // col])`. The historical `1e-12` shift was there to keep `lo + 1`
295 // in range, but capping `lo` at `k − 2` achieves the same
296 // structural guarantee without perturbing the endpoint value.
297 let pos = ((val - u_min) / step).clamp(0.0, (k - 1) as f64);
298 let lo = (pos.floor() as usize).min(k - 2);
299 let hi = lo + 1;
300 let frac = pos - lo as f64;
301 for col in 0..d {
302 out[[row, col]] = coeffs[[lo, col]] * (1.0 - frac) + coeffs[[hi, col]] * frac;
303 }
304 }
305 out
306}
307
308/// Outcome of a 2D log-λ grid-search weight selection.
309///
310/// `evidence_grid[i, j]` is the Laplace-style log marginal-likelihood proxy
311/// at `(lam1_grid[i], lam2_grid[j])`:
312/// `evidence = −½ N log(RSS/N) − ½ (penalty)` with `RSS = rss_grid[i, j]`
313/// and `penalty = penalty_grid[i, j]`.
314///
315/// The winner is `argmax` over the grid; ties are broken by selecting the
316/// `(i, j)` with the smallest `i + j` (i.e. smallest log-weight sum on a
317/// log-spaced grid), then by smallest `i`, then smallest `j` — a fully
318/// deterministic, reproducible policy.
319#[derive(Debug, Clone)]
320pub struct WeightSearchResult {
321 pub best_i: usize,
322 pub best_j: usize,
323 pub best_lam1: f64,
324 pub best_lam2: f64,
325 pub best_evidence: f64,
326 pub evidence_grid: Array2<f64>,
327}
328
329/// Generic 2D log-λ weight-selection driver.
330///
331/// Given a precomputed `(G1, G2)` grid of residual sums-of-squares
332/// `rss_grid`, a matching grid of total-penalty values `penalty_grid`, and
333/// the two 1D weight grids `lam1_grid` / `lam2_grid`, computes the Laplace
334/// log marginal-likelihood proxy on every cell and returns the maximising
335/// cell with deterministic tie-breaking.
336///
337/// The primitive is intentionally agnostic to *what* the two penalty
338/// weights regularise — it takes only the RSS and penalty surfaces, so it
339/// can drive weight selection for any two-penalty model (identifiable
340/// factor model, double-penalty smooths, IBP + sparsity, etc.).
341pub fn identifiable_factor_select_weights(
342 rss_grid: ArrayView2<'_, f64>,
343 penalty_grid: ArrayView2<'_, f64>,
344 lam1_grid: ArrayView1<'_, f64>,
345 lam2_grid: ArrayView1<'_, f64>,
346 n_obs: usize,
347) -> Result<WeightSearchResult, String> {
348 let (g1, g2) = rss_grid.dim();
349 if penalty_grid.dim() != (g1, g2) {
350 return Err(format!(
351 "identifiable_factor_select_weights: penalty_grid shape {:?} \
352 must match rss_grid shape ({}, {})",
353 penalty_grid.dim(),
354 g1,
355 g2
356 ));
357 }
358 if lam1_grid.len() != g1 {
359 return Err(format!(
360 "identifiable_factor_select_weights: lam1_grid len {} must \
361 equal rss_grid rows {}",
362 lam1_grid.len(),
363 g1
364 ));
365 }
366 if lam2_grid.len() != g2 {
367 return Err(format!(
368 "identifiable_factor_select_weights: lam2_grid len {} must \
369 equal rss_grid cols {}",
370 lam2_grid.len(),
371 g2
372 ));
373 }
374 if g1 == 0 || g2 == 0 {
375 return Err("identifiable_factor_select_weights: grids must be non-empty".to_string());
376 }
377 if n_obs == 0 {
378 return Err("identifiable_factor_select_weights: n_obs must be > 0".to_string());
379 }
380 for v in rss_grid.iter() {
381 if !v.is_finite() || *v < 0.0 {
382 return Err(format!(
383 "identifiable_factor_select_weights: rss_grid contains non-finite or \
384 negative value {v}"
385 ));
386 }
387 }
388 for v in penalty_grid.iter() {
389 if !v.is_finite() {
390 return Err(format!(
391 "identifiable_factor_select_weights: penalty_grid contains non-finite value {v}"
392 ));
393 }
394 }
395 for v in lam1_grid.iter().chain(lam2_grid.iter()) {
396 if !v.is_finite() || *v <= 0.0 {
397 return Err(format!(
398 "identifiable_factor_select_weights: λ grids must contain finite positive \
399 values, got {v}"
400 ));
401 }
402 }
403
404 let n = n_obs as f64;
405 let rss_floor = 1.0e-300_f64;
406 let mut evidence_grid = Array2::<f64>::zeros((g1, g2));
407 let mut best: Option<(usize, usize, f64)> = None;
408 for i in 0..g1 {
409 for j in 0..g2 {
410 let rss = rss_grid[[i, j]];
411 let pen = penalty_grid[[i, j]];
412 let mean_sq = (rss / n).max(rss_floor);
413 let ev = -0.5 * n * mean_sq.ln() - 0.5 * pen;
414 evidence_grid[[i, j]] = ev;
415 let better = match best {
416 None => true,
417 Some((bi, bj, bev)) => {
418 if ev > bev {
419 true
420 } else if ev == bev {
421 let cur_sum = i + j;
422 let best_sum = bi + bj;
423 if cur_sum < best_sum {
424 true
425 } else if cur_sum == best_sum && i < bi {
426 true
427 } else {
428 cur_sum == best_sum && i == bi && j < bj
429 }
430 } else {
431 false
432 }
433 }
434 };
435 if better {
436 best = Some((i, j, ev));
437 }
438 }
439 }
440 let (best_i, best_j, best_evidence) = best.ok_or_else(|| {
441 "identifiable_factor_select_weights: empty search (this is a bug)".to_string()
442 })?;
443 Ok(WeightSearchResult {
444 best_i,
445 best_j,
446 best_lam1: lam1_grid[best_i],
447 best_lam2: lam2_grid[best_j],
448 best_evidence,
449 evidence_grid,
450 })
451}
452
453/// Column-centred thin-SVD scores: returns the leading `k` columns of
454/// `U Σ` for the centred predictor matrix `X − mean(X, axis=0)`.
455///
456/// Used to seed `T_init` for the partial-supervision recipe when the
457/// caller does not supply one. Pure-Rust path (faer SVD via the
458/// `FaerSvd` bridge) so the seeding math lives in the same crate as the
459/// gauge-fix solver.
460pub fn thin_svd_scores(x: ArrayView2<f64>, k: usize) -> Result<Array2<f64>, String> {
461 let (n, p) = x.dim();
462 if k == 0 {
463 return Ok(Array2::<f64>::zeros((n, 0)));
464 }
465 if k > n.min(p) {
466 return Err(format!(
467 "thin_svd_scores: requested {k} components but min(n={n}, p={p}) limits to {}",
468 n.min(p)
469 ));
470 }
471 let mut mean_row = Array1::<f64>::zeros(p);
472 for row in 0..n {
473 for col in 0..p {
474 mean_row[col] += x[[row, col]];
475 }
476 }
477 if n > 0 {
478 let inv_n = 1.0 / (n as f64);
479 for col in 0..p {
480 mean_row[col] *= inv_n;
481 }
482 }
483 let mut xc = Array2::<f64>::zeros((n, p));
484 for row in 0..n {
485 for col in 0..p {
486 xc[[row, col]] = x[[row, col]] - mean_row[col];
487 }
488 }
489 let (u_opt, sigma, _vt_opt) = xc
490 .svd(true, false)
491 .map_err(|e| format!("thin_svd_scores: SVD failed: {e}"))?;
492 let u = u_opt.ok_or_else(|| "thin_svd_scores: SVD did not return U".to_string())?;
493 let mut out = Array2::<f64>::zeros((n, k));
494 for row in 0..n {
495 for col in 0..k {
496 out[[row, col]] = u[[row, col]] * sigma[col];
497 }
498 }
499 Ok(out)
500}
501
502/// Method for tying the supervised block to the auxiliary signal.
503#[derive(Debug, Clone, Copy, PartialEq, Eq)]
504pub enum PartialSupervisionSupMethod {
505 /// Orthogonal Procrustes: `min_{RᵀR=I} ‖T_sup R - aux‖_F²`.
506 Procrustes,
507 /// Affine least-squares pinned to `anchor_idx`.
508 Anchor,
509 /// Ridge map `A_λ = (TᵀT + λI)⁻¹ Tᵀaux` with REML-selected λ.
510 SoftL2,
511}
512
513/// Free-block decorrelation rule.
514#[derive(Debug, Clone, Copy, PartialEq, Eq)]
515pub enum PartialSupervisionFreeConstraint {
516 /// QR-based projection onto the orthogonal complement of `col(T_sup)`.
517 OrthogonalToSup,
518 /// No projection.
519 None,
520}
521
522/// Result of [`partial_supervision_solve`].
523///
524/// `alignment_score = 1 - ‖T_sup_aligned - aux‖_F² / ‖aux‖_F²` for every
525/// method (1.0 = perfect, 0.0 = no better than the constant-zero predictor).
526/// The fitted gauge map lives in the variant-specific fields:
527///
528/// * Procrustes → `map_r = R` (`d × d` orthogonal).
529/// * Anchor → `map_a = A` (`d × d`), `map_b` (`d`).
530/// * SoftL2 → `map_a = A_λ` (`d × d`), `selected_weight = λ`.
531#[derive(Debug, Clone)]
532pub struct PartialSupervisionResult {
533 pub t_supervised: Array2<f64>,
534 pub t_free: Array2<f64>,
535 pub alignment_score: f64,
536 pub selected_weight: Option<f64>,
537 pub map_r: Option<Array2<f64>>,
538 pub map_a: Option<Array2<f64>>,
539 pub map_b: Option<Array1<f64>>,
540}
541
542/// Library-level partial-supervision gauge-fix solver.
543///
544/// Solves the supervised-block alignment problem and applies the chosen
545/// free-block decorrelation rule. Pure numerical linear algebra: SVD,
546/// symmetric eigendecomposition (`Side::Lower`), and thin QR are routed
547/// through the faer bridge in `gam_linalg::faer_ndarray`.
548///
549/// This is the single Rust source-of-math for the gauge-fix step; it is
550/// language-agnostic so the CLI, R, and Julia bindings can reuse it
551/// through their own marshaling layers.
552///
553/// Shape requirements:
554/// * `t_sup` is `(N, d_sup)`; `aux` must equal that shape.
555/// * `t_free` is `(N, d_free)` — `d_free` may be 0.
556/// * `anchor_idx` is consulted only when `method == Anchor`; it must be
557/// non-empty and every index must be `< N`.
558pub fn partial_supervision_solve(
559 t_sup: ArrayView2<f64>,
560 aux: ArrayView2<f64>,
561 t_free: ArrayView2<f64>,
562 method: PartialSupervisionSupMethod,
563 anchor_idx: &[usize],
564 free_constraint: PartialSupervisionFreeConstraint,
565) -> Result<PartialSupervisionResult, String> {
566 let (n, d_sup) = t_sup.dim();
567 if aux.dim() != (n, d_sup) {
568 return Err(format!(
569 "partial_supervision_solve: aux shape {:?} must equal t_sup shape ({}, {})",
570 aux.dim(),
571 n,
572 d_sup
573 ));
574 }
575 if t_free.nrows() != n {
576 return Err(format!(
577 "partial_supervision_solve: t_free has {} rows, expected {}",
578 t_free.nrows(),
579 n
580 ));
581 }
582 let aux_norm_sq: f64 = aux.iter().map(|x| x * x).sum();
583 if !(aux_norm_sq.is_finite() && aux_norm_sq > 0.0) {
584 return Err(
585 "partial_supervision_solve: aux has zero or non-finite Frobenius norm".to_string(),
586 );
587 }
588
589 let mut t_sup_aligned = Array2::<f64>::zeros((n, d_sup));
590 let mut map_r: Option<Array2<f64>> = None;
591 let mut map_a: Option<Array2<f64>> = None;
592 let mut map_b: Option<Array1<f64>> = None;
593 let mut selected_weight: Option<f64> = None;
594
595 match method {
596 PartialSupervisionSupMethod::Procrustes => {
597 // R = U Vᵀ where T_supᵀ aux = U Σ Vᵀ.
598 let m = t_sup.t().dot(&aux);
599 let (u_opt, _sigma, vt_opt) = m
600 .svd(true, true)
601 .map_err(|e| format!("partial_supervision_solve: Procrustes SVD failed: {e}"))?;
602 let u = u_opt
603 .ok_or_else(|| "partial_supervision_solve: SVD did not return U".to_string())?;
604 let vt = vt_opt
605 .ok_or_else(|| "partial_supervision_solve: SVD did not return Vᵀ".to_string())?;
606 let r = u.dot(&vt);
607 t_sup_aligned = t_sup.dot(&r);
608 map_r = Some(r);
609 }
610 PartialSupervisionSupMethod::Anchor => {
611 if anchor_idx.is_empty() {
612 return Err(
613 "partial_supervision_solve: anchor method requires anchor_idx with at \
614 least one row"
615 .to_string(),
616 );
617 }
618 for &idx in anchor_idx {
619 if idx >= n {
620 return Err(format!(
621 "partial_supervision_solve: anchor index {idx} out of bounds (n={n})"
622 ));
623 }
624 }
625 // Stack design [Ta | 1] of shape (m, d_sup+1); solve via SVD pseudo-inverse.
626 let m_rows = anchor_idx.len();
627 let mut design = Array2::<f64>::zeros((m_rows, d_sup + 1));
628 let mut targets = Array2::<f64>::zeros((m_rows, d_sup));
629 for (row_out, &row_in) in anchor_idx.iter().enumerate() {
630 for c in 0..d_sup {
631 design[[row_out, c]] = t_sup[[row_in, c]];
632 targets[[row_out, c]] = aux[[row_in, c]];
633 }
634 design[[row_out, d_sup]] = 1.0;
635 }
636 let (u_opt, sigma, vt_opt) = design
637 .svd(true, true)
638 .map_err(|e| format!("partial_supervision_solve: Anchor SVD failed: {e}"))?;
639 let u = u_opt
640 .ok_or_else(|| "partial_supervision_solve: anchor SVD lacked U".to_string())?;
641 let vt = vt_opt
642 .ok_or_else(|| "partial_supervision_solve: anchor SVD lacked Vᵀ".to_string())?;
643 // Tikhonov cutoff matches numpy.linalg.lstsq's default rcond policy.
644 let leading = sigma.iter().cloned().fold(0.0_f64, f64::max);
645 let cutoff = leading * f64::EPSILON * (m_rows.max(d_sup + 1) as f64);
646 let rank = sigma.len();
647 let ut_targets = u.t().dot(&targets);
648 let mut scaled = Array2::<f64>::zeros((rank, d_sup));
649 for r in 0..rank {
650 let s = sigma[r];
651 if s > cutoff {
652 let inv = 1.0 / s;
653 for c in 0..d_sup {
654 scaled[[r, c]] = inv * ut_targets[[r, c]];
655 }
656 }
657 }
658 let coef = vt.t().dot(&scaled);
659 let a = coef.slice(s![..d_sup, ..]).to_owned();
660 let b_vec = coef.slice(s![d_sup, ..]).to_owned();
661 for row in 0..n {
662 for c in 0..d_sup {
663 let mut acc = b_vec[c];
664 for k in 0..d_sup {
665 acc += t_sup[[row, k]] * a[[k, c]];
666 }
667 t_sup_aligned[[row, c]] = acc;
668 }
669 }
670 map_a = Some(a);
671 map_b = Some(b_vec);
672 }
673 PartialSupervisionSupMethod::SoftL2 => {
674 // Symmetric eigendecomposition of G = T_supᵀ T_sup.
675 let g = t_sup.t().dot(&t_sup);
676 let (eigvals, eigvecs) = g
677 .eigh(Side::Lower)
678 .map_err(|e| format!("partial_supervision_solve: eigh on Gram failed: {e}"))?;
679 let rhs = t_sup.t().dot(&aux);
680 let ut_aux = eigvecs.t().dot(&rhs);
681 // Per-eigenvector signal energy m_r = ‖row_r(Vᵀ Tᵀaux)‖²; the
682 // multi-response RSS at weight λ is then
683 // S(λ) = ‖aux‖_F² − Σ_r m_r/(γ_r+λ)
684 // with γ_r the eigenvalues of G = TᵀT (`eigvals`).
685 let m_row: Array1<f64> = Array1::from_vec(
686 (0..d_sup)
687 .map(|r| (0..d_sup).map(|c| ut_aux[[r, c]] * ut_aux[[r, c]]).sum())
688 .collect(),
689 );
690 let lam_max = eigvals.iter().cloned().fold(0.0_f64, f64::max);
691 let floor = (lam_max * 1.0e-10).max(1.0e-12);
692 let top = (lam_max * 1.0e3).max(floor * 1.0e6);
693 let grid_n: usize = 64;
694 let log_floor = floor.ln();
695 let log_top = top.ln();
696 // Select λ by REML, never GCV. The ridge map is the linear mixed
697 // model aux_j = T β_j + ε with β_j ~ N(0, σ²/λ I), ε ~ N(0, σ² I)
698 // applied to each of the d columns sharing λ. The map carries no
699 // unpenalized fixed effect, so REML coincides with the marginal
700 // likelihood, whose profile (σ² concentrated out) criterion to
701 // MINIMIZE is
702 // reml(λ) = n·log S(λ) + Σ_r log(1 + γ_r/λ),
703 // the exact analogue of the smoothing-parameter REML used
704 // everywhere else in gam.
705 let mut best_score = f64::INFINITY;
706 let mut best_lam = floor;
707 for k in 0..grid_n {
708 let frac = if grid_n == 1 {
709 0.0
710 } else {
711 (k as f64) / ((grid_n - 1) as f64)
712 };
713 let lam = (log_floor + frac * (log_top - log_floor)).exp();
714 let mut shrunk = 0.0_f64; // Σ_r m_r/(γ_r+λ)
715 let mut logdet = 0.0_f64; // Σ_r log(1 + γ_r/λ)
716 for r in 0..d_sup {
717 let g = eigvals[r].max(0.0);
718 shrunk += m_row[r] / (g + lam);
719 logdet += (1.0 + g / lam).ln();
720 }
721 let s = aux_norm_sq - shrunk;
722 if !(s.is_finite() && s > 0.0) {
723 continue;
724 }
725 let score = (n as f64) * s.ln() + logdet;
726 if score < best_score {
727 best_score = score;
728 best_lam = lam;
729 }
730 }
731 if !best_score.is_finite() {
732 return Err(
733 "partial_supervision_solve: REML grid did not find a finite-score weight"
734 .to_string(),
735 );
736 }
737 // Build the ridge map A_λ = (G + λI)⁻¹ Tᵀaux at the REML weight.
738 let denom: Array1<f64> = eigvals.mapv(|v| v + best_lam);
739 let mut a_eig = Array2::<f64>::zeros((d_sup, d_sup));
740 for r in 0..d_sup {
741 for c in 0..d_sup {
742 a_eig[[r, c]] = ut_aux[[r, c]] / denom[r];
743 }
744 }
745 let best_a = eigvecs.dot(&a_eig);
746 t_sup_aligned = t_sup.dot(&best_a);
747 map_a = Some(best_a);
748 selected_weight = Some(best_lam);
749 }
750 }
751
752 // Single source of truth for alignment_score.
753 let mut sq_resid = 0.0_f64;
754 for row in 0..n {
755 for c in 0..d_sup {
756 let r = t_sup_aligned[[row, c]] - aux[[row, c]];
757 sq_resid += r * r;
758 }
759 }
760 let alignment_score = 1.0 - sq_resid / aux_norm_sq;
761
762 let t_free_out = match free_constraint {
763 PartialSupervisionFreeConstraint::None => t_free.to_owned(),
764 PartialSupervisionFreeConstraint::OrthogonalToSup => {
765 if t_sup_aligned.ncols() == 0 || t_free.ncols() == 0 {
766 t_free.to_owned()
767 } else {
768 let qr_pair = t_sup_aligned
769 .qr()
770 .map_err(|e| format!("partial_supervision_solve: QR on T_sup failed: {e}"))?;
771 let q = qr_pair.0;
772 let qt_free = q.t().dot(&t_free);
773 let proj = q.dot(&qt_free);
774 let mut out = t_free.to_owned();
775 out -= &proj;
776 out
777 }
778 }
779 };
780
781 Ok(PartialSupervisionResult {
782 t_supervised: t_sup_aligned,
783 t_free: t_free_out,
784 alignment_score,
785 selected_weight,
786 map_r,
787 map_a,
788 map_b,
789 })
790}
791
792// ============================================================================
793// Object 4 — the Certificate: `residual_gauge()`
794// ============================================================================
795
796/// The latent-manifold topology of one fitted atom, as far as the certificate
797/// needs it to enumerate the atom's isometry-group generators. This mirrors the
798/// user-facing [`crate::manifold::SaeAtomBasisKind`] choice but
799/// carries only what is required to build `Isom(M_k)` tangent directions, so the
800/// certificate is decoupled from the full `SaeManifoldAtom` machinery.
801#[derive(Debug, Clone, PartialEq, Eq)]
802pub enum AtomTopology {
803 /// `S¹` (periodic 1-D). `Isom(S¹) = O(2)`: a single continuous rotation
804 /// generator (shift of the circular coordinate) plus a reflection.
805 Circle,
806 /// `S²` (intrinsic sphere chart). `Isom(S²) = O(3)`: three rotation
807 /// generators (so(3) basis) plus the antipodal/reflection component.
808 Sphere,
809 /// `Tᵈ` (product of `latent_dim` circles). `Isom` contains the `d`
810 /// independent circle shifts (a maximal torus of rotations).
811 Torus { latent_dim: usize },
812 /// A `latent_dim`-dimensional Euclidean patch / Duchon patch. Its connected
813 /// isometry group `SE(d)` is generated by `d` translations and
814 /// `d(d−1)/2` rotations of the latent coordinate frame.
815 EuclideanPatch { latent_dim: usize },
816}
817
818impl AtomTopology {
819 /// Intrinsic latent dimensionality of the atom's manifold.
820 fn latent_dim(&self) -> usize {
821 match self {
822 AtomTopology::Circle => 1,
823 AtomTopology::Sphere => 2,
824 AtomTopology::Torus { latent_dim } => *latent_dim,
825 AtomTopology::EuclideanPatch { latent_dim } => *latent_dim,
826 }
827 }
828}
829
830/// One fitted atom as the certificate sees it.
831///
832/// `frame` is the fitted decoder frame whose columns the isometry generators
833/// rotate: an `(output_dim, latent_dim)` matrix whose column `a` is the fitted
834/// image of latent axis `a` in output space (e.g. the decoder Jacobian columns
835/// at the atom's centroid, or the leading decoder directions). The isometry
836/// generators of `Isom(M_k)` act on these columns; the certificate lifts that
837/// action to a tangent direction on the flattened decoder frame.
838#[derive(Debug, Clone)]
839pub struct FittedAtom {
840 pub name: String,
841 pub topology: AtomTopology,
842 /// `(output_dim, latent_dim)` fitted decoder frame.
843 pub frame: Array2<f64>,
844 /// ARD prior variances (one per latent axis of this atom), used to detect
845 /// equal-ARD eigenspaces inside which a rotation is unpinned by the prior.
846 /// `None` ⇒ no ARD prior on this atom (every within-frame rotation is then
847 /// a candidate generator, pinned-or-not decided solely by the data + the
848 /// isometry penalty).
849 pub ard_variances: Option<Array1<f64>>,
850 /// **Lowering-error scale** (#995), in `[0, 1]`: the mass-weighted relative
851 /// dispersion of the atom's per-row decoder tangents around the mean
852 /// `frame` the certificate compresses them into,
853 /// `Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖²`.
854 ///
855 /// `0` ⇒ the frame represents every row exactly (hand-built fixtures, flat
856 /// decoders) and the certificate's verdicts within this atom are at full
857 /// resolution. Values toward `1` ⇒ a curved decoder whose tangent field
858 /// disperses strongly (e.g. a full circle, whose tangents average to ≈ 0):
859 /// the mean-frame lowering then cannot distinguish gauge motion from
860 /// genuine curvature, so the verdict tolerance for generators touching
861 /// this atom is *calibrated up to this scale* — the certificate refuses to
862 /// claim a pin it cannot resolve, the same honesty contract as the
863 /// `diffeomorphism-unpinned` escalation.
864 pub lowering_error: f64,
865 /// #1019 stage 1: `true` when the atom's `d = 1` latent chart was pinned
866 /// post-fit to its arc-length (unit-speed) canonical representative. #1019
867 /// stage 2: `true` as well when a `d = 2` torus atom's chart was pinned
868 /// post-fit to the minimum-isometry-defect flow representative, in which
869 /// case the residual chart freedom is `Isom(T², flat) = U(1)² ⋊ D₄`. The
870 /// certificate then records that this atom's continuous chart
871 /// (reparameterization) freedom is **pinned by canonicalization** — a
872 /// provenance distinct from curvature/penalty pinning
873 /// ([`VerdictProvenance::PinnedByCanonicalization`]) — and that the
874 /// residual chart freedom is the finite isometry group of the reference
875 /// manifold for `d = 1` charts: rotation + reflection (`O(2)`) on the
876 /// circle, reflection + translation on the interval.
877 pub chart_canonicalized: bool,
878 /// Per-atom inner-decoder-smooth byproducts harvested at fit time, the
879 /// single source the post-PIRLS atom inference reports
880 /// ([`AtomFunctionalReport`] #1097, [`AtomSmoothSignificance`] #1103)
881 /// consume in [`dictionary_report`].
882 ///
883 /// The certificate path that builds `FittedSaeManifold` does so *without* a
884 /// fit harness in scope, so it leaves this `None`; callers that own the
885 /// fitted term attach it through [`FittedAtom::with_inner_fit`] (the term
886 /// builder fills it from the live per-atom basis, decoder, assignment mass,
887 /// and smoothness Gram). When `None`, both reports below are `None`: the
888 /// genuine prerequisite — the post-fit inner-smooth design, penalized
889 /// Hessian, and row scores — is simply not present on a bare
890 /// certificate-only `FittedSaeManifold`.
891 pub inner_fit: Option<AtomInnerFit>,
892}
893
894/// The fitted per-atom inner-decoder smooth, captured once at fit time so the
895/// post-PIRLS atom-inference reports reuse the *same* design, penalized Hessian,
896/// and per-row scores the identifiability certificate's curvature sees.
897///
898/// The SAE decoder reconstructs `Z_i ≈ Σ_k a_ik Φ_k(t_ik) B_k`. Holding all
899/// other atoms and the assignment fixed at the fitted optimum, atom `k`'s own
900/// contribution along a single output channel `j` is the Gaussian-identity
901/// penalized smooth `a_ik · Φ_k(t_ik)ᵀ β_{k,j}` with roughness penalty `S_k`,
902/// Gauss–Newton observation weight `w_i = a_ik²` (the assignment mass enters the
903/// channel linearly, so the normal-equation weight is its square), and
904/// dispersion the fitted reconstruction dispersion. That is an ordinary
905/// penalized WLS smooth — exactly what [`crate::inference::riesz`],
906/// [`gam_terms::inference::lawley`], and the κ-profile machinery consume. The
907/// channel `j` is the atom's dominant decoder output direction (largest column
908/// norm of `B_k`), i.e. the channel that carries the atom's signal.
909#[derive(Debug, Clone)]
910pub struct AtomInnerFit {
911 /// `Φ_k` evaluated on the atom's active rows, `(n_active, M_k)`. The inner
912 /// GAM smooth design. Column 0 is the constant/intercept basis column.
913 pub design: Array2<f64>,
914 /// `∂Φ_k/∂t` along the atom's leading latent axis on the active rows,
915 /// `(n_active, M_k)`: the derivative design the average-derivative
916 /// functional integrates.
917 pub derivative_design: Array2<f64>,
918 /// The fitted decoder coefficients for the captured output channel,
919 /// `β_{k,j} ∈ ℝ^{M_k}`.
920 pub beta: Array1<f64>,
921 /// The atom roughness Gram `S_k`, `(M_k, M_k)`.
922 pub penalty: Array2<f64>,
923 /// The penalized Hessian `H = ΦᵀWΦ + S_k` at the fitted state, `(M_k, M_k)`.
924 pub penalized_hessian: Array2<f64>,
925 /// Per-row Gaussian-identity scores `s_i = ∂nll_i/∂β = −w_i r_i Φ_i / φ`,
926 /// `(n_active, M_k)`, on the captured channel.
927 pub row_scores: Array2<f64>,
928 /// Per-row Gauss–Newton weights `w_i = a_ik²` on the captured channel.
929 pub weights: Array1<f64>,
930 /// Fitted reconstruction dispersion `φ` (Gaussian σ²).
931 pub dispersion: f64,
932 /// Design row at the latent peak `t_peak` (largest fitted `|g_k|`).
933 pub peak_design_row: Array1<f64>,
934 /// Design row at the latent mode `t_mode` (largest assignment mass).
935 pub mode_design_row: Array1<f64>,
936}
937
938impl FittedAtom {
939 /// Attach the inner-decoder-smooth byproducts harvested at fit time. The
940 /// term builder calls this so [`dictionary_report`] can produce the three
941 /// post-PIRLS atom inference reports.
942 pub fn with_inner_fit(mut self, inner_fit: AtomInnerFit) -> Self {
943 self.inner_fit = Some(inner_fit);
944 self
945 }
946}
947
948/// Descriptive penalty-debiased POINT summaries of one fitted atom's decoder
949/// curve (#1097, narrowed under #1115). Each field is a scalar functional of the
950/// atom's inner smooth `g_k(t)`, reported as a plug-in value and a one-step
951/// penalty-debiased value (the regularization bias relative to the conditional
952/// target is removed through the atom fit's penalized Hessian). No standard
953/// error and no confidence interval are reported — by design (see below).
954///
955/// # Why these carry NO coverage claim (#1115)
956///
957/// Conditional on the fitted latent coordinates `t̂` and assignment `â`, each
958/// functional is an ordinary linear functional of the penalized-WLS coefficients
959/// `β` with a well-defined *conditional* population value, and one-step debiasing
960/// validly removes the penalty bias for that conditional target. The point
961/// estimates are therefore meaningful. A *standard error*, however, would only be
962/// honest if `t̂` and `â` were fixed/known. They are not: they are **generated
963/// regressors** estimated from the very activations that also form the response
964/// `Z`, so `Z` enters both the design (via `t̂(Z), â(Z)`) and the response. An
965/// influence-function SE built from the β-only Hessian and row scores carries no
966/// `∂t̂/∂Z` / `∂â/∂Z` channel — exactly the generated-regressor correction the
967/// marginal-slope family (#461 Stage 2) is *defined* by — so it omits a
968/// first-order variance term and is generally anti-conservative. Rather than ship
969/// an SE/CI that silently under-covers, this report exposes only the debiased
970/// point summaries; a coverage-valid interval would require either freezing the
971/// dictionary on a held-out split or propagating the generated-regressor
972/// Jacobian, neither of which the fixed inner-fit snapshot supports.
973#[derive(Debug, Clone)]
974pub struct AtomFunctionalReport {
975 /// `g(t_peak) − g(t_mode)`: the peak-vs-baseline contrast of the fitted
976 /// decoder, penalty-debiased through the inner-fit Hessian. Point summary
977 /// only (no coverage claim — see the type doc).
978 pub peak_contrast: Option<AtomFunctionalEstimate>,
979 /// `E_data[g(t_i)]`: the data-averaged decoder value over the atom's active
980 /// rows, penalty-debiased. Point summary only.
981 pub average_value: Option<AtomFunctionalEstimate>,
982 /// `E_data[∂g/∂t]` along the atom's leading latent axis: how much the fitted
983 /// decoder curve varies across the data distribution, **conditional on the
984 /// fit**. A descriptive variation measure of the fitted curve, NOT a
985 /// population "marginal slope" (the latent coordinate is itself a fitted,
986 /// generated regressor). Point summary only.
987 ///
988 /// Despite the historical `_norm` suffix this is the **signed** mass-weighted
989 /// mean derivative `E_data[∂g/∂t]` over the single leading axis, not a
990 /// magnitude — it can be negative, and a value near 0 means the average slope
991 /// cancels (a symmetric bump), not that the curve is flat. Use
992 /// [`AtomSmoothSignificance::log_e_nonconstant`] for an honest non-constancy
993 /// test; this field only describes the average local slope.
994 pub decoder_variation_norm: Option<AtomFunctionalEstimate>,
995}
996
997/// One atom decoder-functional point summary: the plug-in value and the one-step
998/// penalty-debiased value, with the removed penalty bias. Deliberately carries
999/// NO standard error / confidence interval — the conditional-on-generated-
1000/// regressors variance channel is unmodelled, so any SE would under-cover
1001/// (#1115). Use [`AtomSmoothSignificance`] for an honest any-n-valid structure
1002/// test instead.
1003#[derive(Debug, Clone, Copy)]
1004pub struct AtomFunctionalEstimate {
1005 /// The raw plug-in functional value `θ̂ = g·β̂`.
1006 pub theta_plugin: f64,
1007 /// The one-step penalty-debiased value `θ̂ − bias`, removing the
1008 /// regularization bias relative to the conditional target.
1009 pub theta_onestep: f64,
1010 /// The removed penalty bias `(H⁻¹ g)·(Sβ̂)`.
1011 pub penalty_bias: f64,
1012}
1013
1014/// Any-n-valid structure evidence that one atom's inner smooth `h_k(t)` is
1015/// genuinely non-constant (#1103): the same split-likelihood-ratio e-value the
1016/// atom-birth gate uses ([`gam_terms::inference::structure_evidence`]), under the
1017/// null H0 = "the atom's decoder curve is constant in its latent coordinate".
1018///
1019/// This replaces the earlier Lawley–Bartlett-corrected χ² test. That correction
1020/// was a category error here: the penalized smooth's null is effectively
1021/// rank ≈ n, the first-order χ² is the wrong reference entirely, and an O(1/n)
1022/// Bartlett factor (whose own stated size shift is ≈0.15%, flipping no admit/
1023/// demote decision) does not rescue it. The split-LRT e-value is finite-sample
1024/// valid with NO regularity conditions — exactly the instrument for "does this
1025/// atom earn a latent dimension".
1026#[derive(Debug, Clone)]
1027pub struct AtomSmoothSignificance {
1028 /// `log E` for "the atom's smooth is non-constant" (null = constant). A
1029 /// universal-inference split-likelihood-ratio e-value: `E_{H0}[E] ≤ 1`
1030 /// exactly, so `E ≥ 1/α` certifies the non-constant alternative at level α,
1031 /// at any data-dependent stopping time. `None` when the split is degenerate
1032 /// (too few active rows / a fold with no curvature column).
1033 pub log_e_nonconstant: Option<f64>,
1034}
1035
1036/// The post-PIRLS inference reports for one atom, paired by atom index.
1037///
1038/// Two reports survive #1115: the descriptive penalty-debiased point summaries
1039/// of the fitted decoder curve ([`AtomFunctionalReport`], no coverage claim) and
1040/// the any-n-valid split-LRT smooth-structure e-value ([`AtomSmoothSignificance`],
1041/// a genuine finite-sample-valid test). The #1099 per-atom curvature *confidence
1042/// interval* was removed: its target (a sup-norm extrinsic-curvature BOUND read
1043/// off the fitted decoder) is not an estimand with a profiled criterion, and its
1044/// delta-method SE conditioned on the generated latent coordinates as if known.
1045/// The plug-in curvature point estimate itself survives — as the per-atom
1046/// `kappa_hat` entries of
1047/// [`crate::manifold::CertificateInputs::per_atom_kappa_hat`] (the
1048/// #1008 empirical curved-dictionary report, surfaced to Python as
1049/// `ManifoldSAE.curvature_report`), the single source of truth for the bound.
1050/// It is deliberately *not* duplicated onto this report: a descriptive geometry
1051/// bound is a property of the fitted decoder frames, not of the post-PIRLS
1052/// inner-smooth inference snapshot this type carries.
1053#[derive(Debug, Clone)]
1054pub struct AtomInferenceReport {
1055 pub atom_index: usize,
1056 pub atom_name: String,
1057 pub functionals: Option<AtomFunctionalReport>,
1058 pub smooth_significance: Option<AtomSmoothSignificance>,
1059}
1060
1061/// The fitted SAE-manifold model the certificate consumes.
1062///
1063/// Self-contained on purpose: it carries exactly the objects the residual-gauge
1064/// computation needs — the atoms (with topology + fitted frames + ARD), the
1065/// curvature/Jacobian row-blocks that pin directions, and the one
1066/// [`RowMetric`] whose provenance the report reads. The flattened free-parameter
1067/// vector the generators live in is `vec(frame_0) ⊕ vec(frame_1) ⊕ …` in atom
1068/// order; `param_dim()` is its length.
1069pub struct FittedSaeManifold {
1070 pub atoms: Vec<FittedAtom>,
1071 /// Per-row decoder Jacobian blocks `J_n ∈ ℝ^{p × param_dim}` flattened
1072 /// row-major (`J_n[i, c] = jacobian_rows[n][i * param_dim + c]`), one entry
1073 /// per metric row. These are the directions the *data* gives cost to; the
1074 /// certificate whitens them through [`RowMetric`] and orthonormalizes to
1075 /// obtain the data part of the pinning span `range(H_data)`.
1076 pub jacobian_rows: Vec<Vec<f64>>,
1077 /// The isometry-penalty curvature root `R ∈ ℝ^{r × param_dim}` (so the
1078 /// penalty Hessian is `RᵀR`). Its row space is `range(H_isometry)` — the
1079 /// directions the isometry pin gives cost to. Empty (`0 × param_dim`) when
1080 /// the isometry pin is inactive, which is exactly the condition that
1081 /// escalates the verdict to `diffeomorphism-unpinned`.
1082 pub isometry_penalty_root: Array2<f64>,
1083 /// The single provenance-carrying per-row inner product. Read for the
1084 /// report's "computed in metric X" line and used to whiten the Jacobian
1085 /// rows so the rank decision happens in the fit's actual metric.
1086 pub metric: RowMetric,
1087}
1088
1089impl FittedSaeManifold {
1090 /// Total flattened free-parameter dimension `Σ_k output_dim_k · latent_dim_k`
1091 /// (the decoder-frame coordinates the generators are tangent directions in).
1092 pub fn param_dim(&self) -> usize {
1093 self.atoms.iter().map(|a| a.frame.len()).sum()
1094 }
1095
1096 /// Column offset of atom `k`'s flattened frame inside the joint parameter
1097 /// vector.
1098 fn atom_offset(&self, k: usize) -> usize {
1099 self.atoms[..k].iter().map(|a| a.frame.len()).sum()
1100 }
1101}
1102
1103/// Which symmetry family a generator belongs to. Carried per-generator so the
1104/// report names the group the residual freedom (or pin) lives in.
1105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1106pub enum GeneratorFamily {
1107 /// A generator of `Isom(M_k)` for a single atom (frame rotation/reflection
1108 /// realising the atom's own manifold isometry).
1109 IsomAtom,
1110 /// A rotation inside an ARD-equal eigenspace (the ARD prior cannot
1111 /// distinguish the two axes, so the prior does not pin this rotation).
1112 EqualArdRotation,
1113 /// A rotation of the global decoder output frame `O(output_dim)`.
1114 FrameRotation,
1115 /// An exchange of two topology-identical atoms (`Sym(F)` permutation, built
1116 /// as the antisymmetric transposition direction).
1117 AtomPermutation,
1118 /// The continuous chart (reparameterization) freedom `Diff(M_k)` of one
1119 /// `d = 1` atom (arc-length canonicalization) or `d = 2` torus atom
1120 /// (isometry-flow canonicalization, #1019 stage 2). Always reported
1121 /// **pinned** with
1122 /// [`VerdictProvenance::PinnedByCanonicalization`]; the verdict's
1123 /// description names the surviving residual group (rotation + reflection
1124 /// on `S¹`, reflection + translation on the interval, or `Isom(T², flat) =
1125 /// U(1)² ⋊ D₄` for a `d = 2` torus).
1126 ChartReparameterization,
1127}
1128
1129impl GeneratorFamily {
1130 fn label(self) -> &'static str {
1131 match self {
1132 GeneratorFamily::IsomAtom => "Isom(M_k)",
1133 GeneratorFamily::EqualArdRotation => "equal-ARD rotation",
1134 GeneratorFamily::FrameRotation => "frame rotation O(output_dim)",
1135 GeneratorFamily::AtomPermutation => "Sym(F) atom permutation",
1136 GeneratorFamily::ChartReparameterization => "Diff(M_k) chart reparameterization",
1137 }
1138 }
1139}
1140
1141/// How a generator's pinned/unpinned verdict was decided. Carried
1142/// per-generator so the report distinguishes a chart fixed **by convention**
1143/// (the #1019 post-fit arc-length canonicalization — an exact, image-frozen
1144/// representative choice) from a direction pinned **by curvature** (data or
1145/// the isometry penalty giving the orbit genuine objective cost).
1146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1147pub enum VerdictProvenance {
1148 /// Decided by the relative-curvature flatness test against the stacked
1149 /// pinning root (data + isometry penalty, in the fit's metric) — the
1150 /// historical path for every enumerated generator.
1151 CurvatureTest,
1152 /// Pinned by the post-fit arc-length chart canonicalization (#1019) or the
1153 /// `d = 2` torus isometry-flow canonicalization (#1019 stage 2): the atom's
1154 /// chart is the selected representative of its `Diff(M)` orbit, so the
1155 /// continuous reparameterization freedom is fixed by construction — no
1156 /// curvature was (or needed to be) measured. Distinct from penalty-pinning
1157 /// on purpose: the certificate must not claim the objective resists chart
1158 /// motion when it is the canonicalization that removed it.
1159 PinnedByCanonicalization,
1160}
1161
1162/// Noise floor for the per-generator flatness verdict: a generator is
1163/// certified **unpinned** iff its relative curvature fraction
1164/// `‖R ξ̂‖² / σ_max(R)²` (curvature along the unit generator, relative to the
1165/// stiffest direction of the stacked curvature root `R`) is at or below the
1166/// verdict tolerance `max(GENERATOR_FLAT_ENERGY_TOL, lowering_error_scale)`.
1167///
1168/// An exact residual symmetry of the converged objective has fraction 0 up to
1169/// roundoff; any genuinely curved component — however partial — means the
1170/// orbit costs objective and the exact group element is broken, so a *mixed*
1171/// generator (e.g. a frame rotation the anisotropic output-Fisher isometry pin
1172/// gives partial curvature, the #980 Theorem-2 situation) must be reported
1173/// pinned, never as a surviving freedom. The `lowering_error_scale` arm of the
1174/// tolerance is the #995 calibration: curvature attributable to the mean-frame
1175/// compression of a curved decoder must not be read as a pin.
1176pub const GENERATOR_FLAT_ENERGY_TOL: f64 = 1.0e-3;
1177
1178/// One enumerated symmetry generator and the certificate's verdict on it.
1179#[derive(Debug, Clone)]
1180pub struct GeneratorVerdict {
1181 /// Which symmetry family this generator realises.
1182 pub family: GeneratorFamily,
1183 /// Human-readable description (which atom(s) / axes it acts on).
1184 pub description: String,
1185 /// `true` ⇒ the converged objective is flat along this generator
1186 /// (`ξ ∈ ker(H)`): a genuine residual gauge freedom the data + isometry
1187 /// penalty leave unbroken. `false` ⇒ the generator is pinned — the data or
1188 /// the isometry penalty gives it curvature (a pinned-energy fraction above
1189 /// [`GENERATOR_FLAT_ENERGY_TOL`]).
1190 pub unpinned: bool,
1191 /// `‖ξ‖₂` of the realised tangent direction (0 ⇒ the generator was
1192 /// structurally trivial — e.g. a rotation of a rank-deficient frame — and
1193 /// is reported as pinned/absent, never as a spurious freedom).
1194 pub generator_norm: f64,
1195 /// `‖R ξ̂‖² / σ_max(R)²` ∈ [0, 1]: curvature along the unit generator,
1196 /// relative to the stiffest direction of the stacked curvature root `R`
1197 /// (data + isometry penalty, in the metric). `0` ⇒ exactly flat, `1` ⇒ as
1198 /// stiff as the stiffest direction; strictly-interior values are the
1199 /// *mixed* regime — partial curvature that breaks the exact symmetry
1200 /// (verdict pinned when above the tolerance) while leaving nearby flat
1201 /// directions, kept visible here rather than collapsed into the boolean.
1202 /// Relative-to-σ_max (not span membership) so the statistic stays
1203 /// informative when the pinning span is full-rank, which production fits
1204 /// always are. Structurally trivial generators (zero norm) report `1.0`.
1205 pub pinned_energy_fraction: f64,
1206 /// The #995 lowering-error arm of this generator's verdict tolerance: the
1207 /// largest [`FittedAtom::lowering_error`] over the atoms the generator
1208 /// touches (its own atom for within-atom families, the exchanged pair for
1209 /// permutations, all atoms for global output-frame rotations). The verdict
1210 /// is `unpinned ⇔ pinned_energy_fraction ≤
1211 /// max(GENERATOR_FLAT_ENERGY_TOL, lowering_error_scale)` — curvature the
1212 /// mean-frame compression cannot distinguish from gauge motion is never
1213 /// read as a pin.
1214 pub lowering_error_scale: f64,
1215 /// How this verdict was decided: by the curvature flatness test, or
1216 /// pinned by the #1019 post-fit arc-length chart canonicalization
1217 /// (see [`VerdictProvenance`]).
1218 pub provenance: VerdictProvenance,
1219}
1220
1221/// The #972 decoder-frame **inner-rotation gauge**, enumerated for the
1222/// certificate.
1223///
1224/// A frame-factored atom `B_k = U_k C_k` is *exactly* invariant under
1225/// `U_k → U_k R`, `C_k → Rᵀ C_k` for any `R ∈ O(r_k)`: the reconstruction,
1226/// the likelihood, the penalty — every objective term — sees only the
1227/// product. Unlike the latent-isometry / ARD-rotation / permutation
1228/// generators, this freedom is therefore **not** a candidate to be pinned by
1229/// data or penalty curvature (its orbit direction is identically zero in
1230/// function space), so running it through the pinning-span test would be a
1231/// category error: it would always come back "unpinned" and pollute the
1232/// verdict list with freedoms the parameterization already handles. The
1233/// honest certificate treatment is what this struct is: *enumerate* the
1234/// group and its dimension `Σ_k r_k(r_k−1)/2`, and record how it is fixed —
1235/// by the canonical orientation gauge
1236/// ([`crate::manifold::GrassmannFrame`]'s SVD-ordered
1237/// representative), which picks one point per `O(r_k)` orbit for
1238/// serialization/comparison stability.
1239#[derive(Debug, Clone, PartialEq, Eq)]
1240pub struct FrameInnerRotationGauge {
1241 /// Active frame rank `r_k` per frame-factored atom (atoms on the full-`B`
1242 /// path contribute no entry).
1243 pub per_atom_ranks: Vec<usize>,
1244 /// Total group dimension `Σ_k r_k (r_k − 1) / 2` (`dim O(r) = r(r−1)/2`).
1245 pub dim: usize,
1246}
1247
1248impl FrameInnerRotationGauge {
1249 /// Enumerate the gauge from the active frame ranks.
1250 pub fn from_ranks(per_atom_ranks: Vec<usize>) -> Self {
1251 let dim = frame_inner_rotation_dim(&per_atom_ranks);
1252 Self {
1253 per_atom_ranks,
1254 dim,
1255 }
1256 }
1257}
1258
1259/// `Σ_k r_k (r_k − 1) / 2` — the dimension of the #972 inner-rotation gauge
1260/// group `∏_k O(r_k)` over the active frame ranks. Rank-1 frames contribute
1261/// `0` (`O(1)` is finite, a sign — absorbed by the orientation gauge), so a
1262/// dictionary of single-direction atoms reports a zero-dimensional inner
1263/// gauge, matching the intuition that one direction has no inner rotation to
1264/// fix.
1265pub fn frame_inner_rotation_dim(ranks: &[usize]) -> usize {
1266 ranks.iter().map(|&r| r * r.saturating_sub(1) / 2).sum()
1267}
1268
1269/// The certificate produced by [`residual_gauge`].
1270#[derive(Debug, Clone)]
1271pub struct ResidualGaugeReport {
1272 /// "computed in metric X" — read straight off
1273 /// [`RowMetric::provenance`]; the single metric object guarantees this
1274 /// matches the inner product the fit actually used.
1275 pub metric_provenance: MetricProvenance,
1276 /// Per-generator pinned/unpinned verdict, in enumeration order.
1277 pub generators: Vec<GeneratorVerdict>,
1278 /// Rank of the pinning span `range(H)` (data + isometry penalty) the
1279 /// generators were tested against, in the metric.
1280 pub pinning_rank: usize,
1281 /// Number of generators certified as unpinned residual gauge freedoms.
1282 pub residual_gauge_dim: usize,
1283 /// `true` when the isometry pin is inactive (`isometry_penalty_root` has no
1284 /// rows): the model is then only identified up to an arbitrary
1285 /// diffeomorphism of the latent manifolds, and every isometry generator is
1286 /// reported as a residual freedom. This is the escalation flag.
1287 pub diffeomorphism_unpinned: bool,
1288 /// Under [`MetricProvenance::OutputFisher`] the `Sym(F)` permutation
1289 /// subgroup is expected to be *trivially pinned* — the output-Fisher metric
1290 /// distinguishes the atoms behaviorally so no atom-exchange can be a
1291 /// residual freedom. `true` ⇒ that triviality holds (every
1292 /// [`GeneratorFamily::AtomPermutation`] generator is pinned);
1293 /// `false` ⇒ a permutation survived as a residual freedom, which under
1294 /// OutputFisher provenance is a certificate violation the caller must
1295 /// surface. `None` ⇒ provenance is not `OutputFisher`, so the check does
1296 /// not apply.
1297 pub sym_f_trivial_under_output_fisher: Option<bool>,
1298 /// The #972 decoder-frame inner-rotation gauge `∏_k O(r_k)` — enumerated,
1299 /// never curvature-tested (see [`FrameInnerRotationGauge`] for why).
1300 /// `None` when the caller declared no frame factorization (full-`B`
1301 /// dictionaries, or a pre-#972 caller using [`residual_gauge`] directly);
1302 /// attach via [`ResidualGaugeReport::with_frame_inner_rotation`].
1303 pub frame_inner_rotation: Option<FrameInnerRotationGauge>,
1304 /// Human-readable one-line summary.
1305 pub summary: String,
1306}
1307
1308impl ResidualGaugeReport {
1309 /// The certified residual gauge group, as a compact string naming the
1310 /// surviving generator families and their multiplicities. Two replicate
1311 /// fits are "identified up to the same group" iff this string is equal.
1312 ///
1313 /// When a frame inner-rotation gauge is enumerated it is appended with its
1314 /// dimension and its `[canonical-fixed]` marker — it is part of the group
1315 /// two replicate fits must agree on, even though it is fixed by
1316 /// convention rather than by curvature.
1317 pub fn group_signature(&self) -> String {
1318 let base = group_signature_of(&self.generators, self.diffeomorphism_unpinned);
1319 match &self.frame_inner_rotation {
1320 Some(gauge) if gauge.dim > 0 => format!(
1321 "{base} ⊕ frame-inner ∏O(r_k)×{} [dim {}, canonical-fixed]",
1322 gauge.per_atom_ranks.len(),
1323 gauge.dim
1324 ),
1325 _ => base,
1326 }
1327 }
1328
1329 /// Attach the #972 frame inner-rotation enumeration to the certificate
1330 /// (consumed by frame-factored dictionaries; `ranks` are the active frame
1331 /// ranks `r_k`, one per factored atom). Extends the summary so the
1332 /// one-line report names the enumerated-but-convention-fixed gauge.
1333 pub fn with_frame_inner_rotation(mut self, ranks: Vec<usize>) -> Self {
1334 let gauge = FrameInnerRotationGauge::from_ranks(ranks);
1335 if gauge.dim > 0 {
1336 self.summary.push_str(&format!(
1337 "; frame inner-rotation gauge ∏O(r_k) of dim {} enumerated \
1338 (exact reparameterization, fixed by the canonical orientation gauge)",
1339 gauge.dim
1340 ));
1341 }
1342 self.frame_inner_rotation = Some(gauge);
1343 self
1344 }
1345}
1346
1347/// Compact, order-independent signature of the unpinned generator families and
1348/// multiplicities. Two replicate fits agree on their residual gauge group iff
1349/// these strings are equal.
1350fn group_signature_of(generators: &[GeneratorVerdict], diffeomorphism_unpinned: bool) -> String {
1351 let mut counts: std::collections::BTreeMap<&'static str, usize> =
1352 std::collections::BTreeMap::new();
1353 for g in generators {
1354 if g.unpinned {
1355 *counts.entry(g.family.label()).or_insert(0) += 1;
1356 }
1357 }
1358 let body = if counts.is_empty() {
1359 "{e} [fully pinned: rigid up to nothing]".to_string()
1360 } else {
1361 counts
1362 .iter()
1363 .map(|(name, mult)| format!("{name}×{mult}"))
1364 .collect::<Vec<_>>()
1365 .join(" ⊕ ")
1366 };
1367 if diffeomorphism_unpinned {
1368 // With the isometry pin inactive the residual gauge is at least the
1369 // manifold reparametrization (diffeomorphism) group modulo whatever the
1370 // data alone still pins — the surviving generators below are the
1371 // isometry slice of that larger freedom.
1372 format!("Diff(M) ⊇ {{ {body} }} [diffeomorphism-unpinned: isometry pin inactive]")
1373 } else {
1374 body
1375 }
1376}
1377
1378/// Build the atom-local isometry generators for one atom as tangent directions
1379/// on the atom's flattened decoder frame.
1380///
1381/// An isometry of the latent manifold acts on the latent coordinate frame; we
1382/// lift it to the decoder output by acting on the frame columns. For a rotation
1383/// generator `A ∈ so(latent_dim)` (antisymmetric), the induced tangent direction
1384/// on `frame ∈ ℝ^{p × d}` is `frame · Aᵀ` (the first-order motion of the frame
1385/// columns under the one-parameter rotation `exp(tA)`), flattened row-major. For
1386/// the circle this is the single `so(2)` generator; for the sphere the three
1387/// `so(3)` generators; for the torus the `d` independent axis shifts (which on
1388/// the flat product manifold are translations of each circle coordinate —
1389/// realised as the unit tangent along each frame column).
1390fn atom_isometry_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
1391 let (p, d) = atom.frame.dim();
1392 // The intrinsic latent dimension of the manifold fixes `dim Isom(M_k)` (the
1393 // number of independent isometry generators we must enumerate). The fitted
1394 // decoder frame's column count `d` must realise exactly that many latent
1395 // axes; a frame whose column count disagrees with the topology's intrinsic
1396 // dimension is a structurally inconsistent atom and we refuse to fabricate
1397 // generators for it (returning none, so it cannot masquerade as either
1398 // pinned or a spurious residual freedom in the certificate).
1399 if d != atom.topology.latent_dim() {
1400 return Vec::new();
1401 }
1402 let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1403 match &atom.topology {
1404 AtomTopology::Circle => {
1405 // so(2): A = [[0,-1],[1,0]] on the 1 circle, but a Circle atom has a
1406 // single latent axis whose isometry is a *shift* of the periodic
1407 // coordinate. The first-order motion of the (cos,sin) frame columns
1408 // under a shift is the orthogonal frame column. With latent_dim == 1
1409 // the decoder frame's single column moves along its own
1410 // 90°-rotated image, which (lacking a second column) is realised as
1411 // the tangent that advances the periodic phase: the unit direction
1412 // along the frame column itself (the generator of the U(1) shift).
1413 if d >= 1 {
1414 let mut g = Array1::<f64>::zeros(p * d);
1415 for i in 0..p {
1416 g[i * d] = atom.frame[[i, 0]];
1417 }
1418 out.push((g, format!("{}: S¹ U(1) phase shift", atom.name)));
1419 }
1420 }
1421 AtomTopology::Sphere | AtomTopology::EuclideanPatch { .. } | AtomTopology::Torus { .. } => {
1422 // so(d) rotation generators: one per unordered axis pair (a < b).
1423 // The induced frame motion is frame · A_{ab}ᵀ, i.e. column a picks
1424 // up −column b and column b picks up +column a.
1425 for a in 0..d {
1426 for b in (a + 1)..d {
1427 let mut g = Array1::<f64>::zeros(p * d);
1428 for i in 0..p {
1429 // (frame · Aᵀ)[i, a] = −frame[i, b]; [i, b] = +frame[i, a].
1430 g[i * d + a] = -atom.frame[[i, b]];
1431 g[i * d + b] = atom.frame[[i, a]];
1432 }
1433 out.push((
1434 g,
1435 format!(
1436 "{}: {} rotation axes ({a},{b})",
1437 atom.name,
1438 match &atom.topology {
1439 AtomTopology::Sphere => "S² so(3)",
1440 AtomTopology::Torus { .. } => "Tᵈ frame",
1441 _ => "patch so(d)",
1442 }
1443 ),
1444 ));
1445 }
1446 }
1447 // Torus additionally carries `d` independent circle shifts: the unit
1448 // tangent advancing each axis's periodic phase (translation of that
1449 // circle coordinate), realised as motion along each frame column.
1450 if let AtomTopology::Torus { .. } = atom.topology {
1451 for a in 0..d {
1452 let mut g = Array1::<f64>::zeros(p * d);
1453 for i in 0..p {
1454 g[i * d + a] = atom.frame[[i, a]];
1455 }
1456 out.push((g, format!("{}: Tᵈ circle shift axis {a}", atom.name)));
1457 }
1458 }
1459 }
1460 }
1461 out
1462}
1463
1464/// Build equal-ARD rotation generators for one atom: a rotation between two
1465/// latent axes whose ARD variances are equal (within `rel_tol`) is not pinned by
1466/// the ARD prior, so it is a candidate residual gauge freedom (the data +
1467/// isometry penalty decide). Returns the antisymmetric frame-rotation tangent
1468/// for each such equal pair.
1469fn equal_ard_rotation_generators(atom: &FittedAtom) -> Vec<(Array1<f64>, String)> {
1470 let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1471 let (p, d) = atom.frame.dim();
1472 let Some(ard) = atom.ard_variances.as_ref() else {
1473 return out;
1474 };
1475 if ard.len() != d {
1476 return out;
1477 }
1478 const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
1479 for a in 0..d {
1480 for b in (a + 1)..d {
1481 let va = ard[a];
1482 let vb = ard[b];
1483 let scale = va.abs().max(vb.abs()).max(f64::MIN_POSITIVE);
1484 if (va - vb).abs() <= ARD_EQUAL_REL_TOL * scale {
1485 let mut g = Array1::<f64>::zeros(p * d);
1486 for i in 0..p {
1487 g[i * d + a] = -atom.frame[[i, b]];
1488 g[i * d + b] = atom.frame[[i, a]];
1489 }
1490 out.push((
1491 g,
1492 format!("{}: equal-ARD rotation axes ({a},{b})", atom.name),
1493 ));
1494 }
1495 }
1496 }
1497 out
1498}
1499
1500/// Build global decoder output-frame rotation generators `O(output_dim)`: a
1501/// rotation `B ∈ so(output_dim)` acts on every atom's frame from the left
1502/// (`B · frame`). The induced tangent on the joint parameter vector stacks
1503/// `B · frame_k` per atom. We enumerate the full `so(output_dim)` basis — one
1504/// generator per unordered output-axis pair `(oi < oj)`, count
1505/// `output_dim·(output_dim−1)/2` — since the per-generator rank test treats each
1506/// independently and we want the certificate to find every output-frame freedom,
1507/// not a subset. `output_dim` is taken as the maximum frame row-count across
1508/// atoms; an atom whose frame lacks one of the two axes contributes nothing to
1509/// that generator.
1510fn frame_rotation_generators(model: &FittedSaeManifold) -> Vec<(Array1<f64>, String)> {
1511 let mut out: Vec<(Array1<f64>, String)> = Vec::new();
1512 let p = model
1513 .atoms
1514 .iter()
1515 .map(|a| a.frame.nrows())
1516 .max()
1517 .unwrap_or(0);
1518 let param_dim = model.param_dim();
1519 for oi in 0..p {
1520 for oj in (oi + 1)..p {
1521 let mut g = Array1::<f64>::zeros(param_dim);
1522 for (k, atom) in model.atoms.iter().enumerate() {
1523 let (ap, ad) = atom.frame.dim();
1524 if oi >= ap || oj >= ap {
1525 continue;
1526 }
1527 let base = model.atom_offset(k);
1528 // (B · frame)[oi, c] = −frame[oj, c]; [oj, c] = +frame[oi, c].
1529 for c in 0..ad {
1530 g[base + oi * ad + c] = -atom.frame[[oj, c]];
1531 g[base + oj * ad + c] = atom.frame[[oi, c]];
1532 }
1533 }
1534 out.push((g, format!("output-frame rotation axes ({oi},{oj})")));
1535 }
1536 }
1537 out
1538}
1539
1540/// Build exchangeable-atom permutation generators: for every pair of atoms with
1541/// identical topology and matching frame shape, the transposition that swaps
1542/// their decoder frames is a candidate `Sym(F)` symmetry. Realised as the
1543/// antisymmetric "swap" tangent `(frame_b − frame_a)` placed on atom a's slot and
1544/// `(frame_a − frame_b)` on atom b's slot — the first-order direction of the
1545/// one-parameter family interpolating the swap.
1546/// Embed an atom-local generator (length = that atom's flattened frame length)
1547/// into the joint parameter vector at the atom's column offset. The per-atom
1548/// generator builders do not know the joint layout; the certificate does, and
1549/// mixing the two coordinate systems is a shape error for every model with more
1550/// than one atom.
1551fn embed_local_generator(offset: usize, local: &Array1<f64>, param_dim: usize) -> Array1<f64> {
1552 let mut g = Array1::<f64>::zeros(param_dim);
1553 g.slice_mut(s![offset..offset + local.len()]).assign(local);
1554 g
1555}
1556
1557fn atom_permutation_generators(
1558 model: &FittedSaeManifold,
1559) -> Vec<(Array1<f64>, String, usize, usize)> {
1560 let mut out: Vec<(Array1<f64>, String, usize, usize)> = Vec::new();
1561 let param_dim = model.param_dim();
1562 for ka in 0..model.atoms.len() {
1563 for kb in (ka + 1)..model.atoms.len() {
1564 let a = &model.atoms[ka];
1565 let b = &model.atoms[kb];
1566 if a.topology != b.topology || a.frame.dim() != b.frame.dim() {
1567 continue;
1568 }
1569 let (ap, ad) = a.frame.dim();
1570 let base_a = model.atom_offset(ka);
1571 let base_b = model.atom_offset(kb);
1572 let mut g = Array1::<f64>::zeros(param_dim);
1573 for i in 0..ap {
1574 for c in 0..ad {
1575 let diff = b.frame[[i, c]] - a.frame[[i, c]];
1576 g[base_a + i * ad + c] = diff;
1577 g[base_b + i * ad + c] = -diff;
1578 }
1579 }
1580 out.push((g, format!("atom-exchange {} ↔ {}", a.name, b.name), ka, kb));
1581 }
1582 }
1583 out
1584}
1585
1586// ============================================================================
1587// #998 — the full-resolution certificate: exact gauge orbits in the model's
1588// own (decoder, coordinate) parameter space.
1589// ============================================================================
1590
1591/// One atom's exact parameter-space view (#998): the raw objects the fit
1592/// actually optimizes, in which the model-class gauge orbits live.
1593///
1594/// The mean-frame certificate ([`FittedAtom::frame`]) is a lossy compression:
1595/// the true gauge orbits are **compensated** motions — the latent coordinates
1596/// move AND the decoder counter-rotates (e.g. `Φ(t+ε)·R(−ε)B = Φ(t)B` for the
1597/// harmonic circle) — whose net action on the mean frame is identically zero,
1598/// so no frame-space realisation can measure them (#995's calibrated tolerance
1599/// is the honest *floor* there). With this view the certificate realises each
1600/// orbit exactly: the coordinate motion field `δt` comes from the group
1601/// action, and the decoder compensation `δB` is **profiled out by least
1602/// squares** against the data motion. The leftover residual is the orbit's
1603/// true data cost — exactly zero when the basis family is closed under the
1604/// action (harmonics under shifts, linear charts under rotations), genuinely
1605/// positive when it is not (a Duchon patch under so(d)). Basis closure is
1606/// therefore a *computed* per-generator quantity, not a declared flag.
1607#[derive(Debug, Clone)]
1608pub struct AtomParameterView {
1609 /// Basis values `Φ`, `(n, M)`.
1610 pub basis_values: Array2<f64>,
1611 /// Basis first-derivative jet `Φ'`, `(n, M, latent_dim)`.
1612 pub basis_jacobian: Array3<f64>,
1613 /// Decoder coefficients `B`, `(M, p)`.
1614 pub decoder: Array2<f64>,
1615 /// Latent coordinates `t`, `(n, latent_dim)` — the chart the group acts on.
1616 pub coords: Array2<f64>,
1617 /// Per-row assignment mass `a_nk`, length `n`.
1618 pub activations: Array1<f64>,
1619 /// Basis second-derivative jet `Φ''`, `(n, M, latent_dim, latent_dim)`.
1620 /// Required only to lower an isometry [`OrbitPenaltyOperator`] for a
1621 /// *pin-active* fit (#998): the penalty is a function of the pullback
1622 /// metric `g_n = J_nᵀ W_n J_n`, and the first-order change of `g_n` under a
1623 /// coordinate motion `δt` differentiates `J_n = Φ'_n B` through `t`, which
1624 /// needs `Φ''`. `None` keeps the data-only orbit verdict (no pin), exactly
1625 /// as before; absence never errors.
1626 pub basis_second_jet: Option<Array4<f64>>,
1627}
1628
1629/// The penalty/prior channel of the exact certificate: an operator returning
1630/// the penalty curvature root's image of an orbit direction `(δB, δt)`,
1631/// together with its stiffness scale `σ_max²`. With exact orbits the data can
1632/// never pin a model-class symmetry (the LS-compensated motion is a data-null
1633/// by construction for closed bases), so **all** pinning of such symmetries
1634/// flows through this channel — exactly where the #981 gauge-reduction ladder
1635/// says identification lives (the isometry pin does the collapsing, rungs 2
1636/// and 4, in whichever metric it is computed). `None` ⇒ no pin installed on
1637/// this atom; the orbit's verdict is then decided by the data residual alone.
1638pub struct OrbitPenaltyOperator {
1639 /// Maps an orbit direction `(δB (M, p), δt (n, latent_dim))` to the
1640 /// penalty curvature root's image (any length); the penalty cost along the
1641 /// direction is the squared norm of the image.
1642 pub apply: Box<dyn Fn(ArrayView2<f64>, ArrayView2<f64>) -> Array1<f64> + Send + Sync>,
1643 /// `σ_max²` of the penalty curvature root — the stiffness scale the
1644 /// orbit's penalty cost is reported relative to (the same
1645 /// relative-curvature convention as the frame certificate).
1646 pub stiffness_sq: f64,
1647}
1648
1649/// Build the isometry-pin [`OrbitPenaltyOperator`] for one viewed atom from its
1650/// second jet (#998 — the orbit-space pin operator the pin-active exact path
1651/// needs).
1652///
1653/// The isometry penalty is `P = ½ μ Σ_n ‖g_n − g_ref‖²_F` with the pullback
1654/// first-fundamental-form gram `g_n = J_nᵀ J_n`, `J_n[i,c] = Σ_m Φ'_n[m,c] B[m,i]`
1655/// (Euclidean metric — the default isometry reference; an output-Fisher metric
1656/// rides the same operator once its factors are threaded, which only re-weights
1657/// the `i`-sum). At a converged isometric fit the residual `g_n − g_ref ≈ 0`, so
1658/// the penalty's curvature along an orbit direction `(δB, δt)` is the
1659/// Gauss-Newton term `μ Σ_n ‖δg_n‖²_F`, and the curvature-root image is
1660/// `√μ · {δg_n[a,b]}` — its squared norm is exactly that cost. The first-order
1661/// gram change
1662///
1663/// `δJ_n[i,c] = Σ_m Φ'_n[m,c] δB[m,i] + Σ_{m,e} Φ''_n[m,c,e] δt_n[e] B[m,i]`
1664/// `δg_n[a,b] = Σ_i ( δJ_n[i,a] J_n[i,b] + J_n[i,a] δJ_n[i,b] )`
1665///
1666/// differentiates `J_n` through `t` via the **second jet** `Φ''` — which is why
1667/// the pin-active path needs it and the frame path (no second jet) could not
1668/// supply it. A model-class symmetry that preserves the metric (e.g. a circle
1669/// phase shift on a closed harmonic basis) yields `δg_n = 0` → the operator
1670/// gives it zero cost → it stays a certified freedom even under the pin; a
1671/// non-isometric orbit (a Duchon/quadratic patch under rotation) yields
1672/// `δg_n ≠ 0` → genuine pinning. The verdict is therefore conservative: the
1673/// operator can only *cost* an orbit, never spuriously free one.
1674///
1675/// `weight` is the penalty strength `μ`. Returns `None` when the view carries no
1676/// second jet (the atom's basis exposes no analytic Hessian): with no orbit-space
1677/// operator the atom's verdict falls back to the data residual, never an error.
1678/// The stiffness `σ_max²` is `μ` times the largest unit-coordinate-motion gram
1679/// curvature `max_n σ_max(∂g_n/∂t)²`, so the reported relative fraction is on the
1680/// same convention as the frame certificate.
1681pub fn isometry_orbit_penalty_operator(
1682 view: &AtomParameterView,
1683 weight: f64,
1684) -> Option<OrbitPenaltyOperator> {
1685 let second = view.basis_second_jet.as_ref()?.clone();
1686 let (n, m) = view.basis_values.dim();
1687 let d = view.coords.ncols();
1688 let p = view.decoder.ncols();
1689 if second.dim() != (n, m, d, d) || view.basis_jacobian.dim() != (n, m, d) {
1690 return None;
1691 }
1692 if !(weight.is_finite() && weight > 0.0) {
1693 return None;
1694 }
1695 let sqrt_w = weight.sqrt();
1696 let jac = view.basis_jacobian.clone();
1697 let decoder = view.decoder.clone();
1698
1699 // Base pullback Jacobian J_n[i,c] = Σ_m Φ'_n[m,c] B[m,i] and its per-row
1700 // first-fundamental gram σ_max scale (stiffness), computed once.
1701 let mut j_base = Array3::<f64>::zeros((n, p, d));
1702 for row in 0..n {
1703 for i in 0..p {
1704 for c in 0..d {
1705 let mut acc = 0.0;
1706 for mm in 0..m {
1707 acc += jac[[row, mm, c]] * decoder[[mm, i]];
1708 }
1709 j_base[[row, i, c]] = acc;
1710 }
1711 }
1712 }
1713
1714 // Stiffness: σ_max over rows of the gram derivative ∂g_n/∂t along a unit
1715 // coordinate motion. ∂g_n/∂t_e [a,b] = Σ_i ( H_n[i,a,e] J_n[i,b]
1716 // + J_n[i,a] H_n[i,b,e] ), H_n[i,c,e] = Σ_m Φ''_n[m,c,e] B[m,i]. The
1717 // stiffest unit δt direction's gram change drives the relative-curvature
1718 // denominator; we take the largest ‖∂g/∂t_e‖_F over axes e and rows as a
1719 // conservative (≤ true σ_max) scale, so the reported fraction never
1720 // under-states the pin.
1721 let mut max_curv_sq = 0.0_f64;
1722 for row in 0..n {
1723 // H_n[i, c, e] = Σ_m Φ''_n[m, c, e] B[m, i].
1724 let mut hn = vec![0.0_f64; p * d * d];
1725 for i in 0..p {
1726 for c in 0..d {
1727 for e in 0..d {
1728 let mut acc = 0.0;
1729 for mm in 0..m {
1730 acc += second[[row, mm, c, e]] * decoder[[mm, i]];
1731 }
1732 hn[(i * d + c) * d + e] = acc;
1733 }
1734 }
1735 }
1736 for e in 0..d {
1737 let mut frob = 0.0_f64;
1738 for a in 0..d {
1739 for b in 0..d {
1740 let mut g = 0.0;
1741 for i in 0..p {
1742 g += hn[(i * d + a) * d + e] * j_base[[row, i, b]];
1743 g += j_base[[row, i, a]] * hn[(i * d + b) * d + e];
1744 }
1745 frob += g * g;
1746 }
1747 }
1748 max_curv_sq = max_curv_sq.max(frob);
1749 }
1750 }
1751 let stiffness_sq = (weight * max_curv_sq).max(f64::MIN_POSITIVE);
1752
1753 let apply = move |delta_b: ArrayView2<f64>, delta_t: ArrayView2<f64>| -> Array1<f64> {
1754 let mut image = Array1::<f64>::zeros(n * d * d);
1755 // δJ_n[i,c] = Σ_m Φ'_n[m,c] δB[m,i] + Σ_{m,e} Φ''_n[m,c,e] δt_n[e] B[m,i].
1756 let valid_b = delta_b.dim() == (m, p);
1757 let valid_t = delta_t.dim() == (n, d);
1758 if !valid_t {
1759 return image;
1760 }
1761 for row in 0..n {
1762 let mut dj = vec![0.0_f64; p * d];
1763 for i in 0..p {
1764 for c in 0..d {
1765 let mut acc = 0.0;
1766 if valid_b {
1767 for mm in 0..m {
1768 acc += jac[[row, mm, c]] * delta_b[[mm, i]];
1769 }
1770 }
1771 for e in 0..d {
1772 let dte = delta_t[[row, e]];
1773 if dte == 0.0 {
1774 continue;
1775 }
1776 for mm in 0..m {
1777 acc += second[[row, mm, c, e]] * dte * decoder[[mm, i]];
1778 }
1779 }
1780 dj[i * d + c] = acc;
1781 }
1782 }
1783 // δg_n[a,b] = Σ_i ( δJ[i,a] J[i,b] + J[i,a] δJ[i,b] ).
1784 for a in 0..d {
1785 for b in 0..d {
1786 let mut dg = 0.0;
1787 for i in 0..p {
1788 dg += dj[i * d + a] * j_base[[row, i, b]];
1789 dg += j_base[[row, i, a]] * dj[i * d + b];
1790 }
1791 image[(row * d + a) * d + b] = sqrt_w * dg;
1792 }
1793 }
1794 }
1795 image
1796 };
1797
1798 Some(OrbitPenaltyOperator {
1799 apply: Box::new(apply),
1800 stiffness_sq,
1801 })
1802}
1803
1804/// Enumerate one atom's exact orbit coordinate-motion fields `δt ∈ ℝ^{n×d}`.
1805///
1806/// Supported charts are the ones the group acts on **linearly** (so the
1807/// first-order field is exact, not a linearisation): circle/torus axis shifts
1808/// (`δt = e_ax`, chart-free) and flat-patch `so(d)` rotations
1809/// (`δt_n = A_{ab} t_n`). The sphere's `so(3)` action on an intrinsic chart is
1810/// nonlinear, so sphere atoms stay on the frame path (the caller must not
1811/// build a view for them). Equal-ARD rotations reuse the rotation field for
1812/// the tied axis pairs (the ARD prior is their pinning channel).
1813fn exact_orbit_fields(
1814 atom: &FittedAtom,
1815 view: &AtomParameterView,
1816) -> Vec<(GeneratorFamily, Array2<f64>, String)> {
1817 let n = view.coords.nrows();
1818 let d = view.coords.ncols();
1819 let mut out: Vec<(GeneratorFamily, Array2<f64>, String)> = Vec::new();
1820 let rotation_field = |a: usize, b: usize| -> Array2<f64> {
1821 let mut dt = Array2::<f64>::zeros((n, d));
1822 for row in 0..n {
1823 dt[[row, a]] = -view.coords[[row, b]];
1824 dt[[row, b]] = view.coords[[row, a]];
1825 }
1826 dt
1827 };
1828 match &atom.topology {
1829 AtomTopology::Circle => {
1830 out.push((
1831 GeneratorFamily::IsomAtom,
1832 Array2::<f64>::ones((n, 1)),
1833 format!("{}: S¹ U(1) phase shift [exact orbit]", atom.name),
1834 ));
1835 }
1836 AtomTopology::Torus { .. } => {
1837 for ax in 0..d {
1838 let mut dt = Array2::<f64>::zeros((n, d));
1839 dt.column_mut(ax).fill(1.0);
1840 out.push((
1841 GeneratorFamily::IsomAtom,
1842 dt,
1843 format!("{}: Tᵈ circle shift axis {ax} [exact orbit]", atom.name),
1844 ));
1845 }
1846 }
1847 AtomTopology::EuclideanPatch { .. } => {
1848 for a in 0..d {
1849 for b in (a + 1)..d {
1850 out.push((
1851 GeneratorFamily::IsomAtom,
1852 rotation_field(a, b),
1853 format!(
1854 "{}: patch so(d) rotation axes ({a},{b}) [exact orbit]",
1855 atom.name
1856 ),
1857 ));
1858 }
1859 }
1860 }
1861 AtomTopology::Sphere => {}
1862 }
1863 // Equal-ARD rotations between tied axes, on linearly-acting charts only.
1864 if !matches!(atom.topology, AtomTopology::Circle | AtomTopology::Sphere) {
1865 if let Some(ard) = atom.ard_variances.as_ref() {
1866 if ard.len() == d {
1867 const ARD_EQUAL_REL_TOL: f64 = 1.0e-9;
1868 for a in 0..d {
1869 for b in (a + 1)..d {
1870 let scale = ard[a].abs().max(ard[b].abs()).max(f64::MIN_POSITIVE);
1871 if (ard[a] - ard[b]).abs() <= ARD_EQUAL_REL_TOL * scale {
1872 out.push((
1873 GeneratorFamily::EqualArdRotation,
1874 rotation_field(a, b),
1875 format!(
1876 "{}: equal-ARD rotation axes ({a},{b}) [exact orbit]",
1877 atom.name
1878 ),
1879 ));
1880 }
1881 }
1882 }
1883 }
1884 }
1885 }
1886 out
1887}
1888
1889/// Exact-orbit verdicts for one viewed atom (#998).
1890///
1891/// For each orbit field `δt`: the uncompensated data motion is
1892/// `u_n = a_n · (Φ'_n B) δt_n ∈ ℝ^p`; the decoder compensation `δB` minimizing
1893/// `Σ_n ‖a_n Φ_n δB + u_n‖²` is profiled out through one shared SVD
1894/// pseudo-inverse of the activation-weighted basis `D = diag(a) Φ`; and the
1895/// **compensation residual fraction** `r²/‖u‖²` is the orbit's true relative
1896/// data cost — exactly 0 for a basis closed under the group action, genuinely
1897/// positive otherwise (computed closure). The penalty channel, when installed,
1898/// contributes `‖penalty_root(δB, δt)‖² / σ_max²` on the same
1899/// relative-curvature convention. The verdict needs **no lowering-error
1900/// calibration** (`lowering_error_scale = 0`): nothing here is compressed.
1901///
1902/// The data likelihood this measures against is the activation-reconstruction
1903/// objective in its own (Euclidean) inner product — which per the amended #980
1904/// dispatch rule is the only thing that ever whitens the likelihood unless a
1905/// `WhitenedStructured` noise model is installed; the output-Fisher metric
1906/// reaches gauge verdicts only through the penalty operator.
1907fn exact_orbit_verdicts(
1908 atom: &FittedAtom,
1909 view: &AtomParameterView,
1910 penalty: Option<&OrbitPenaltyOperator>,
1911) -> Result<Vec<GeneratorVerdict>, String> {
1912 let (n, m) = view.basis_values.dim();
1913 let d = view.coords.ncols();
1914 let p = view.decoder.ncols();
1915 if view.basis_jacobian.dim() != (n, m, d) {
1916 return Err(format!(
1917 "exact_orbit_verdicts({}): basis_jacobian shape {:?} must be ({n}, {m}, {d})",
1918 atom.name,
1919 view.basis_jacobian.dim()
1920 ));
1921 }
1922 if view.decoder.nrows() != m {
1923 return Err(format!(
1924 "exact_orbit_verdicts({}): decoder has {} rows but basis has {m} columns",
1925 atom.name,
1926 view.decoder.nrows()
1927 ));
1928 }
1929 if view.coords.nrows() != n || view.activations.len() != n {
1930 return Err(format!(
1931 "exact_orbit_verdicts({}): coords/activations rows must match basis rows {n}",
1932 atom.name
1933 ));
1934 }
1935
1936 let fields = exact_orbit_fields(atom, view);
1937 if fields.is_empty() {
1938 return Ok(Vec::new());
1939 }
1940
1941 // Shared compensation operator: thin SVD of D = diag(a)·Φ, computed once.
1942 let mut design = Array2::<f64>::zeros((n, m));
1943 for row in 0..n {
1944 let a = view.activations[row];
1945 for c in 0..m {
1946 design[[row, c]] = a * view.basis_values[[row, c]];
1947 }
1948 }
1949 let (u_opt, sigma, vt_opt) = design
1950 .svd(true, true)
1951 .map_err(|e| format!("exact_orbit_verdicts({}): SVD of D failed: {e}", atom.name))?;
1952 let u_svd =
1953 u_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked U", atom.name))?;
1954 let vt = vt_opt.ok_or_else(|| format!("exact_orbit_verdicts({}): SVD lacked Vᵀ", atom.name))?;
1955 let smax = sigma.iter().cloned().fold(0.0_f64, f64::max);
1956 let cutoff = smax * f64::EPSILON * (n.max(m) as f64);
1957
1958 let mut out: Vec<GeneratorVerdict> = Vec::with_capacity(fields.len());
1959 for (family, dt, description) in fields {
1960 // Uncompensated data motion u_n = a_n (Φ'_n B) δt_n.
1961 let mut u_mot = Array2::<f64>::zeros((n, p));
1962 for row in 0..n {
1963 let a = view.activations[row];
1964 if !(a != 0.0) {
1965 continue;
1966 }
1967 for ax in 0..d {
1968 let step = dt[[row, ax]];
1969 if step == 0.0 {
1970 continue;
1971 }
1972 for bm in 0..m {
1973 let dphi = view.basis_jacobian[[row, bm, ax]];
1974 if dphi == 0.0 {
1975 continue;
1976 }
1977 let w = a * step * dphi;
1978 for j in 0..p {
1979 u_mot[[row, j]] += w * view.decoder[[bm, j]];
1980 }
1981 }
1982 }
1983 }
1984 let raw: f64 = u_mot.iter().map(|v| v * v).sum();
1985 if raw <= f64::MIN_POSITIVE {
1986 // The orbit does not move the fit at all (zero tangents / zero
1987 // mass): structurally trivial, reported pinned with zero norm,
1988 // mirroring the frame certificate's convention.
1989 out.push(GeneratorVerdict {
1990 family,
1991 description,
1992 unpinned: false,
1993 generator_norm: 0.0,
1994 pinned_energy_fraction: 1.0,
1995 lowering_error_scale: 0.0,
1996 provenance: VerdictProvenance::CurvatureTest,
1997 });
1998 continue;
1999 }
2000 // Profile out the decoder compensation: c = Uᵀu, keep σ > cutoff.
2001 // Residual cost r² = ‖u‖² − ‖c_kept‖² (Pythagoras on the projection).
2002 let coeffs = u_svd.t().dot(&u_mot);
2003 let mut kept_sq = 0.0_f64;
2004 let mut scaled = Array2::<f64>::zeros((sigma.len(), p));
2005 for r in 0..sigma.len() {
2006 if sigma[r] > cutoff {
2007 let inv = 1.0 / sigma[r];
2008 for j in 0..p {
2009 kept_sq += coeffs[[r, j]] * coeffs[[r, j]];
2010 scaled[[r, j]] = -inv * coeffs[[r, j]];
2011 }
2012 }
2013 }
2014 let resid_sq = (raw - kept_sq).max(0.0);
2015 let data_fraction = (resid_sq / raw).clamp(0.0, 1.0);
2016
2017 let penalty_fraction = match penalty {
2018 Some(op) if op.stiffness_sq > f64::MIN_POSITIVE => {
2019 let delta_b = vt.t().dot(&scaled); // δB = −V Σ⁺ Uᵀ u, (M, p)
2020 let image = (op.apply)(delta_b.view(), dt.view());
2021 let cost: f64 = image.iter().map(|v| v * v).sum();
2022 (cost / op.stiffness_sq).clamp(0.0, 1.0)
2023 }
2024 _ => 0.0,
2025 };
2026
2027 let pinned_energy_fraction = data_fraction.max(penalty_fraction);
2028 out.push(GeneratorVerdict {
2029 family,
2030 description,
2031 unpinned: pinned_energy_fraction <= GENERATOR_FLAT_ENERGY_TOL,
2032 generator_norm: raw.sqrt(),
2033 pinned_energy_fraction,
2034 lowering_error_scale: 0.0,
2035 provenance: VerdictProvenance::CurvatureTest,
2036 });
2037 }
2038 Ok(out)
2039}
2040
2041/// The stacked curvature root `R` of the pinning operator, in the fit's
2042/// metric: `(m, param_dim)` with `H = H_data + H_isometry = RᵀR`.
2043///
2044/// We assemble `R = [ W^{½} J ; R_isom ]` whose row space is
2045/// `range(H_data) + range(H_isometry)`, where `W^{½} J` is the metric-whitened
2046/// decoder Jacobian (the metric whitening is the `RowMetric`'s
2047/// `whiten_residual_row` applied to each output residual basis vector — i.e.
2048/// each Jacobian row is whitened in the same inner product the likelihood
2049/// sums). The caller derives both faces from this one object: the pinning
2050/// RANK (RRQR on `Rᵀ`, the audit's leverage-scaled rank decision) and the
2051/// per-generator relative curvature `‖R ξ̂‖² / σ_max(R)²` — magnitudes kept,
2052/// not orthonormalized away, so the statistic survives a full-rank span.
2053fn stacked_curvature_root(model: &FittedSaeManifold) -> Result<Array2<f64>, String> {
2054 let param_dim = model.param_dim();
2055 if param_dim == 0 {
2056 return Ok(Array2::<f64>::zeros((0, 0)));
2057 }
2058 let p = model.metric.p_out();
2059 // Metric-whitened Jacobian rows: each row's Jacobian J_n ∈ ℝ^{p × param_dim}
2060 // is whitened to U_nᵀ J_n ∈ ℝ^{rank × param_dim} so that the resulting rows
2061 // span the same directions the metric-whitened residual gives cost to. We
2062 // build the stacked matrix `R` with one block of whitened rows per metric
2063 // row, then the isometry-penalty root beneath it.
2064 let mut stacked_rows: Vec<Array1<f64>> = Vec::new();
2065 for (n, j_flat) in model.jacobian_rows.iter().enumerate() {
2066 if j_flat.len() != p * param_dim {
2067 return Err(format!(
2068 "stacked_curvature_root: jacobian_rows[{n}] has len {} but expected p*param_dim = {}*{} = {}",
2069 j_flat.len(),
2070 p,
2071 param_dim,
2072 p * param_dim
2073 ));
2074 }
2075 // Whiten each parameter column's p-vector of output sensitivities.
2076 // Column c of J_n is the p-vector (j_flat[i*param_dim + c])_i. Whitening
2077 // it through the metric row (U_nᵀ ·) maps each column to a
2078 // `whit_len`-vector; the resulting `whit_len × param_dim` block's rows
2079 // are the metric-whitened Jacobian rows whose span the data gives cost
2080 // to. For Euclidean provenance `whiten_residual_row` is the identity, so
2081 // `whit_len == p` and the block is J_n unchanged (bit-for-bit the
2082 // isotropic data span).
2083 let mut cols_whitened: Vec<Vec<f64>> = Vec::with_capacity(param_dim);
2084 for c in 0..param_dim {
2085 let mut col = vec![0.0_f64; p];
2086 for i in 0..p {
2087 col[i] = j_flat[i * param_dim + c];
2088 }
2089 cols_whitened.push(model.metric.whiten_residual_row(n, ArrayView1::from(&col)));
2090 }
2091 let whit_len = cols_whitened.first().map_or(0, |c| c.len());
2092 for r in 0..whit_len {
2093 let mut row = Array1::<f64>::zeros(param_dim);
2094 for (c, col) in cols_whitened.iter().enumerate() {
2095 row[c] = col[r];
2096 }
2097 stacked_rows.push(row);
2098 }
2099 }
2100 // Append isometry-penalty root rows.
2101 if model.isometry_penalty_root.ncols() != 0 {
2102 if model.isometry_penalty_root.ncols() != param_dim {
2103 return Err(format!(
2104 "stacked_curvature_root: isometry_penalty_root has {} cols but param_dim = {param_dim}",
2105 model.isometry_penalty_root.ncols()
2106 ));
2107 }
2108 for r in 0..model.isometry_penalty_root.nrows() {
2109 stacked_rows.push(model.isometry_penalty_root.row(r).to_owned());
2110 }
2111 }
2112 if stacked_rows.is_empty() {
2113 return Ok(Array2::<f64>::zeros((0, param_dim)));
2114 }
2115 let m = stacked_rows.len();
2116 let mut r_mat = Array2::<f64>::zeros((m, param_dim));
2117 for (i, row) in stacked_rows.iter().enumerate() {
2118 r_mat.row_mut(i).assign(row);
2119 }
2120 Ok(r_mat)
2121}
2122
2123enum CurvatureReduction {
2124 Root {
2125 pinning_rank: usize,
2126 sigma_max_sq: f64,
2127 root: Array2<f64>,
2128 },
2129 Gram {
2130 pinning_rank: usize,
2131 sigma_max_sq: f64,
2132 gram: Array2<f64>,
2133 },
2134}
2135
2136impl CurvatureReduction {
2137 fn from_model(model: &FittedSaeManifold) -> Result<Self, String> {
2138 let root = stacked_curvature_root(model)?;
2139 if root.nrows() == 0 {
2140 return Ok(Self::Root {
2141 pinning_rank: 0,
2142 sigma_max_sq: 0.0,
2143 root,
2144 });
2145 }
2146 let r_t = root.t().to_owned();
2147 let rrqr = rrqr_with_permutation(&r_t, default_rrqr_rank_alpha())
2148 .map_err(|e| format!("residual_gauge: RRQR on Rᵀ failed: {e:?}"))?;
2149 let (_u, sv, _vt) = root
2150 .svd(false, false)
2151 .map_err(|e| format!("residual_gauge: SVD of curvature root failed: {e}"))?;
2152 let smax = sv.iter().cloned().fold(0.0_f64, f64::max);
2153 Ok(Self::Root {
2154 pinning_rank: rrqr.rank,
2155 sigma_max_sq: smax * smax,
2156 root,
2157 })
2158 }
2159
2160 fn from_gram(gram: Array2<f64>, root_rows: usize, param_dim: usize) -> Result<Self, String> {
2161 if gram.nrows() != param_dim || gram.ncols() != param_dim {
2162 return Err(format!(
2163 "residual_gauge: curvature gram has shape ({}, {}) but param_dim = {param_dim}",
2164 gram.nrows(),
2165 gram.ncols()
2166 ));
2167 }
2168 if param_dim == 0 || root_rows == 0 {
2169 return Ok(Self::Gram {
2170 pinning_rank: 0,
2171 sigma_max_sq: 0.0,
2172 gram,
2173 });
2174 }
2175 let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
2176 format!("residual_gauge: eigendecomposition of curvature gram failed: {e}")
2177 })?;
2178 let sigma_max_sq = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
2179 let sigma_max = sigma_max_sq.sqrt();
2180 let rank_tol = default_rrqr_rank_alpha()
2181 * f64::EPSILON
2182 * (root_rows.max(param_dim).max(1) as f64)
2183 * sigma_max.max(1.0);
2184 let lambda_tol = rank_tol * rank_tol;
2185 let pinning_rank = evals
2186 .iter()
2187 .filter(|&&lambda| lambda.max(0.0) > lambda_tol)
2188 .count();
2189 Ok(Self::Gram {
2190 pinning_rank,
2191 sigma_max_sq,
2192 gram,
2193 })
2194 }
2195
2196 fn pinning_rank(&self) -> usize {
2197 match self {
2198 Self::Root { pinning_rank, .. } | Self::Gram { pinning_rank, .. } => *pinning_rank,
2199 }
2200 }
2201
2202 fn sigma_max_sq(&self) -> f64 {
2203 match self {
2204 Self::Root { sigma_max_sq, .. } | Self::Gram { sigma_max_sq, .. } => *sigma_max_sq,
2205 }
2206 }
2207
2208 fn unit_generator_energy(&self, unit: &Array1<f64>) -> f64 {
2209 match self {
2210 Self::Root { root, .. } => {
2211 let r_xi = root.dot(unit);
2212 r_xi.iter().map(|c| c * c).sum::<f64>()
2213 }
2214 Self::Gram { gram, .. } => {
2215 let h_xi = gram.dot(unit);
2216 unit.dot(&h_xi).max(0.0)
2217 }
2218 }
2219 }
2220}
2221
2222/// Evaluate the identifiability rank machinery on the symmetry generators of a
2223/// fitted SAE-manifold model and certify which gauge group the fit is identified
2224/// up to.
2225///
2226/// # Method
2227///
2228/// 1. Enumerate the symmetry generators as tangent directions on the flattened
2229/// decoder frames: per-atom `Isom(M_k)` generators
2230/// ([`atom_isometry_generators`]), equal-ARD rotations
2231/// ([`equal_ard_rotation_generators`]), global output-frame rotations
2232/// ([`frame_rotation_generators`]), and exchangeable-atom permutations
2233/// ([`atom_permutation_generators`]).
2234/// 2. Build the stacked curvature root `R` of the pinning operator
2235/// `H = H_data + H_isometry = RᵀR` in the fit's [`RowMetric`]
2236/// ([`stacked_curvature_root`]); the pinning RANK is the audit's RRQR rank
2237/// of `R`, reported alongside.
2238/// 3. For each generator `ξ`, the **relative curvature fraction**
2239/// `‖R ξ̂‖² / σ_max(R)²` measures the curvature the converged objective has
2240/// along the unit generator, relative to the model's stiffest direction.
2241/// `ξ` is **unpinned** (a residual gauge freedom) iff that fraction is at
2242/// or below the calibrated tolerance
2243/// `max(`[`GENERATOR_FLAT_ENERGY_TOL`]`, lowering_error_scale)` — flat up
2244/// to numerical noise and the mean-frame lowering's own resolution
2245/// ([`FittedAtom::lowering_error`], #995). Any larger fraction — including
2246/// the *mixed* regime where `ξ` carries both a curved and a flat component
2247/// — means the orbit costs objective, the exact group element is broken,
2248/// and the generator is **pinned**. (A span-membership or rank-increase
2249/// test degenerates when `R` is full-rank, which production fits always
2250/// are: every direction is "in the span", so verdicts would collapse to
2251/// all-pinned regardless of magnitudes. Keeping the curvature magnitudes
2252/// is what lets a genuinely flat direction stay visible inside a full-rank
2253/// span.) The fraction and the calibration scale are reported per
2254/// generator so partial flatness stays visible.
2255///
2256/// # Escalations
2257///
2258/// * When the isometry pin is inactive (`isometry_penalty_root` has no rows) the
2259/// report sets `diffeomorphism_unpinned = true`: with no metric pin the model
2260/// is only identified up to an arbitrary diffeomorphism of the latent
2261/// manifolds, so every isometry generator is a residual freedom.
2262/// * Under [`MetricProvenance::OutputFisher`] the `Sym(F)` permutation subgroup
2263/// is checked for triviality: every atom-exchange generator must be pinned
2264/// (the output-Fisher metric separates the atoms behaviorally). The result is
2265/// carried in `sym_f_trivial_under_output_fisher`.
2266pub fn residual_gauge(model: &FittedSaeManifold) -> Result<ResidualGaugeReport, String> {
2267 residual_gauge_inner(model, None, None)
2268}
2269
2270/// The #998 full-resolution certificate: within-atom gauge families are
2271/// realised as **exact orbits** in the model's own (decoder, coordinate)
2272/// parameter space for every atom that supplies an [`AtomParameterView`],
2273/// while cross-atom families (output-frame rotations, atom permutations) and
2274/// any unviewed atom (e.g. spheres, whose chart action is nonlinear) keep the
2275/// frame-space path with its #995 lowering-error calibration.
2276///
2277/// For a viewed atom the compensated orbit is a data-null **by construction**
2278/// when the basis family is closed under the group action — the verdict
2279/// carries no calibration (`lowering_error_scale = 0`), the compensation
2280/// residual is the computed closure, and all pinning of true model-class
2281/// symmetries flows through the per-atom [`OrbitPenaltyOperator`] channel
2282/// (the isometry pin / ARD prior — rungs 2 and 4 of the #981 ladder).
2283///
2284/// `views` and `penalty_ops` are aligned with `model.atoms`; a `None` view
2285/// keeps that atom entirely on the frame path. Supplying a view for an atom
2286/// whose pin is active without also supplying its penalty operator would
2287/// over-claim freedom, so callers must pass the operator (or no view) for
2288/// pinned atoms.
2289pub fn residual_gauge_exact(
2290 model: &FittedSaeManifold,
2291 views: &[Option<AtomParameterView>],
2292 penalty_ops: &[Option<OrbitPenaltyOperator>],
2293) -> Result<ResidualGaugeReport, String> {
2294 let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
2295 residual_gauge_inner(model, Some(exact), None)
2296}
2297
2298/// Exact-orbit residual-gauge certificate with a pre-reduced streamed curvature
2299/// Gram `RᵀR`.
2300///
2301/// This is the memory-scaled entry point for callers that can stream their
2302/// metric-whitened Jacobian rows into the reductions the certificate consumes,
2303/// instead of retaining every per-row `p × param_dim` Jacobian block. The Gram
2304/// must include the same rows [`stacked_curvature_root`] would have placed in
2305/// `R`; `root_rows` is that row count for the rank tolerance scale.
2306pub fn residual_gauge_exact_from_curvature_gram(
2307 model: &FittedSaeManifold,
2308 views: &[Option<AtomParameterView>],
2309 penalty_ops: &[Option<OrbitPenaltyOperator>],
2310 curvature_gram: Array2<f64>,
2311 root_rows: usize,
2312) -> Result<ResidualGaugeReport, String> {
2313 let param_dim = model.param_dim();
2314 let curvature = CurvatureReduction::from_gram(curvature_gram, root_rows, param_dim)?;
2315 let exact = residual_gauge_exact_inputs(model, views, penalty_ops)?;
2316 residual_gauge_inner(model, Some(exact), Some(curvature))
2317}
2318
2319fn residual_gauge_exact_inputs(
2320 model: &FittedSaeManifold,
2321 views: &[Option<AtomParameterView>],
2322 penalty_ops: &[Option<OrbitPenaltyOperator>],
2323) -> Result<(Vec<bool>, Vec<GeneratorVerdict>), String> {
2324 if views.len() != model.atoms.len() || penalty_ops.len() != model.atoms.len() {
2325 return Err(format!(
2326 "residual_gauge_exact: views ({}) and penalty_ops ({}) must align with atoms ({})",
2327 views.len(),
2328 penalty_ops.len(),
2329 model.atoms.len()
2330 ));
2331 }
2332 let mut mask = vec![false; model.atoms.len()];
2333 let mut exact_verdicts: Vec<GeneratorVerdict> = Vec::new();
2334 for (k, (atom, view)) in model.atoms.iter().zip(views.iter()).enumerate() {
2335 let Some(view) = view else { continue };
2336 // Sphere charts: nonlinear group action — refuse exactness, keep the
2337 // calibrated frame path for this atom rather than pretending.
2338 if matches!(atom.topology, AtomTopology::Sphere) {
2339 continue;
2340 }
2341 exact_verdicts.extend(exact_orbit_verdicts(atom, view, penalty_ops[k].as_ref())?);
2342 mask[k] = true;
2343 }
2344 Ok((mask, exact_verdicts))
2345}
2346
2347fn residual_gauge_inner(
2348 model: &FittedSaeManifold,
2349 exact: Option<(Vec<bool>, Vec<GeneratorVerdict>)>,
2350 precomputed_curvature: Option<CurvatureReduction>,
2351) -> Result<ResidualGaugeReport, String> {
2352 let metric_provenance = model.metric.provenance();
2353 let param_dim = model.param_dim();
2354 let (exact_mask, exact_verdicts) = match exact {
2355 Some((mask, verdicts)) => (Some(mask), verdicts),
2356 None => (None, Vec::new()),
2357 };
2358
2359 // 1. Enumerate generators, tagged by family. The per-atom builders speak
2360 // the atom's LOCAL flattened-frame coordinates (length `frame.len()`); the
2361 // certificate's rank arithmetic runs in the joint parameter vector, so each
2362 // local generator is embedded at its atom's offset here. (Single-atom
2363 // models have local == joint, which is why only multi-atom models can
2364 // expose a missed embedding.)
2365 // Each generator carries its #995 lowering-error tolerance scale: the
2366 // largest `lowering_error` over the atoms it touches.
2367 let scale_of = |k: usize| -> f64 { model.atoms[k].lowering_error.clamp(0.0, 1.0) };
2368 let global_scale = (0..model.atoms.len()).map(scale_of).fold(0.0_f64, f64::max);
2369 let mut gens: Vec<(GeneratorFamily, Array1<f64>, String, f64)> = Vec::new();
2370 for (k, atom) in model.atoms.iter().enumerate() {
2371 // Atoms whose within-atom families are realised exactly (#998) are
2372 // skipped here: the frame-space lift of a compensated orbit measures
2373 // compression, not the symmetry, and the report must not carry both a
2374 // lossy and an exact verdict for the same group element.
2375 if exact_mask.as_ref().is_some_and(|mask| mask[k]) {
2376 continue;
2377 }
2378 let base = model.atom_offset(k);
2379 for (g, desc) in atom_isometry_generators(atom) {
2380 gens.push((
2381 GeneratorFamily::IsomAtom,
2382 embed_local_generator(base, &g, param_dim),
2383 desc,
2384 scale_of(k),
2385 ));
2386 }
2387 for (g, desc) in equal_ard_rotation_generators(atom) {
2388 gens.push((
2389 GeneratorFamily::EqualArdRotation,
2390 embed_local_generator(base, &g, param_dim),
2391 desc,
2392 scale_of(k),
2393 ));
2394 }
2395 }
2396 for (g, desc) in frame_rotation_generators(model) {
2397 // A global output rotation moves every atom's frame at once.
2398 gens.push((GeneratorFamily::FrameRotation, g, desc, global_scale));
2399 }
2400 for (g, desc, ka, kb) in atom_permutation_generators(model) {
2401 gens.push((
2402 GeneratorFamily::AtomPermutation,
2403 g,
2404 desc,
2405 scale_of(ka).max(scale_of(kb)),
2406 ));
2407 }
2408
2409 // 2. Stacked curvature root in the metric; pinning rank via the audit's
2410 // RRQR on Rᵀ, stiffness scale σ_max via SVD (magnitudes kept).
2411 let curvature = match precomputed_curvature {
2412 Some(curvature) => curvature,
2413 None => CurvatureReduction::from_model(model)?,
2414 };
2415 let pinning_rank = curvature.pinning_rank();
2416 let sigma_max_sq = curvature.sigma_max_sq();
2417
2418 // The isometry pin is inactive ⇒ diffeomorphism-unpinned escalation.
2419 let diffeomorphism_unpinned = model.isometry_penalty_root.nrows() == 0;
2420
2421 // 3. Per-generator flatness verdict: relative curvature vs the calibrated
2422 // tolerance.
2423 let mut verdicts: Vec<GeneratorVerdict> = Vec::with_capacity(gens.len());
2424 for (family, g, description, lowering_error_scale) in &gens {
2425 let norm = g.iter().map(|v| v * v).sum::<f64>().sqrt();
2426 // A structurally trivial generator (rotation of a rank-deficient frame,
2427 // zero swap) carries no direction — it cannot be a residual freedom.
2428 // Report it pinned with zero norm rather than as a spurious gauge.
2429 if norm <= f64::MIN_POSITIVE {
2430 verdicts.push(GeneratorVerdict {
2431 family: *family,
2432 description: description.clone(),
2433 unpinned: false,
2434 generator_norm: 0.0,
2435 pinned_energy_fraction: 1.0,
2436 lowering_error_scale: *lowering_error_scale,
2437 provenance: VerdictProvenance::CurvatureTest,
2438 });
2439 continue;
2440 }
2441 // Relative curvature fraction ‖R ξ̂‖² / σ_max(R)² of the unit
2442 // generator ξ̂ = ξ/‖ξ‖. Exactly flat directions score 0 even inside a
2443 // full-rank span (production fits!), where the previous
2444 // span-membership rule degenerated to all-pinned. A MIXED generator
2445 // (strictly interior fraction) above the tolerance is pinned: its
2446 // orbit costs objective, so the exact symmetry does not survive
2447 // (#980 Theorem-2 arm). The tolerance is calibrated by the #995
2448 // lowering-error scale: curvature the mean-frame compression cannot
2449 // distinguish from gauge motion must not be read as a pin — the
2450 // certificate refuses to claim resolution it does not have.
2451 let pinned_energy_fraction = if sigma_max_sq <= f64::MIN_POSITIVE {
2452 0.0
2453 } else {
2454 let unit = g.mapv(|v| v / norm);
2455 (curvature.unit_generator_energy(&unit) / sigma_max_sq).clamp(0.0, 1.0)
2456 };
2457 let tolerance = GENERATOR_FLAT_ENERGY_TOL.max(*lowering_error_scale);
2458 let unpinned = pinned_energy_fraction <= tolerance;
2459 verdicts.push(GeneratorVerdict {
2460 family: *family,
2461 description: description.clone(),
2462 unpinned,
2463 generator_norm: norm,
2464 pinned_energy_fraction,
2465 lowering_error_scale: *lowering_error_scale,
2466 provenance: VerdictProvenance::CurvatureTest,
2467 });
2468 }
2469
2470 // Exact-orbit verdicts (#998) join the report on equal footing: the
2471 // group signature, residual dimension, and Sym(F) check all range over
2472 // the union.
2473 verdicts.extend(exact_verdicts);
2474
2475 // #1019 — post-fit arc-length chart canonicalization records: for every
2476 // canonicalized d = 1 atom the continuous chart (reparameterization)
2477 // freedom is pinned BY CONSTRUCTION (the unit-speed representative of the
2478 // Diff(M) orbit was selected post-fit, image-frozen), so the certificate
2479 // records it pinned with the PinnedByCanonicalization provenance —
2480 // distinct from curvature/penalty pinning, since no objective resistance
2481 // was measured — and names the surviving FINITE isometry group of the
2482 // reference manifold. The group's continuous part (the circle's U(1)
2483 // shift) is still enumerated and curvature-tested above; this record is
2484 // the chart-freedom downgrade itself.
2485 let mut canonicalized_charts = 0usize;
2486 let mut canonicalized_torus_charts = 0usize;
2487 let mut canonicalized_patch_charts = 0usize;
2488 let mut canonicalized_sphere_charts = 0usize;
2489 for atom in &model.atoms {
2490 if !atom.chart_canonicalized {
2491 continue;
2492 }
2493 let (pinned_to, residual_group) = match &atom.topology {
2494 AtomTopology::Circle | AtomTopology::Torus { latent_dim: 1 } => {
2495 canonicalized_charts += 1;
2496 ("arc length", "O(2) on S¹ (rotation + reflection)")
2497 }
2498 AtomTopology::EuclideanPatch { latent_dim: 1 } => {
2499 canonicalized_charts += 1;
2500 (
2501 "arc length",
2502 "reflection + translation of the unit interval",
2503 )
2504 }
2505 // #1019 stage 2: d = 2 torus charts are pinned post-fit to the
2506 // minimum-isometry-defect flow representative; the surviving chart
2507 // freedom is the isometry group of the flat square torus.
2508 AtomTopology::Torus { latent_dim: 2 } => {
2509 canonicalized_torus_charts += 1;
2510 (
2511 "the isometry-flow canonical chart",
2512 "Isom(T², flat) = U(1)² ⋊ D₄ (axis translations + axis swap/reflections)",
2513 )
2514 }
2515 // #1019 free-chart arm: d = 2 free/patch (Euclidean-patch) charts
2516 // are pinned post-fit to the flat-reference minimum-anisotropy-
2517 // defect flow representative; the surviving chart freedom is the
2518 // isometry group of the flat plane.
2519 AtomTopology::EuclideanPatch { latent_dim: 2 } => {
2520 canonicalized_patch_charts += 1;
2521 (
2522 "the flat-reference isometry-flow canonical chart",
2523 "Isom(ℝ², flat) = O(2) ⋉ ℝ² (rotation + reflection + translation)",
2524 )
2525 }
2526 // #1019 sphere arm: d = 2 sphere (S²) charts are pinned post-fit to
2527 // the round-sphere conformal-boost minimum-isometry-defect flow,
2528 // which breaks the conformal (Möbius) moduli down to the round
2529 // sphere's isometry group; the surviving chart freedom is O(3).
2530 AtomTopology::Sphere => {
2531 canonicalized_sphere_charts += 1;
2532 (
2533 "the round-sphere conformal-boost isometry-flow canonical chart",
2534 "Isom(S², round) = O(3) (rotations + reflection)",
2535 )
2536 }
2537 // Canonicalization only ever applies to d = 1 charts, d = 2 torus,
2538 // d = 2 free/patch, and d = 2 sphere charts; a flag on any other
2539 // topology is structurally inconsistent and must not fabricate a
2540 // record.
2541 _ => continue,
2542 };
2543 verdicts.push(GeneratorVerdict {
2544 family: GeneratorFamily::ChartReparameterization,
2545 description: format!(
2546 "{}: chart pinned to {pinned_to} by post-fit canonicalization; \
2547 residual chart freedom = {residual_group}",
2548 atom.name
2549 ),
2550 unpinned: false,
2551 generator_norm: 0.0,
2552 pinned_energy_fraction: 1.0,
2553 lowering_error_scale: 0.0,
2554 provenance: VerdictProvenance::PinnedByCanonicalization,
2555 });
2556 }
2557
2558 let residual_gauge_dim = verdicts.iter().filter(|v| v.unpinned).count();
2559
2560 // Sym(F)-triviality under any output-Fisher provenance — same-position
2561 // (`OutputFisher`) or downstream-influence (`OutputFisherDownstream`, #980).
2562 // Both behaviorally separate the atoms (the downstream metric strictly more,
2563 // since it sees far-future coupling the same-position metric misses), so the
2564 // permutation subgroup must be trivially pinned under either.
2565 let sym_f_trivial_under_output_fisher = if matches!(
2566 metric_provenance,
2567 MetricProvenance::OutputFisher { .. } | MetricProvenance::OutputFisherDownstream { .. }
2568 ) {
2569 let any_perm_unpinned = verdicts
2570 .iter()
2571 .any(|v| v.family == GeneratorFamily::AtomPermutation && v.unpinned);
2572 Some(!any_perm_unpinned)
2573 } else {
2574 None
2575 };
2576
2577 let summary = format!(
2578 "residual gauge certificate (computed in metric {metric_provenance:?}): \
2579 pinning rank {pinning_rank}, {residual_gauge_dim} unpinned residual gauge \
2580 generator(s) of {} enumerated; group = {}{}{}",
2581 verdicts.len(),
2582 group_signature_of(&verdicts, diffeomorphism_unpinned),
2583 match sym_f_trivial_under_output_fisher {
2584 Some(true) => "; Sym(F) trivially pinned under OutputFisher",
2585 Some(false) => "; ⚠ Sym(F) NON-trivial under OutputFisher (certificate violation)",
2586 None => "",
2587 },
2588 if diffeomorphism_unpinned {
2589 "; ⚠ isometry pin inactive"
2590 } else {
2591 ""
2592 },
2593 );
2594 let summary = if canonicalized_charts > 0 {
2595 format!(
2596 "{summary}; {canonicalized_charts} chart(s) pinned to arc length by post-fit \
2597 canonicalization (residual chart freedom = finite isometry group)"
2598 )
2599 } else {
2600 summary
2601 };
2602 let summary = if canonicalized_torus_charts > 0 {
2603 format!(
2604 "{summary}; {canonicalized_torus_charts} torus chart(s) pinned to the \
2605 isometry-flow canonical chart by post-fit canonicalization (residual chart \
2606 freedom = Isom(T², flat))"
2607 )
2608 } else {
2609 summary
2610 };
2611 let summary = if canonicalized_patch_charts > 0 {
2612 format!(
2613 "{summary}; {canonicalized_patch_charts} free/patch chart(s) pinned to the \
2614 flat-reference isometry-flow canonical chart by post-fit canonicalization \
2615 (residual chart freedom = Isom(ℝ², flat) = O(2) ⋉ ℝ²)"
2616 )
2617 } else {
2618 summary
2619 };
2620 let summary = if canonicalized_sphere_charts > 0 {
2621 format!(
2622 "{summary}; {canonicalized_sphere_charts} sphere chart(s) pinned to the \
2623 round-sphere conformal-boost isometry-flow canonical chart by post-fit \
2624 canonicalization (residual chart freedom = Isom(S², round) = O(3))"
2625 )
2626 } else {
2627 summary
2628 };
2629
2630 Ok(ResidualGaugeReport {
2631 metric_provenance,
2632 generators: verdicts,
2633 pinning_rank,
2634 residual_gauge_dim,
2635 diffeomorphism_unpinned,
2636 sym_f_trivial_under_output_fisher,
2637 // The #972 inner-rotation gauge is declared by the caller (it lives in
2638 // the (U_k, C_k) parameterization, not in the latent-frame coordinates
2639 // this certificate's generators are tangent to); frame-factored
2640 // dictionaries attach it via `with_frame_inner_rotation`.
2641 frame_inner_rotation: None,
2642 summary,
2643 })
2644}
2645
2646/// The model's two certificates, shipped together (#984 work-plan step 2):
2647/// the residual-gauge report says what NO data could distinguish (the
2648/// symmetry group the fit is identified up to — a statement about the
2649/// model class), the structure certificate says what THIS data
2650/// established (the e-BH-confirmed subset of the dictionary's structural
2651/// claims, FDR ≤ α, valid at the caller's stopping time — a statement
2652/// about the world). A claim can fail both ways, and the failure modes
2653/// are independent: an atom can be perfectly identified yet statistically
2654/// unestablished, or strongly evidenced yet gauge-ambiguous with a twin.
2655#[derive(Debug, Clone)]
2656pub struct DictionaryReport {
2657 /// What cannot be distinguished in principle ([`residual_gauge`]).
2658 pub gauge: ResidualGaugeReport,
2659 /// What the data established
2660 /// ([`gam_terms::inference::structure_evidence::StructureLedger::certify`]).
2661 pub structure: StructureCertificate,
2662 /// Per-atom inter-layer transport ladders (#1096). Empty when the caller
2663 /// has not supplied at least one atom's canonical coordinates across two or
2664 /// more layers. These reports are computed in the transport module's chart
2665 /// convention: circle coordinates are radians on `[0, 2π)`, while SAE
2666 /// canonical circle charts may use an arbitrary period and are rescaled by
2667 /// [`dictionary_report_with_transport_ladders`] before fitting.
2668 pub transport_ladders: Vec<AtomTransportLadderReport>,
2669 /// Per-atom post-PIRLS inference reports (#1097 penalty-debiased functional
2670 /// POINT summaries, #1103 split-LRT smooth-structure e-value), one entry
2671 /// per atom in [`FittedSaeManifold::atoms`] order. The #1099 per-atom
2672 /// curvature CI was removed under #1115 (a curvature BOUND is not an
2673 /// estimand and its SE conditioned on generated regressors); the surviving
2674 /// plug-in curvature point estimate lives on
2675 /// [`crate::manifold::CertificateInputs::per_atom_kappa_hat`],
2676 /// not here. Each report's
2677 /// fields are computed when the atom carries its fit-time
2678 /// [`AtomInnerFit`] byproducts and the relevant numerics succeed; otherwise
2679 /// the field is `None` (a bare certificate-only `FittedSaeManifold` — one
2680 /// built by the residual-gauge path with no fit harness — leaves every
2681 /// `inner_fit` `None`, so both fields are `None`).
2682 pub atom_inference: Vec<AtomInferenceReport>,
2683}
2684
2685/// Canonical per-layer coordinates for one atom, ready for the #1096 transport
2686/// ladder integration.
2687///
2688/// The caller owns extraction from the SAE fit: `layers[i]`, `coords[i]`, and
2689/// `topologies[i]` describe the same atom at the same layer. This type keeps
2690/// that extraction outside [`dictionary_report`] so the core certificate can be
2691/// wired without reaching into `SaeManifoldTerm`.
2692#[derive(Debug, Clone)]
2693pub struct AtomTransportLadderInput {
2694 /// Index into [`FittedSaeManifold::atoms`].
2695 pub atom_index: usize,
2696 /// Layer labels in ladder order.
2697 pub layers: Vec<usize>,
2698 /// One canonical coordinate vector per layer, all over the same rows.
2699 pub coords: Vec<Array1<f64>>,
2700 /// One canonical chart topology per layer.
2701 pub topologies: Vec<CanonicalChartTopology>,
2702}
2703
2704/// One atom's fitted inter-layer transport ladder.
2705#[derive(Debug, Clone)]
2706pub struct AtomTransportLadderReport {
2707 pub atom_index: usize,
2708 pub atom_name: String,
2709 pub report: TransportLadderReport,
2710}
2711
2712/// #1097 penalty-debiased smooth-functional POINT summaries for one atom's
2713/// captured inner-decoder smooth (narrowed under #1115).
2714///
2715/// All three functionals are *linear* in the atom's fitted coefficient vector
2716/// `β_{k,j}`, so each is one-step penalty-debiased through the SAME penalized
2717/// Hessian the identifiability certificate's curvature sees
2718/// ([`AtomInnerFit::penalized_hessian`]) by routing the functional gradient,
2719/// the per-row scores, and the penalty gradient `S̃_k β` through
2720/// [`debias_with_dense_hessian`]. Only the resulting POINT estimates (plug-in,
2721/// penalty-debiased, removed bias) are kept; the influence-function SE is
2722/// discarded because it conditions on the generated latent coordinates `t̂` /
2723/// assignment `â` as if known and so under-covers (see
2724/// [`AtomFunctionalReport`] for the full argument). A non-SPD Hessian or a
2725/// degenerate functional (empty design, non-finite gradient) leaves the
2726/// offending field `None`; the other two still report.
2727fn atom_functional_report(fit: &AtomInnerFit) -> AtomFunctionalReport {
2728 let penalty_beta = fit.penalty.dot(&fit.beta);
2729
2730 // A small closed-form helper: build the Riesz input for a functional
2731 // gradient and penalty-debias it through the fitted penalized Hessian, then
2732 // KEEP ONLY the point estimates (the SE is not honest here — #1115). The
2733 // Riesz layer's own `EstimationError` is collapsed into `None` — a numerical
2734 // refusal is a missing field, not a poisoned report.
2735 let debias = |functional_gradient: Array1<f64>| -> Option<AtomFunctionalEstimate> {
2736 let input = RieszInput {
2737 beta: fit.beta.view(),
2738 functional_gradient: functional_gradient.view(),
2739 row_scores: fit.row_scores.view(),
2740 penalty_beta: penalty_beta.view(),
2741 leverage: None,
2742 };
2743 debias_with_dense_hessian(&input, fit.penalized_hessian.view())
2744 .ok()
2745 .map(|r| AtomFunctionalEstimate {
2746 theta_plugin: r.theta_plugin,
2747 theta_onestep: r.theta_onestep,
2748 penalty_bias: r.penalty_bias,
2749 })
2750 };
2751
2752 // Peak-vs-mode contrast g(t_peak) − g(t_mode): the linear functional whose
2753 // gradient is the difference of the two design rows.
2754 let peak_contrast = SmoothFunctional::Contrast {
2755 design_row_a: fit.peak_design_row.view(),
2756 design_row_b: fit.mode_design_row.view(),
2757 }
2758 .gradient()
2759 .ok()
2760 .and_then(debias);
2761
2762 // E_data[g(t_i)]: the mass-weighted average decoder value over active rows.
2763 let average_value = SmoothFunctional::AverageValue {
2764 value_design: fit.design.view(),
2765 weights: Some(fit.weights.view()),
2766 }
2767 .gradient()
2768 .ok()
2769 .and_then(debias);
2770
2771 // ‖E_data[∂g/∂t]‖ along the leading latent axis: the mass-weighted average
2772 // of the derivative-design rows (the Gauss–Newton weights `w_i = a_ik²` are
2773 // the data measure over the atom's active rows). This is the conditional-
2774 // on-fit decoder-VARIATION norm, not a population marginal slope.
2775 let decoder_variation_norm = SmoothFunctional::AverageDerivative {
2776 derivative_design: fit.derivative_design.view(),
2777 weights: Some(fit.weights.view()),
2778 }
2779 .gradient()
2780 .ok()
2781 .and_then(debias);
2782
2783 AtomFunctionalReport {
2784 peak_contrast,
2785 average_value,
2786 decoder_variation_norm,
2787 }
2788}
2789
2790/// #1103 Any-n-valid structure evidence that one atom's inner smooth is
2791/// non-constant, via the split-likelihood-ratio e-value.
2792///
2793/// The inner decoder smooth is the Gaussian-identity penalized WLS fit
2794/// `a_ik · Φ_k(t)ᵀ β_{k,j}` with dispersion `φ = `[`AtomInnerFit::dispersion`],
2795/// working response `z_i` reconstructed from the captured per-row scores. H0 is
2796/// "the smooth is constant": only the intercept column 0 is free.
2797///
2798/// We compute the universal-inference e-value the atom-birth gate
2799/// ([`gam_terms::inference::structure_evidence::split_likelihood_log_e_value`]) uses:
2800///
2801/// * Split the active rows deterministically into an ESTIMATION fold (even
2802/// index) and an EVALUATION fold (odd index).
2803/// * On the estimation fold, fit the penalized smooth (the alternative) by
2804/// `β̂ = (ΦᵀWΦ + S)⁻¹ ΦᵀW z` — any fitter is admissible; zero conditions.
2805/// * On the evaluation fold, score the Gaussian log-likelihood under that
2806/// prefit alternative, and the SUPREMUM of the evaluation-fold log-likelihood
2807/// over the null class (the constant fit = weighted-mean response refit on the
2808/// eval fold — the honest constrained sup on D₀).
2809/// * `log E = ℓ_alt(D₀) − sup_{H0} ℓ(D₀)`, with `E_{H0}[E] ≤ 1` exactly.
2810///
2811/// The dispersion `φ` is held fixed at the fitted reconstruction dispersion in
2812/// both log-likelihoods so it cancels structurally and the e-value isolates the
2813/// mean-curvature evidence. Returns `None` when the design has no curvature
2814/// column (`M_k ≤ 1`), either fold is empty, or the inner Gram is not SPD.
2815fn atom_smooth_significance(fit: &AtomInnerFit) -> Option<AtomSmoothSignificance> {
2816 let m = fit.design.ncols();
2817 if m <= 1 || fit.beta.len() != m {
2818 // No curvature column: the constant null IS the full model — there is no
2819 // non-constant alternative to earn an e-value.
2820 return None;
2821 }
2822 let n = fit.design.nrows();
2823 if n == 0 || fit.weights.len() != n || fit.row_scores.nrows() != n {
2824 return None;
2825 }
2826 let phi = if fit.dispersion.is_finite() && fit.dispersion > 0.0 {
2827 fit.dispersion
2828 } else {
2829 return None;
2830 };
2831
2832 // Per-row working response z_i = μ̂_i + r_i, reconstructing the scalar
2833 // residual r_i from the captured score projected onto the design row
2834 // (s_iᵀ Φ_i = −w_i r_i ‖Φ_i‖² / φ ⇒ r_i). Same reconstruction the previous
2835 // deviance path used; here it feeds the two folds' likelihoods.
2836 let mut z = Array1::<f64>::zeros(n);
2837 for i in 0..n {
2838 let mu_hat = fit.design.row(i).dot(&fit.beta);
2839 let w_i = fit.weights[i];
2840 let phi_row = fit.design.row(i);
2841 let phi_norm_sq = phi_row.dot(&phi_row);
2842 let r_i = if w_i > 0.0 && phi_norm_sq > 0.0 {
2843 let s_dot_phi = fit.row_scores.row(i).dot(&phi_row);
2844 -phi * s_dot_phi / (w_i * phi_norm_sq)
2845 } else {
2846 0.0
2847 };
2848 z[i] = mu_hat + r_i;
2849 }
2850
2851 // Deterministic estimation/evaluation split by row parity.
2852 let est: Vec<usize> = (0..n).filter(|i| i % 2 == 0).collect();
2853 let eval: Vec<usize> = (0..n).filter(|i| i % 2 == 1).collect();
2854 if est.is_empty() || eval.is_empty() {
2855 return None;
2856 }
2857
2858 // Penalized smooth fit on the estimation fold: β̂ = (ΦᵀWΦ + S)⁻¹ ΦᵀW z.
2859 let mut a_gram = fit.penalty.clone();
2860 let mut b = Array1::<f64>::zeros(m);
2861 for &i in &est {
2862 let w_i = fit.weights[i];
2863 if !(w_i > 0.0) {
2864 continue;
2865 }
2866 let row = fit.design.row(i);
2867 for r in 0..m {
2868 let xr = row[r];
2869 if xr == 0.0 {
2870 continue;
2871 }
2872 b[r] += w_i * xr * z[i];
2873 for c in 0..m {
2874 a_gram[[r, c]] += w_i * xr * row[c];
2875 }
2876 }
2877 }
2878 let beta_alt = a_gram.cholesky(Side::Lower).ok()?.solvevec(&b);
2879
2880 // Null sup on the EVALUATION fold: the weighted-mean response (the constant
2881 // fit's MLE on D₀, the honest constrained sup over the null class).
2882 let mut eval_mass = 0.0_f64;
2883 let mut eval_wz = 0.0_f64;
2884 for &i in &eval {
2885 let w_i = fit.weights[i];
2886 eval_mass += w_i;
2887 eval_wz += w_i * z[i];
2888 }
2889 if !(eval_mass > 0.0) {
2890 return None;
2891 }
2892 let null_mean = eval_wz / eval_mass;
2893
2894 // Gaussian log-likelihoods on the evaluation fold at fixed dispersion φ;
2895 // the −½ log(2πφ) and weight-log terms are identical under both models, so
2896 // log E = −(½/φ) [ Σ w(z − μ_alt)² − Σ w(z − μ_null)² ].
2897 let mut sse_alt = 0.0_f64;
2898 let mut sse_null = 0.0_f64;
2899 for &i in &eval {
2900 let w_i = fit.weights[i];
2901 let mu_alt = fit.design.row(i).dot(&beta_alt);
2902 let r_alt = z[i] - mu_alt;
2903 let r_null = z[i] - null_mean;
2904 sse_alt += w_i * r_alt * r_alt;
2905 sse_null += w_i * r_null * r_null;
2906 }
2907 let log_lik_alt = -0.5 * sse_alt / phi;
2908 let log_lik_null_sup = -0.5 * sse_null / phi;
2909 let log_e = gam_terms::inference::structure_evidence::split_likelihood_log_e_value(
2910 log_lik_alt,
2911 log_lik_null_sup,
2912 );
2913 if !log_e.is_finite() {
2914 return None;
2915 }
2916
2917 Some(AtomSmoothSignificance {
2918 log_e_nonconstant: Some(log_e),
2919 })
2920}
2921
2922/// Assemble the post-PIRLS inference reports for every atom, reusing the
2923/// per-atom [`AtomInnerFit`] harvested at fit time.
2924///
2925/// * #1097 penalty-debiased functional POINT summaries and the #1103 split-LRT
2926/// smooth-structure e-value are computed from the captured inner-decoder
2927/// smooth (design, penalized Hessian, row scores, roughness Gram) — they need
2928/// only the fixed fitted snapshot.
2929/// * The #1099 per-atom curvature *confidence interval* was removed under #1115:
2930/// a sup-norm curvature BOUND is not an estimand with a profiled criterion,
2931/// and its delta-method SE conditioned on generated latent coordinates as if
2932/// known. The plug-in curvature point estimate survives on
2933/// [`crate::manifold::CertificateInputs::per_atom_kappa_hat`] (the
2934/// #1008 empirical curved-dictionary report), not on this report.
2935pub(crate) fn atom_inference_reports(model: &FittedSaeManifold) -> Vec<AtomInferenceReport> {
2936 model
2937 .atoms
2938 .iter()
2939 .enumerate()
2940 .map(|(atom_index, atom)| {
2941 let (functionals, smooth_significance) = match &atom.inner_fit {
2942 Some(fit) => (
2943 Some(atom_functional_report(fit)),
2944 atom_smooth_significance(fit),
2945 ),
2946 None => (None, None),
2947 };
2948 AtomInferenceReport {
2949 atom_index,
2950 atom_name: atom.name.clone(),
2951 functionals,
2952 smooth_significance,
2953 }
2954 })
2955 .collect()
2956}
2957
2958/// Produce the paired certificate for a fitted model: the residual-gauge
2959/// report computed here plus the anytime-valid structure certificate from
2960/// the discovery run's evidence ledger at level `alpha`. The ledger is the
2961/// one the structure search absorbed its shard evidence into
2962/// (`structure_evidence::StructureLedger`); certifying at any
2963/// data-dependent stopping time is sound — that is the ledger's whole
2964/// design.
2965pub fn dictionary_report(
2966 model: &FittedSaeManifold,
2967 ledger: &StructureLedger,
2968 alpha: f64,
2969) -> Result<DictionaryReport, String> {
2970 Ok(DictionaryReport {
2971 gauge: residual_gauge(model)?,
2972 structure: ledger.certify(alpha),
2973 transport_ladders: Vec::new(),
2974 atom_inference: atom_inference_reports(model),
2975 })
2976}
2977
2978// --- #1100: closed-loop probe runner FFI ---------------------------------
2979// Top-level entry points exposing the steering→structure-evidence probe loop
2980// (`crate::inference::probe_runner::ProbeRunner`) beside `dictionary_report`, so
2981// the Python driver can design and absorb interventional probes against the same
2982// fitted term and evidence ledger the certificate is built from.
2983
2984/// Design the next interventional probe for the most contested steerable claim
2985/// in `ledger`, against the fitted SAE-manifold `term` read through its per-row
2986/// output-Fisher `metric`.
2987///
2988/// Thin top-level wrapper over [`crate::inference::probe_runner::ProbeRunner::design_next`]:
2989/// it selects the contested claim furthest from certification, realizes candidate
2990/// latent moves of its atom through `crate::inference::steering::steer_delta`,
2991/// and routes their doses through
2992/// `gam_terms::inference::structure_evidence::plan_probe_for_contested_claim` to pick
2993/// the most discriminating one. The returned
2994/// [`crate::inference::probe_runner::RealizedProbe`] carries both the experiment
2995/// plan and the chosen intervention's on-manifold activation delta with its
2996/// dosimetry and validity radius.
2997pub fn design_probe(
2998 term: &SaeManifoldTerm,
2999 metric: &RowMetric,
3000 ledger: &StructureLedger,
3001) -> Result<RealizedProbe, String> {
3002 ProbeRunner { term, metric }.design_next(ledger)
3003}
3004
3005/// Absorb a realized probe outcome into `ledger`, banking the delivered
3006/// behavioral dose (`realized_nats`, the observed output-Fisher KL of the steered
3007/// response) as anytime-valid evidence for the probe's claim.
3008///
3009/// Thin top-level wrapper over [`crate::inference::probe_runner::ProbeRunner::absorb`].
3010pub fn absorb_probe(
3011 term: &SaeManifoldTerm,
3012 metric: &RowMetric,
3013 ledger: &mut StructureLedger,
3014 probe: &RealizedProbe,
3015 realized_nats: f64,
3016) {
3017 ProbeRunner { term, metric }.absorb(ledger, probe, realized_nats);
3018}
3019
3020/// Produce the paired certificate plus #1096 per-atom layer-transport ladders.
3021///
3022/// This is the strict wiring seam for callers that already have canonical
3023/// per-layer atom coordinates. It validates atom indices, topology/coordinate
3024/// lengths, finite coordinates, and the circle-period convention before calling
3025/// [`transport_ladder`]. Single-layer inputs are refused: no transport estimand
3026/// exists without at least one adjacent layer pair.
3027pub fn dictionary_report_with_transport_ladders(
3028 model: &FittedSaeManifold,
3029 ledger: &StructureLedger,
3030 alpha: f64,
3031 ladders: &[AtomTransportLadderInput],
3032) -> Result<DictionaryReport, String> {
3033 let mut report = dictionary_report(model, ledger, alpha)?;
3034 report.transport_ladders = atom_transport_ladder_reports(model, ladders)?;
3035 Ok(report)
3036}
3037
3038/// Fit #1096 transport ladders for the supplied atom/layer coordinate blocks.
3039pub fn atom_transport_ladder_reports(
3040 model: &FittedSaeManifold,
3041 ladders: &[AtomTransportLadderInput],
3042) -> Result<Vec<AtomTransportLadderReport>, String> {
3043 let mut out = Vec::with_capacity(ladders.len());
3044 for input in ladders {
3045 let atom = model.atoms.get(input.atom_index).ok_or_else(|| {
3046 format!(
3047 "atom transport ladder index {} out of range for {} fitted atoms",
3048 input.atom_index,
3049 model.atoms.len()
3050 )
3051 })?;
3052 let depth = input.layers.len();
3053 if depth < 2 {
3054 return Err(format!(
3055 "atom transport ladder for atom {} ('{}') needs at least two layers, got {depth}",
3056 input.atom_index, atom.name
3057 ));
3058 }
3059 if input.coords.len() != depth || input.topologies.len() != depth {
3060 return Err(format!(
3061 "atom transport ladder for atom {} ('{}') has {} layers, {} coordinate blocks, {} topologies",
3062 input.atom_index,
3063 atom.name,
3064 depth,
3065 input.coords.len(),
3066 input.topologies.len()
3067 ));
3068 }
3069
3070 let mut coords = Vec::with_capacity(depth);
3071 let mut topologies = Vec::with_capacity(depth);
3072 for (layer_pos, (coord, topology)) in
3073 input.coords.iter().zip(input.topologies.iter()).enumerate()
3074 {
3075 coords.push(canonical_coords_for_transport(
3076 coord,
3077 topology,
3078 input.atom_index,
3079 &atom.name,
3080 input.layers[layer_pos],
3081 )?);
3082 topologies.push(ChartTopology::from(topology));
3083 }
3084
3085 let report = transport_ladder(&input.layers, &coords, &topologies).map_err(|e| {
3086 format!(
3087 "atom transport ladder for atom {} ('{}') failed: {e}",
3088 input.atom_index, atom.name
3089 )
3090 })?;
3091 out.push(AtomTransportLadderReport {
3092 atom_index: input.atom_index,
3093 atom_name: atom.name.clone(),
3094 report,
3095 });
3096 }
3097 Ok(out)
3098}
3099
3100fn canonical_coords_for_transport(
3101 coords: &Array1<f64>,
3102 topology: &CanonicalChartTopology,
3103 atom_index: usize,
3104 atom_name: &str,
3105 layer: usize,
3106) -> Result<Array1<f64>, String> {
3107 if coords.iter().any(|v| !v.is_finite()) {
3108 return Err(format!(
3109 "atom transport ladder for atom {atom_index} ('{atom_name}') layer {layer} has non-finite coordinates"
3110 ));
3111 }
3112 match topology {
3113 CanonicalChartTopology::Circle { period } => {
3114 if !(period.is_finite() && *period > 0.0) {
3115 return Err(format!(
3116 "atom transport ladder for atom {atom_index} ('{atom_name}') layer {layer} has invalid circle period {period}"
3117 ));
3118 }
3119 Ok(coords.mapv(|t| (t / *period) * TAU))
3120 }
3121 CanonicalChartTopology::Interval => Ok(coords.clone()),
3122 }
3123}
3124
3125// ----------------------------------------------------------------------------
3126// #1102 cross-checkpoint atom-dynamics FFI entry (new top-level block).
3127// ----------------------------------------------------------------------------
3128
3129/// Run #1102 cross-checkpoint Riesz-debiased atom-trajectory dynamics for the
3130/// fitted dictionary's atoms.
3131///
3132/// `decoder_grid` is `[n_checkpoints, n_atoms, n_grid, ambient_dim]` and
3133/// `atom_names`/`checkpoint_ids`/`latent_grid` label its axes; see
3134/// [`crate::inference::checkpoint_dynamics`] for the estimator and the honest
3135/// accounting of which Riesz inputs the bare grid supports. This entry binds
3136/// the atom axis to the fitted model: `atom_names` must name exactly the
3137/// model's atoms in order, so trajectories are reported against real atoms.
3138pub fn atom_checkpoint_dynamics(
3139 model: &FittedSaeManifold,
3140 decoder_grid: ndarray::ArrayView4<'_, f64>,
3141 checkpoint_ids: &[String],
3142 atom_names: &[String],
3143 latent_grid: ArrayView1<'_, f64>,
3144) -> Result<Vec<crate::inference::checkpoint_dynamics::AtomTrajectory>, String> {
3145 if atom_names.len() != model.atoms.len() {
3146 return Err(format!(
3147 "atom_checkpoint_dynamics: {} atom names supplied for {} fitted atoms",
3148 atom_names.len(),
3149 model.atoms.len()
3150 ));
3151 }
3152 for (idx, (supplied, fitted)) in atom_names.iter().zip(model.atoms.iter()).enumerate() {
3153 if supplied != &fitted.name {
3154 return Err(format!(
3155 "atom_checkpoint_dynamics: atom {idx} name '{supplied}' does not match fitted atom '{}'",
3156 fitted.name
3157 ));
3158 }
3159 }
3160 crate::inference::checkpoint_dynamics::checkpoint_atom_dynamics(
3161 &crate::inference::checkpoint_dynamics::CheckpointDynamicsInput {
3162 decoder_grid,
3163 checkpoint_ids,
3164 atom_names,
3165 latent_grid,
3166 },
3167 )
3168}
3169
3170#[cfg(test)]
3171mod tests {
3172 use super::*;
3173 use ndarray::{Array1, array};
3174
3175 /// #1097: the per-atom penalty-debiased functional point summaries must
3176 /// reproduce the exact linear functionals of the fitted decoder smooth
3177 /// (plug-in) and a finite debiased value, on a synthetic atom whose inner
3178 /// smooth is an analytic polynomial. No SE/CI is asserted — none is reported
3179 /// (#1115).
3180 #[test]
3181 fn atom_functional_report_recovers_known_functionals() {
3182 use ndarray::{Array1 as A1, Array2 as A2};
3183 // Polynomial basis Φ(t) = [1, t, t²] on a uniform active grid; the atom's
3184 // fitted smooth is g(t) = β·Φ(t) with a known β. We assemble a genuine
3185 // penalized-WLS AtomInnerFit (unit weights, identity-ish penalty) so the
3186 // Riesz path runs end to end.
3187 let n = 40usize;
3188 let m = 3usize;
3189 let beta = A1::from(vec![0.5_f64, -1.0, 2.0]);
3190 let mut design = A2::<f64>::zeros((n, m));
3191 let mut derivative_design = A2::<f64>::zeros((n, m));
3192 let mut weights = A1::<f64>::ones(n);
3193 let mut t = vec![0.0_f64; n];
3194 for i in 0..n {
3195 let ti = i as f64 / (n - 1) as f64;
3196 t[i] = ti;
3197 design[[i, 0]] = 1.0;
3198 design[[i, 1]] = ti;
3199 design[[i, 2]] = ti * ti;
3200 // dΦ/dt = [0, 1, 2t].
3201 derivative_design[[i, 0]] = 0.0;
3202 derivative_design[[i, 1]] = 1.0;
3203 derivative_design[[i, 2]] = 2.0 * ti;
3204 weights[i] = 1.0;
3205 }
3206 let dispersion = 1.0_f64;
3207 // Working response equals the fitted curve so residuals are zero → the
3208 // plug-in is exactly the analytic functional of β; scores are zero.
3209 let row_scores = A2::<f64>::zeros((n, m));
3210 // Penalty S = small ridge on curvature column only; penalized Hessian
3211 // H = ΦᵀWΦ + S.
3212 let mut penalty = A2::<f64>::zeros((m, m));
3213 penalty[[2, 2]] = 1e-3;
3214 let mut xtwx = A2::<f64>::zeros((m, m));
3215 for i in 0..n {
3216 for a in 0..m {
3217 for b in 0..m {
3218 xtwx[[a, b]] += weights[i] * design[[i, a]] * design[[i, b]];
3219 }
3220 }
3221 }
3222 let penalized_hessian = &xtwx + &penalty;
3223 // Peak: |g| largest; mode: pick endpoints to give a known contrast.
3224 let mut peak_slot = 0usize;
3225 let mut peak_val = -1.0;
3226 for i in 0..n {
3227 let g = design.row(i).dot(&beta).abs();
3228 if g > peak_val {
3229 peak_val = g;
3230 peak_slot = i;
3231 }
3232 }
3233 let peak_design_row = design.row(peak_slot).to_owned();
3234 let mode_design_row = design.row(0).to_owned();
3235
3236 let fit = AtomInnerFit {
3237 design: design.clone(),
3238 derivative_design: derivative_design.clone(),
3239 beta: beta.clone(),
3240 penalty,
3241 penalized_hessian,
3242 row_scores,
3243 weights: weights.clone(),
3244 dispersion,
3245 peak_design_row: peak_design_row.clone(),
3246 mode_design_row: mode_design_row.clone(),
3247 };
3248
3249 let report = atom_functional_report(&fit);
3250
3251 // Average value E_w[g] = mean_i β·Φ(t_i): exact plug-in match.
3252 let av = report.average_value.expect("average value");
3253 let expected_av: f64 = (0..n).map(|i| design.row(i).dot(&beta)).sum::<f64>() / n as f64;
3254 assert!(
3255 (av.theta_plugin - expected_av).abs() < 1e-9,
3256 "average value plug-in {} vs expected {}",
3257 av.theta_plugin,
3258 expected_av
3259 );
3260 // Point summary only: the debiased value is finite (no SE/CI is
3261 // reported by design — #1115).
3262 assert!(
3263 av.theta_onestep.is_finite(),
3264 "average-value debiased finite"
3265 );
3266
3267 // Decoder-variation norm (conditional on fit): g'(t) = β1 + 2β2 t, mean
3268 // over the grid is β1 + 2β2 * mean(t). The functional gradient is the
3269 // mean derivative row; its plug-in is exactly that scalar. This is the
3270 // descriptive variation of the fitted curve, not a population marginal
3271 // slope.
3272 let ad = report
3273 .decoder_variation_norm
3274 .expect("decoder variation norm");
3275 let mean_t: f64 = t.iter().sum::<f64>() / n as f64;
3276 let expected_ad = beta[1] + 2.0 * beta[2] * mean_t;
3277 assert!(
3278 (ad.theta_plugin - expected_ad).abs() < 1e-9,
3279 "decoder variation plug-in {} vs expected {}",
3280 ad.theta_plugin,
3281 expected_ad
3282 );
3283
3284 // Peak-vs-mode contrast g(t_peak) − g(t_mode): exact plug-in.
3285 let pc = report.peak_contrast.expect("peak contrast");
3286 let expected_pc = peak_design_row.dot(&beta) - mode_design_row.dot(&beta);
3287 assert!(
3288 (pc.theta_plugin - expected_pc).abs() < 1e-9,
3289 "peak contrast plug-in {} vs expected {}",
3290 pc.theta_plugin,
3291 expected_pc
3292 );
3293 }
3294
3295 #[test]
3296 fn mechanism_sparsity_jacobian_value_matches_closed_form() {
3297 let w = array![[3.0_f64, 0.0], [4.0, 0.0]]; // col0 norm=5, col1 norm=0
3298 let pen = MechanismSparsityJacobian::new(1.0, 1.0e-8).unwrap();
3299 let (v, _g) = pen.value_and_grad(w.view());
3300 assert!((v - 5.0).abs() < 1e-6, "value {v} expected ≈5");
3301 }
3302
3303 #[test]
3304 fn mechanism_sparsity_jacobian_grad_matches_finite_diff() {
3305 let w = array![[0.5_f64, -1.2, 0.3], [1.1, 0.4, -0.7]];
3306 let pen = MechanismSparsityJacobian::new(2.5, 1.0e-6).unwrap();
3307 let (_, g) = pen.value_and_grad(w.view());
3308 let h = 1.0e-5;
3309 for i in 0..w.nrows() {
3310 for j in 0..w.ncols() {
3311 let mut wp = w.clone();
3312 let mut wm = w.clone();
3313 wp[[i, j]] += h;
3314 wm[[i, j]] -= h;
3315 let (vp, _) = pen.value_and_grad(wp.view());
3316 let (vm, _) = pen.value_and_grad(wm.view());
3317 let fd = (vp - vm) / (2.0 * h);
3318 assert!(
3319 (g[[i, j]] - fd).abs() < 1e-4,
3320 "grad[{i},{j}] = {} vs fd {}",
3321 g[[i, j]],
3322 fd
3323 );
3324 }
3325 }
3326 }
3327
3328 #[test]
3329 fn mechanism_sparsity_jacobian_rejects_bad_input() {
3330 assert!(MechanismSparsityJacobian::new(-1.0, 1e-6).is_err());
3331 assert!(MechanismSparsityJacobian::new(1.0, 0.0).is_err());
3332 }
3333
3334 #[test]
3335 fn frame_inner_rotation_dim_is_sum_of_so_r_dims() {
3336 // dim O(r) = r(r−1)/2 per factored atom; rank-1 frames contribute 0.
3337 assert_eq!(frame_inner_rotation_dim(&[]), 0);
3338 assert_eq!(frame_inner_rotation_dim(&[1]), 0);
3339 assert_eq!(frame_inner_rotation_dim(&[2]), 1);
3340 assert_eq!(frame_inner_rotation_dim(&[4]), 6);
3341 assert_eq!(frame_inner_rotation_dim(&[1, 4, 8]), 0 + 6 + 28);
3342 assert_eq!(
3343 FrameInnerRotationGauge::from_ranks(vec![3, 3]).dim,
3344 6,
3345 "two rank-3 frames carry 2·3 inner-rotation dims"
3346 );
3347 }
3348
3349 /// The #972 inner-rotation gauge is enumerated in the certificate, never
3350 /// curvature-tested: attaching it must not change any generator verdict
3351 /// or the residual_gauge_dim, but it MUST change the group signature and
3352 /// the summary — two replicate frame-factored fits agree on their gauge
3353 /// iff they also agree on this enumerated, convention-fixed part.
3354 #[test]
3355 fn frame_inner_rotation_attaches_to_the_certificate_without_verdict_change() {
3356 let base = ResidualGaugeReport {
3357 metric_provenance: MetricProvenance::Euclidean,
3358 generators: Vec::new(),
3359 pinning_rank: 5,
3360 residual_gauge_dim: 0,
3361 diffeomorphism_unpinned: false,
3362 sym_f_trivial_under_output_fisher: None,
3363 frame_inner_rotation: None,
3364 summary: "base".to_string(),
3365 };
3366 let sig_before = base.group_signature();
3367 let report = base.with_frame_inner_rotation(vec![1, 4, 8]);
3368 assert_eq!(
3369 report.frame_inner_rotation,
3370 Some(FrameInnerRotationGauge {
3371 per_atom_ranks: vec![1, 4, 8],
3372 dim: 34,
3373 })
3374 );
3375 // Verdict-side facts untouched.
3376 assert_eq!(report.residual_gauge_dim, 0);
3377 assert!(report.generators.is_empty());
3378 // Signature and summary carry the enumeration.
3379 let sig_after = report.group_signature();
3380 assert_ne!(sig_before, sig_after);
3381 assert!(sig_after.contains("frame-inner"), "got: {sig_after}");
3382 assert!(sig_after.contains("dim 34"), "got: {sig_after}");
3383 assert!(sig_after.contains("canonical-fixed"), "got: {sig_after}");
3384 assert!(report.summary.contains("inner-rotation gauge"));
3385
3386 // A dictionary of rank-1 atoms has a zero-dimensional inner gauge:
3387 // enumerated (Some), but the signature is unchanged — there is
3388 // nothing to fix beyond the orientation sign convention.
3389 let trivial = ResidualGaugeReport {
3390 metric_provenance: MetricProvenance::Euclidean,
3391 generators: Vec::new(),
3392 pinning_rank: 0,
3393 residual_gauge_dim: 0,
3394 diffeomorphism_unpinned: false,
3395 sym_f_trivial_under_output_fisher: None,
3396 frame_inner_rotation: None,
3397 summary: "base".to_string(),
3398 };
3399 let sig_trivial_before = trivial.group_signature();
3400 let trivial = trivial.with_frame_inner_rotation(vec![1, 1, 1]);
3401 assert_eq!(
3402 trivial.frame_inner_rotation.as_ref().map(|g| g.dim),
3403 Some(0)
3404 );
3405 assert_eq!(trivial.group_signature(), sig_trivial_before);
3406 assert_eq!(trivial.summary, "base");
3407 }
3408
3409 /// Build a `(n, d)` `(mean, scale)` pair whose stacked signature
3410 /// `[μ ‖ log σ]` has full rank `2d` (so it satisfies the Khemakhem
3411 /// Theorem 1 precondition baked into `ConditionalPriorIvae::new`).
3412 ///
3413 /// Each per-column function is given a distinct *frequency* (not a
3414 /// shared frequency with a column-dependent phase) so the resulting
3415 /// `2d` columns are genuinely linearly independent. `sin(ω·t + φ)`
3416 /// with a shared `ω` lives in the 2-dimensional span of `{sin(ω t),
3417 /// cos(ω t)}`, so the earlier `sin(0.7t + 0.3c)` / `cos(0.5t + 0.9c)`
3418 /// fixture only ever produced rank `≤ 4`, no matter how many `d`
3419 /// columns it built. Distinct frequencies push each column into its
3420 /// own subspace, so for `n ≥ 2d + 1` the SVD of `[μ ‖ log σ]` has
3421 /// `2d` non-trivial singular values.
3422 fn ivae_precondition_pair(n: usize, d: usize) -> (Array2<f64>, Array2<f64>) {
3423 assert!(n >= 2 * d + 1, "need at least 2d+1 rows");
3424 let mut mean = Array2::<f64>::zeros((n, d));
3425 let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
3426 for r in 0..n {
3427 let t = r as f64 / (n as f64 - 1.0);
3428 for c in 0..d {
3429 let omega = (c + 1) as f64;
3430 mean[[r, c]] = (std::f64::consts::PI * omega * t).sin();
3431 scale[[r, c]] = (0.4 * (std::f64::consts::PI * omega * t).cos()).exp();
3432 }
3433 }
3434 (mean, scale)
3435 }
3436
3437 #[test]
3438 fn conditional_prior_ivae_zero_mean_unit_scale_matches_standard_gaussian() {
3439 // Use varying (μ, log σ) so the identifiability precondition holds,
3440 // then evaluate at a `t` that matches `μ` to recover the closed-form
3441 // Gaussian normaliser ½·n·d·log 2π + Σ log σ.
3442 let n = 7;
3443 let d = 3;
3444 let (mean, scale) = ivae_precondition_pair(n, d);
3445 let t = mean.clone();
3446 let log_norm: f64 = scale.iter().map(|s| s.ln()).sum();
3447 let pen = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap();
3448 let (v, g) = pen.value_and_grad(t.view());
3449 let expected = log_norm + 0.5 * (n * d) as f64 * (2.0 * std::f64::consts::PI).ln();
3450 assert!(
3451 (v - expected).abs() < 1e-9,
3452 "value {v} vs expected {expected}"
3453 );
3454 for &gv in g.iter() {
3455 assert!(gv.abs() < 1e-12);
3456 }
3457 }
3458
3459 #[test]
3460 fn conditional_prior_ivae_grad_matches_finite_diff() {
3461 let (mean, scale) = ivae_precondition_pair(5, 2);
3462 let mut t = mean.clone();
3463 for r in 0..5 {
3464 t[[r, 0]] += 0.4;
3465 t[[r, 1]] -= 0.3;
3466 }
3467 let pen = ConditionalPriorIvae::new(mean, scale, 1.7).unwrap();
3468 let (_, g) = pen.value_and_grad(t.view());
3469 let h = 1.0e-5;
3470 for i in 0..t.nrows() {
3471 for j in 0..t.ncols() {
3472 let mut tp = t.clone();
3473 let mut tm = t.clone();
3474 tp[[i, j]] += h;
3475 tm[[i, j]] -= h;
3476 let vp = pen.value(tp.view());
3477 let vm = pen.value(tm.view());
3478 let fd = (vp - vm) / (2.0 * h);
3479 assert!((g[[i, j]] - fd).abs() < 1e-5);
3480 }
3481 }
3482 }
3483
3484 #[test]
3485 fn conditional_prior_ivae_rejects_nonpositive_scale() {
3486 let mean = Array2::<f64>::zeros((2, 2));
3487 let mut scale = Array2::<f64>::ones((2, 2));
3488 scale[[0, 0]] = -0.1;
3489 assert!(ConditionalPriorIvae::new(mean, scale, 1.0).is_err());
3490 }
3491
3492 #[test]
3493 fn conditional_prior_ivae_accepts_when_signature_full_rank() {
3494 let (mean, scale) = ivae_precondition_pair(7, 3);
3495 let result = ConditionalPriorIvae::new(mean, scale, 1.0);
3496 assert!(
3497 result.is_ok(),
3498 "full-rank signature should satisfy Khemakhem Theorem 1, got {:?}",
3499 result.err(),
3500 );
3501 }
3502
3503 #[test]
3504 fn conditional_prior_ivae_rejects_trivial_constant_prior() {
3505 // All rows identical → unconditional N(μ, σ²), non-identifiable.
3506 let n = 9;
3507 let d = 3;
3508 let mean = Array2::<f64>::from_elem((n, d), 0.25);
3509 let scale = Array2::<f64>::from_elem((n, d), 1.5);
3510 let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3511 assert!(
3512 err.contains("trivial unconditional") && err.contains("Khemakhem"),
3513 "unexpected error: {err}"
3514 );
3515 }
3516
3517 #[test]
3518 fn conditional_prior_ivae_rejects_too_few_auxiliary_states() {
3519 // n_rows = 4, latent_dim = 3 → need ≥ 2·3+1 = 7 rows.
3520 let (full_mean, full_scale) = ivae_precondition_pair(7, 3);
3521 let mean = full_mean.slice(s![..4, ..]).to_owned();
3522 let scale = full_scale.slice(s![..4, ..]).to_owned();
3523 let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3524 assert!(
3525 err.contains("2k+1") && err.contains("Khemakhem"),
3526 "unexpected error: {err}"
3527 );
3528 }
3529
3530 #[test]
3531 fn conditional_prior_ivae_rejects_rank_deficient_signature() {
3532 // Enough rows (n = 9 ≥ 2·3+1 = 7) and rows are NOT all identical,
3533 // but the stacked [μ ‖ log σ] matrix lies in a strict subspace of
3534 // ℝ^{2d}: column 0 of μ equals column 0 of log σ, and columns 1,2
3535 // of both μ and σ are zero / one. So the signature has rank 1, far
3536 // below the required 2·3 = 6.
3537 let n = 9;
3538 let d = 3;
3539 let mut mean = Array2::<f64>::zeros((n, d));
3540 let mut scale = Array2::<f64>::from_elem((n, d), 1.0);
3541 for r in 0..n {
3542 let v = ((r as f64) * 0.5).sin();
3543 mean[[r, 0]] = v;
3544 scale[[r, 0]] = v.exp(); // log σ column 0 = v = μ column 0
3545 }
3546 let err = ConditionalPriorIvae::new(mean, scale, 1.0).unwrap_err();
3547 assert!(
3548 err.contains("numerical rank") && err.contains("Khemakhem"),
3549 "unexpected error: {err}"
3550 );
3551 }
3552
3553 #[test]
3554 fn piecewise_linear_eval_endpoints_and_midpoint() {
3555 let coeffs = array![[0.0_f64, 10.0], [1.0, 20.0], [2.0, 30.0]];
3556 let u = Array1::from(vec![0.0, 0.5, 1.0]);
3557 let out = piecewise_linear_eval(u.view(), coeffs.view(), 0.0, 1.0);
3558 assert!((out[[0, 0]] - 0.0).abs() < 1e-12);
3559 assert!((out[[1, 0]] - 1.0).abs() < 1e-12);
3560 assert!((out[[2, 0]] - 2.0).abs() < 1e-12);
3561 assert!((out[[1, 1]] - 20.0).abs() < 1e-12);
3562 }
3563
3564 #[test]
3565 fn select_weights_picks_max_evidence() {
3566 let rss = array![[10.0, 9.0, 9.5], [8.0, 4.0, 5.0], [9.0, 6.0, 7.0]];
3567 let pen = Array2::<f64>::zeros((3, 3));
3568 let l1 = Array1::from(vec![0.1, 1.0, 10.0]);
3569 let l2 = Array1::from(vec![0.1, 1.0, 10.0]);
3570 let res =
3571 identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 80)
3572 .unwrap();
3573 assert_eq!((res.best_i, res.best_j), (1, 1));
3574 assert!((res.best_lam1 - 1.0).abs() < 1e-12);
3575 assert!((res.best_lam2 - 1.0).abs() < 1e-12);
3576 assert!(res.best_evidence.is_finite());
3577 }
3578
3579 #[test]
3580 fn select_weights_breaks_ties_by_smallest_log_weight_sum() {
3581 let rss = Array2::<f64>::from_elem((2, 2), 4.0);
3582 let pen = Array2::<f64>::from_elem((2, 2), 1.0);
3583 let l1 = Array1::from(vec![0.1, 10.0]);
3584 let l2 = Array1::from(vec![0.1, 10.0]);
3585 let res =
3586 identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
3587 .unwrap();
3588 assert_eq!((res.best_i, res.best_j), (0, 0));
3589 }
3590
3591 #[test]
3592 fn select_weights_rejects_shape_mismatch() {
3593 let rss = Array2::<f64>::zeros((2, 3));
3594 let pen = Array2::<f64>::zeros((2, 2));
3595 let l1 = Array1::from(vec![1.0, 1.0]);
3596 let l2 = Array1::from(vec![1.0, 1.0, 1.0]);
3597 let err =
3598 identifiable_factor_select_weights(rss.view(), pen.view(), l1.view(), l2.view(), 8)
3599 .unwrap_err();
3600 assert!(err.contains("penalty_grid"));
3601 }
3602
3603 #[test]
3604 fn partial_supervision_procrustes_recovers_rotation_and_orthogonalizes_free() {
3605 // Construct a known orthogonal rotation Q, supervised slice = aux @ Qᵀ.
3606 let aux = array![
3607 [1.0_f64, 0.0, 0.0],
3608 [0.0, 1.0, 0.0],
3609 [0.0, 0.0, 1.0],
3610 [1.0, 1.0, 0.0],
3611 [-1.0, 1.0, 2.0],
3612 ];
3613 // 90° rotation in the (0,1) plane.
3614 let q = array![[0.0_f64, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]];
3615 let t_sup = aux.dot(&q.t());
3616 let t_free = array![
3617 [1.5_f64, 0.0],
3618 [0.0, 1.0],
3619 [-1.0, 2.0],
3620 [0.3, -0.7],
3621 [2.0, 1.0],
3622 ];
3623 let result = partial_supervision_solve(
3624 t_sup.view(),
3625 aux.view(),
3626 t_free.view(),
3627 PartialSupervisionSupMethod::Procrustes,
3628 &[],
3629 PartialSupervisionFreeConstraint::OrthogonalToSup,
3630 )
3631 .expect("procrustes solve should succeed");
3632 // Aligned supervised block should equal aux exactly (noise-free).
3633 for r in 0..aux.nrows() {
3634 for c in 0..aux.ncols() {
3635 assert!(
3636 (result.t_supervised[[r, c]] - aux[[r, c]]).abs() < 1.0e-10,
3637 "sup[{r},{c}] = {} vs aux {}",
3638 result.t_supervised[[r, c]],
3639 aux[[r, c]]
3640 );
3641 }
3642 }
3643 // Cross-Gram T_freeᵀ T_sup should be near zero after orthogonalization.
3644 let cross = result.t_free.t().dot(&result.t_supervised);
3645 let frob: f64 = cross.iter().map(|x| x * x).sum::<f64>().sqrt();
3646 assert!(frob < 1.0e-8, "cross frobenius = {frob}");
3647 assert!(result.alignment_score > 1.0 - 1.0e-10);
3648 assert!(result.map_r.is_some());
3649 }
3650
3651 #[test]
3652 fn partial_supervision_anchor_pins_exact_anchors_when_full_rank() {
3653 let aux = array![[1.0_f64, 2.0], [-1.0, 0.5], [3.0, -2.0], [0.7, 1.2],];
3654 let t_sup = array![[0.5_f64, 1.0], [-0.5, 0.25], [1.5, -1.0], [0.35, 0.6],];
3655 let t_free = Array2::<f64>::zeros((4, 1));
3656 let result = partial_supervision_solve(
3657 t_sup.view(),
3658 aux.view(),
3659 t_free.view(),
3660 PartialSupervisionSupMethod::Anchor,
3661 &[0, 1, 2],
3662 PartialSupervisionFreeConstraint::None,
3663 )
3664 .expect("anchor solve should succeed");
3665 for &row in &[0, 1, 2] {
3666 for c in 0..2 {
3667 assert!(
3668 (result.t_supervised[[row, c]] - aux[[row, c]]).abs() < 1.0e-9,
3669 "anchor row {row} col {c} not pinned: {} vs {}",
3670 result.t_supervised[[row, c]],
3671 aux[[row, c]]
3672 );
3673 }
3674 }
3675 assert!(result.map_a.is_some() && result.map_b.is_some());
3676 }
3677
3678 #[test]
3679 fn partial_supervision_softl2_selects_a_finite_weight() {
3680 let aux = array![
3681 [1.0_f64, 0.0],
3682 [0.0, 1.0],
3683 [1.0, 1.0],
3684 [-1.0, 1.0],
3685 [0.5, -0.5],
3686 ];
3687 let t_sup = array![
3688 [1.0_f64, 0.1],
3689 [0.1, 1.0],
3690 [1.0, 1.0],
3691 [-1.0, 1.0],
3692 [0.5, -0.5],
3693 ];
3694 let t_free = array![[0.5_f64], [0.5], [0.5], [0.5], [0.5]];
3695 let result = partial_supervision_solve(
3696 t_sup.view(),
3697 aux.view(),
3698 t_free.view(),
3699 PartialSupervisionSupMethod::SoftL2,
3700 &[],
3701 PartialSupervisionFreeConstraint::OrthogonalToSup,
3702 )
3703 .expect("soft_l2 solve should succeed");
3704 let lam = result.selected_weight.unwrap();
3705 assert!(lam.is_finite() && lam > 0.0, "lam={lam}");
3706 assert!(result.map_a.is_some());
3707 }
3708}