1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
//! Analytic penalty primitives for the three-tier (beta / ext-coord / rho) engine.
//!
//! This module implements the structured
//! penalties identified as the minimal identifiability tools needed by an
//! SAE / principal-manifold / latent-coordinate workflow:
//!
//! * [`IsometryPenalty`] — pulls the decoder pullback metric
//! `J(t)^T W J(t)` toward a reference metric on the latent manifold. Lives
//! on the extension-coordinate tier (specifically on a
//! [`crate::latent::LatentCoordValues`] slice). Breaks the
//! diffeomorphism gauge so the inner Hessian on `t` is full-rank and the
//! IFT is well-defined.
//! * [`SparsityPenalty`] — smoothed L¹ (`sqrt(x² + ε²)`), Hoyer, or Log
//! sparsifier. Applied to a `β` slice (SAE codes) or extension-coordinate
//! slice (soft atom
//! amplitudes). Differentiable everywhere; the smoothing parameter `ε` may
//! itself live in `ρ` so REML shrinks it.
//! * [`IBPAssignmentPenalty`] — deterministic continuous-relaxation
//! Beta-Bernoulli/IBP prior over per-row SAE-manifold active sets.
//! * [`ARDPenalty`] — one penalty parameter per latent axis. The marginal
//! likelihood's Occam factor sends unused axes' precision to infinity,
//! discovering intrinsic dimension only after a separate gauge fix
//! (`AuxPrior` or `Isometry`) pins rotations / reparameterisations.
//! * [`TotalVariationPenalty`] — smoothed L¹ on first differences of a
//! latent coefficient block. This is coordinatewise/anisotropic TV: each
//! latent axis is penalized independently on every edge. Promotes
//! piecewise-constant atom maps.
//! * [`NuclearNormPenalty`] — smoothed L¹ on singular values of a matrix
//! latent block, `Σ_i (sqrt(σ_i² + ε²) - ε)`. Promotes low intrinsic rank
//! without choosing a canonical axis basis; in SAE wiring this is the
//! decoder-embedding rank-selection lever.
//! * [`BlockSparsityPenalty`] — group-lasso smoothed L¹ over predefined
//! latent-axis blocks. Unlike per-element L¹ or per-axis L² ARD, it
//! shrinks whole semantic groups together; pair with
//! `LatentIdMode::AuxPriorDimSelection` when aux classes define the active
//! group subset.
//! * [`RowPrecisionPriorPenalty`] — zero-mean Gaussian row-precision
//! prior on latent rows. This fixed-precomputed variant accepts one
//! precision matrix per row. It is not an iVAE conditional-mean gauge;
//! use `LatentIdMode::AuxPrior` for the ridge/linear projection residual.
//! * [`IvaeRidgeMeanGauge`] — iVAE-style conditional-mean gauge fixing:
//! penalizes the component of the latent field not explained by auxiliary
//! covariates via the ridge projection `U(UᵀU + εI)⁻¹Uᵀ`.
//! * [`ParametricRowPrecisionPriorPenalty`] — zero-mean Gaussian
//! row-precision prior with a learnable distance-kernel map from auxiliary
//! rows to diagonal per-row precision. It changes shrinkage strength, not
//! the conditional mean.
//! * [`OrthogonalityPenalty`] — fixes the rotation gauge inside a latent
//! block by penalizing cross-axis correlations. Pair with ARD when
//! intrinsic dimension should be identifiable.
//! * [`BlockOrthogonalityPenalty`] — penalizes only between-block
//! cross-products of latent axes, leaving within-block structure free.
//! * [`ScadMcpPenalty`] — elementwise nonconvex SCAD/MCP sparsity on
//! extension-coordinate latent blocks. Tapers the shrinkage derivative to
//! zero beyond the SCAD/MCP cutoff so large coefficients are not L¹-biased.
//! * [`DecoderIncoherencePenalty`] — β-tier SAE decoder penalty
//! `½·w·Σ_{j<k} W[j,k]·‖B_j B_k^T‖²_F` for stored decoder blocks
//! `B_k ∈ R^{M_k×p_out}`, with `W[j,k]` coming from co-activation.
//! Pushes co-firing atom decoder column spaces apart.
//!
//! All shipped primitives are **analytic**: no autograd, no finite differencing. Each
//! exposes:
//!
//! * `value(target, rho) -> f64`
//! * `grad_target(target, rho) -> Array1<f64>`
//! * `hessian_diag(target, rho) -> Array1<f64>` (when block-diagonal) or
//! `hvp(target, rho, v) -> Array1<f64>` (when not)
//! * `grad_rho(target, rho) -> Array1<f64>` (one entry per ρ-axis owned)
//!
//! The signatures are deliberately uniform with the existing smoothness path:
//! the quadratic ARD penalty produces a [`crate::smooth::BlockwisePenalty`]
//! that slots directly into the canonical-penalty pipeline, while the
//! non-quadratic Sparsity, TV, NuclearNorm, SCAD/MCP, Orthogonality,
//! DecoderIncoherence, and Isometry
//! penalties produce [`AnalyticPenaltyOp`] handles that downstream PIRLS / REML consumers query
//! through the same `value / gradient / hvp` interface they already use for
//! smoothness.
//!
//! ## Registration with REML
//!
//! Each penalty owns a (possibly empty) sub-range of the global `ρ` vector.
//! See [`AnalyticPenaltyKind::rho_count`]. The outer REML loop concatenates
//! these onto the existing per-smooth `ρ`s, exactly the way anisotropic
//! kernel-shape paths append ext-coords. The IsometryPenalty owns one `ρ`; the
//! SparsityPenalty owns either zero (`ε` fixed) or one (`ε` REML-selected) plus
//! one strength; the ARDPenalty owns `d` (one per latent axis);
//! NuclearNorm, BlockSparsity, BlockOrthogonality, ScadMcp,
//! DecoderIncoherence, RowPrecisionPrior, and Orthogonality each own one
//! strength only when their weight is learnable.
//! IvaeRidgeMeanGauge owns one strength only when its weight is learnable.
//! ParametricRowPrecisionPrior owns its log-baseline precision, raw distance
//! sensitivity, and reference point coordinates, plus one strength axis when
//! requested.
//!
//! ## Three-tier landings
//!
//! | Penalty | Target tier | ρ-axes owned |
//! |-----------|-------------|----------------------|
//! | Isometry | ext-coord (latent t) | 1 (log μ_iso) |
//! | Sparsity | β or ext-coord | 1 (strength) [+1 ε] |
//! | IBP | ext-coord (logits) | 0 or 1 (log α) |
//! | ARD | ext-coord (latent t) | d (one per axis) |
//! | TV | ext-coord (latent t) | 0 or 1 (log μ_tv) |
//! | NuclearNorm | ext-coord (latent t) | 0 or 1 (log μ_nuc) |
//! | BlockSparsity | ext-coord (latent t) | 0 or 1 (log μ_group) |
//! | MechanismSparsity | β (decoder W) | 0 or 1 (log μ_mech) |
//! | ScadMcp | ext-coord (latent t) | 0 or 1 (log μ_scad_mcp) |
//! | DecoderIncoherence | β (SAE decoder blocks) | 0 or 1 (log μ_decoder_incoh) |
//! | RowPrecisionPrior | ext-coord (latent t) | 0 or 1 (log μ_aux) |
//! | IvaeRidgeMeanGauge | ext-coord (latent t) | 0 or 1 (log μ_ivae_mean) |
//! | ParametricRowPrecisionPrior | ext-coord (latent t) | d + d + d·du [+1 log μ_aux] |
//! | Orthogonality | ext-coord (latent t) | 0 or 1 (log μ_orth) |
//! | BlockOrthogonality | ext-coord (latent t) | 0 or 1 (log μ_block_orth) |
// Re-exported so every concern submodule can pull the shared external imports
// through `use super::*;` without re-listing them.
pub use Side;
pub use ;
pub use ;
pub use PenaltyManifest;
pub use ;
pub use ;
pub use ;
pub use ;
pub use crate;
pub use ;
pub use crateBlockwisePenalty;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub use *;
pub