1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
//! Analytic penalty primitives for the three-tier (beta / ext-coord / rho) engine.
//!
//! This module implements the structured
//! penalties identified as the minimal identifiability tools needed by an
//! SAE / principal-manifold / latent-coordinate workflow:
//!
//! * [`IsometryPenalty`] — pulls the decoder pullback metric
//! `J(t)^T W J(t)` toward a reference metric on the latent manifold. Lives
//! on the extension-coordinate tier (specifically on a
//! [`crate::terms::latent_coord::LatentCoordValues`] slice). Breaks the
//! diffeomorphism gauge so the inner Hessian on `t` is full-rank and the
//! IFT is well-defined.
//! * [`SparsityPenalty`] — smoothed L¹ (`sqrt(x² + ε²)`), Hoyer, or Log
//! sparsifier. Applied to a `β` slice (SAE codes) or extension-coordinate
//! slice (soft atom
//! amplitudes). Differentiable everywhere; the smoothing parameter `ε` may
//! itself live in `ρ` so REML shrinks it.
//! * [`IBPAssignmentPenalty`] — deterministic continuous-relaxation
//! Beta-Bernoulli/IBP prior over per-row SAE-manifold active sets.
//! * [`ARDPenalty`] — one penalty parameter per latent axis. The marginal
//! likelihood's Occam factor sends unused axes' precision to infinity,
//! discovering intrinsic dimension only after a separate gauge fix
//! (`AuxPrior` or `Isometry`) pins rotations / reparameterisations.
//! * [`TotalVariationPenalty`] — smoothed L¹ on first differences of a
//! latent coefficient block. This is coordinatewise/anisotropic TV: each
//! latent axis is penalized independently on every edge. Promotes
//! piecewise-constant atom maps.
//! * [`NuclearNormPenalty`] — smoothed L¹ on singular values of a matrix
//! latent block, `Σ_i (sqrt(σ_i² + ε²) - ε)`. Promotes low intrinsic rank
//! without choosing a canonical axis basis; in SAE wiring this is the
//! decoder-embedding rank-selection lever.
//! * [`BlockSparsityPenalty`] — group-lasso smoothed L¹ over predefined
//! latent-axis blocks. Unlike per-element L¹ or per-axis L² ARD, it
//! shrinks whole semantic groups together; pair with
//! `LatentIdMode::AuxPriorDimSelection` when aux classes define the active
//! group subset.
//! * [`RowPrecisionPriorPenalty`] — zero-mean Gaussian row-precision
//! prior on latent rows. This fixed-precomputed variant accepts one
//! precision matrix per row. It is not an iVAE conditional-mean gauge;
//! use `LatentIdMode::AuxPrior` for the ridge/linear projection residual.
//! * [`IvaeRidgeMeanGauge`] — iVAE-style conditional-mean gauge fixing:
//! penalizes the component of the latent field not explained by auxiliary
//! covariates via the ridge projection `U(UᵀU + εI)⁻¹Uᵀ`.
//! * [`ParametricRowPrecisionPriorPenalty`] — zero-mean Gaussian
//! row-precision prior with a learnable distance-kernel map from auxiliary
//! rows to diagonal per-row precision. It changes shrinkage strength, not
//! the conditional mean.
//! * [`OrthogonalityPenalty`] — fixes the rotation gauge inside a latent
//! block by penalizing cross-axis correlations. Pair with ARD when
//! intrinsic dimension should be identifiable.
//! * [`BlockOrthogonalityPenalty`] — penalizes only between-block
//! cross-products of latent axes, leaving within-block structure free.
//! * [`ScadMcpPenalty`] — elementwise nonconvex SCAD/MCP sparsity on
//! extension-coordinate latent blocks. Tapers the shrinkage derivative to
//! zero beyond the SCAD/MCP cutoff so large coefficients are not L¹-biased.
//! * [`DecoderIncoherencePenalty`] — β-tier SAE decoder penalty
//! `½·w·Σ_{j<k} W[j,k]·‖B_j B_k^T‖²_F` for stored decoder blocks
//! `B_k ∈ R^{M_k×p_out}`, with `W[j,k]` coming from co-activation.
//! Pushes co-firing atom decoder column spaces apart.
//!
//! All shipped primitives are **analytic**: no autograd, no finite differencing. Each
//! exposes:
//!
//! * `value(target, rho) -> f64`
//! * `grad_target(target, rho) -> Array1<f64>`
//! * `hessian_diag(target, rho) -> Array1<f64>` (when block-diagonal) or
//! `hvp(target, rho, v) -> Array1<f64>` (when not)
//! * `grad_rho(target, rho) -> Array1<f64>` (one entry per ρ-axis owned)
//!
//! The signatures are deliberately uniform with the existing smoothness path:
//! the quadratic ARD penalty produces a [`crate::terms::smooth::BlockwisePenalty`]
//! that slots directly into the canonical-penalty pipeline, while the
//! non-quadratic Sparsity, TV, NuclearNorm, SCAD/MCP, Orthogonality,
//! DecoderIncoherence, and Isometry
//! penalties produce [`AnalyticPenaltyOp`] handles that downstream PIRLS / REML consumers query
//! through the same `value / gradient / hvp` interface they already use for
//! smoothness.
//!
//! ## Registration with REML
//!
//! Each penalty owns a (possibly empty) sub-range of the global `ρ` vector.
//! See [`AnalyticPenaltyKind::rho_count`]. The outer REML loop concatenates
//! these onto the existing per-smooth `ρ`s, exactly the way anisotropic
//! kernel-shape paths append ext-coords. The IsometryPenalty owns one `ρ`; the
//! SparsityPenalty owns either zero (`ε` fixed) or one (`ε` REML-selected) plus
//! one strength; the ARDPenalty owns `d` (one per latent axis);
//! NuclearNorm, BlockSparsity, BlockOrthogonality, ScadMcp,
//! DecoderIncoherence, RowPrecisionPrior, and Orthogonality each own one
//! strength only when their weight is learnable.
//! IvaeRidgeMeanGauge owns one strength only when its weight is learnable.
//! ParametricRowPrecisionPrior owns its log-baseline precision, raw distance
//! sensitivity, and reference point coordinates, plus one strength axis when
//! requested.
//!
//! ## Three-tier landings
//!
//! | Penalty | Target tier | ρ-axes owned |
//! |-----------|-------------|----------------------|
//! | Isometry | ext-coord (latent t) | 1 (log μ_iso) |
//! | Sparsity | β or ext-coord | 1 (strength) [+1 ε] |
//! | IBP | ext-coord (logits) | 0 or 1 (log α) |
//! | ARD | ext-coord (latent t) | d (one per axis) |
//! | TV | ext-coord (latent t) | 0 or 1 (log μ_tv) |
//! | NuclearNorm | ext-coord (latent t) | 0 or 1 (log μ_nuc) |
//! | BlockSparsity | ext-coord (latent t) | 0 or 1 (log μ_group) |
//! | MechanismSparsity | β (decoder W) | 0 or 1 (log μ_mech) |
//! | ScadMcp | ext-coord (latent t) | 0 or 1 (log μ_scad_mcp) |
//! | DecoderIncoherence | β (SAE decoder blocks) | 0 or 1 (log μ_decoder_incoh) |
//! | RowPrecisionPrior | ext-coord (latent t) | 0 or 1 (log μ_aux) |
//! | IvaeRidgeMeanGauge | ext-coord (latent t) | 0 or 1 (log μ_ivae_mean) |
//! | ParametricRowPrecisionPrior | ext-coord (latent t) | d + d + d·du [+1 log μ_aux] |
//! | Orthogonality | ext-coord (latent t) | 0 or 1 (log μ_orth) |
//! | BlockOrthogonality | ext-coord (latent t) | 0 or 1 (log μ_block_orth) |
// Split from the original oversized module; keep included in order.
include!;
include!;