Skip to main content

gam_terms/analytic_penalties/
sparsity.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// Sparsity penalty
5// ---------------------------------------------------------------------------
6
7/// Sparsifier kernel.
8///
9/// * `SmoothedL1 { eps }` — `Σ_i sqrt(x_i² + ε²)`. The smoothing scale `ε`
10///   may be REML-selected (`eps_rho_index = Some(_)`), in which case the
11///   shrink rate `ε → 0` is governed by the marginal likelihood (Occam keeps
12///   `ε` large when the data don't demand sharpness).
13/// * `Hoyer` — `(√n · ‖x‖_1 − ‖x‖_2) / (√n − 1)`. Scale-invariant; encourages
14///   absolute sparsity even when the global scale of `x` drifts.
15/// * `Log { delta }` — `Σ_i log(1 + x_i² / δ²)`. Strongly concave; aggressive
16///   sparsifier suitable for active-set / iterative-reweighted paths.
17#[derive(Debug, Clone, Copy)]
18pub enum SparsityKind {
19    SmoothedL1 { eps: f64 },
20    Hoyer,
21    Log { delta: f64 },
22}
23
24/// Sparsity penalty on a slice of β (SAE codes) or ext-coords (soft atom assignments).
25///
26/// The smoothed-L¹ default `Σ_i sqrt(x_i² + ε²)` is the simplest analytic
27/// option. Its gradient is `x_i / sqrt(x_i² + ε²)` (a smooth sign function),
28/// and its Hessian is diagonal with entries `ε² / (x_i² + ε²)^{3/2}` — so
29/// `hvp` is cheap and the inner Newton step inherits a benign block-diagonal
30/// regularizer.
31///
32/// When to use: any time a parameter block carries a "this should be sparse"
33/// prior — SAE atom codes (β slice), soft-routing weights on a latent
34/// ext-coordinate slice. For SAE codes specifically, smoothed-L¹ with REML-selected `ε`
35/// gives the principled relaxation of the L¹ objective without giving up
36/// differentiability.
37#[derive(Debug, Clone)]
38pub struct SparsityPenalty {
39    pub target_tier: PenaltyTier,
40    pub kind: SparsityKind,
41    pub weight: f64,
42    pub weight_schedule: Option<ScalarWeightSchedule>,
43    /// Index of `log strength` inside this penalty's local ρ view.
44    pub strength_rho_index: usize,
45    /// If `Some`, the index of `log ε` (or `log δ`) inside this penalty's
46    /// local ρ view. If `None`, `ε` / `δ` is held fixed at the value baked
47    /// into [`SparsityKind`].
48    pub eps_rho_index: Option<usize>,
49}
50
51/// Entropy sparsity over row-wise softmax assignment logits.
52///
53/// This is the SAE-manifold soft-assignment penalty. The target is a flat
54/// row-major `(N, K)` logit matrix. Assignments are
55/// `a_i = softmax(logits_i / temperature)`, and the penalty is
56///
57/// ```text
58///   lambda_sparse * sum_i H(a_i)
59///   H(a_i) = -sum_k a_ik log a_ik
60/// ```
61///
62/// Minimizing entropy drives each row toward a small active support while the
63/// softmax keeps `a_ik >= 0` and `sum_k a_ik = 1`. The exact Hessian is dense
64/// in each row and can be indefinite because entropy is concave in assignment
65/// space, so callers must use the HVP rather than a diagonal Hessian shortcut.
66#[derive(Debug, Clone)]
67pub struct SoftmaxAssignmentSparsityPenalty {
68    pub k_atoms: usize,
69    pub temperature: f64,
70    pub weight: f64,
71    pub weight_schedule: Option<ScalarWeightSchedule>,
72}
73
74impl SoftmaxAssignmentSparsityPenalty {
75    #[must_use]
76    pub fn new(k_atoms: usize, temperature: f64) -> Self {
77        assert!(k_atoms > 0);
78        assert!(temperature > 0.0);
79        Self {
80            k_atoms,
81            temperature,
82            weight: 1.0,
83            weight_schedule: None,
84        }
85    }
86
87    impl_with_weight_schedule!(weight);
88
89    fn softmax_row(&self, row: &[f64]) -> Vec<f64> {
90        let inv_tau = 1.0 / self.temperature;
91        let mut max_logit = f64::NEG_INFINITY;
92        for (idx, &v) in row.iter().enumerate() {
93            assert!(
94                v.is_finite(),
95                "SoftmaxAssignmentSparsityPenalty: non-finite logit at atom {idx}: {v}"
96            );
97            max_logit = max_logit.max(v);
98        }
99        let mut out = vec![0.0; self.k_atoms];
100        let mut sum = 0.0;
101        for i in 0..self.k_atoms {
102            let v = ((row[i] - max_logit) * inv_tau).exp();
103            out[i] = v;
104            sum += v;
105        }
106        assert!(
107            sum.is_finite() && sum > 0.0,
108            "SoftmaxAssignmentSparsityPenalty: non-finite softmax normalizer"
109        );
110        for v in out.iter_mut() {
111            *v /= sum;
112        }
113        out
114    }
115
116    /// Absolute row sums of the exact per-row dense entropy Hessian, used as a
117    /// Gershgorin / diagonal-dominance PSD majorizer.
118    ///
119    /// The exact per-row Hessian wrt logits (symmetric, dense) is
120    ///
121    /// ```text
122    ///   H_kj = (λ/τ²)·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
123    ///   L_k = ln a_k + 1,   m = Σ_j a_j L_j,
124    /// ```
125    ///
126    /// whose diagonal coincides with [`AnalyticPenalty::hessian_diag`]. Entropy
127    /// is concave in assignment space, so this block is indefinite (negative on
128    /// near-uniform rows). Setting `D_kk = Σ_j |H_kj|` makes `D − H` symmetric
129    /// with nonnegative diagonal and diagonally dominant
130    /// (`D_kk − H_kk = |H_kk| − H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|`),
131    /// hence PSD: `D ⪰ H` and `D ⪰ 0` both hold. `D` is a genuine PSD diagonal
132    /// operator that dominates the dense Hessian's quadratic form — unlike the
133    /// raw indefinite diagonal, which is neither PSD nor a faithful stand-in for
134    /// the dense operator.
135    pub fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64> {
136        let a = self.softmax_row(row);
137        let k = self.k_atoms;
138        let l: Vec<f64> = (0..k)
139            .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
140            .collect();
141        let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
142        let mut d = vec![0.0_f64; k];
143        for kk in 0..k {
144            // Diagonal entry H_kk.
145            let h_kk = scale * a[kk] * ((m - l[kk] - 1.0) + a[kk] * (2.0 * l[kk] + 1.0 - 2.0 * m));
146            let mut acc = h_kk.abs();
147            // Off-diagonal entries H_kj, j ≠ k.
148            for jj in 0..k {
149                if jj == kk {
150                    continue;
151                }
152                let h_kj = scale * a[kk] * a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
153                acc += h_kj.abs();
154            }
155            d[kk] = acc;
156        }
157        d
158    }
159
160    /// Exact per-row dense softmax-entropy Hessian wrt the row's logits (#1038),
161    /// scaled by `scale = λ/τ²`. Returns the symmetric `K×K` block
162    ///
163    /// ```text
164    ///   H_kj = scale·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
165    ///   L_k = ln a_k + 1,   m = Σ_r a_r L_r,
166    /// ```
167    ///
168    /// whose diagonal coincides with [`AnalyticPenalty::hessian_diag`] and whose
169    /// quadratic form coincides with [`AnalyticPenalty::hvp`]. This is the dense
170    /// block the Arrow-Schur row factor stores so the criterion's `log|H|` and
171    /// the #1006 θ-adjoint differentiate the SAME operator (not just its
172    /// diagonal). The entropy block alone is gauge-null (`H·𝟙 = 0`, softmax
173    /// shift-invariance); callers must add it to the gauge-breaking data-fit
174    /// row block before factoring — never factor it in isolation.
175    #[must_use]
176    pub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
177        let k = self.k_atoms;
178        let a = self.softmax_row(row_logits);
179        let l: Vec<f64> = (0..k)
180            .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
181            .collect();
182        let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
183        let mut h = Array2::<f64>::zeros((k, k));
184        for kk in 0..k {
185            for jj in 0..k {
186                let indicator = if kk == jj { 1.0 } else { 0.0 };
187                h[[kk, jj]] = scale
188                    * a[kk]
189                    * (indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m));
190            }
191        }
192        h
193    }
194
195    /// Derivative of the exact per-row dense entropy Hessian
196    /// [`Self::row_dense_hessian`] with respect to a single row logit `z_w`,
197    /// scaled by `scale = λ/τ²`. Returns the symmetric `K×K` block
198    /// `∂H_kj/∂z_w`, the third-derivative tensor slice the #1006 θ-adjoint
199    /// contracts against the row's selected inverse. Built from the SAME
200    /// `(a, L, m)` as [`Self::row_dense_hessian`] (`∂a_r/∂z_w = a_r(δ_rw − a_w)/τ`),
201    /// so value, logdet and adjoint stay on one branch.
202    #[must_use]
203    pub fn row_dense_hessian_logit_derivative(
204        &self,
205        row_logits: &[f64],
206        scale: f64,
207        w: usize,
208    ) -> Array2<f64> {
209        let k = self.k_atoms;
210        let inv_tau = 1.0 / self.temperature;
211        let a = self.softmax_row(row_logits);
212        let l: Vec<f64> = (0..k)
213            .map(|i| a[i].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0)
214            .collect();
215        let m: f64 = (0..k).map(|i| a[i] * l[i]).sum();
216        // ∂a_r/∂z_w = a_r (δ_rw − a_w)/τ ; ∂L_r/∂z_w = (∂a_r/∂z_w)/a_r.
217        let da: Vec<f64> = (0..k)
218            .map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
219            .collect();
220        let dl: Vec<f64> = (0..k)
221            .map(|r| da[r] / a[r].max(ENTROPY_LOG_PROBABILITY_FLOOR))
222            .collect();
223        let dm: f64 = (0..k).map(|r| da[r] * l[r] + a[r] * dl[r]).sum();
224        let mut dh = Array2::<f64>::zeros((k, k));
225        for kk in 0..k {
226            for jj in 0..k {
227                let indicator = if kk == jj { 1.0 } else { 0.0 };
228                // bracket = δ_kj(m − L_k − 1) + a_j(L_k + L_j + 1 − 2m).
229                let bracket =
230                    indicator * (m - l[kk] - 1.0) + a[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m);
231                let dbracket = indicator * (dm - dl[kk])
232                    + da[jj] * (l[kk] + l[jj] + 1.0 - 2.0 * m)
233                    + a[jj] * (dl[kk] + dl[jj] - 2.0 * dm);
234                dh[[kk, jj]] = scale * (da[kk] * bracket + a[kk] * dbracket);
235            }
236        }
237        dh
238    }
239
240    /// Per-row **Gershgorin diagonal majorizer** `D` of the exact softmax-entropy
241    /// Hessian [`Self::row_dense_hessian`], scaled by `scale = λ/τ²`. Returns the
242    /// `K×K` diagonal block `diag(D_0, …, D_{K−1})` with
243    /// `D_kk = Σ_j |H_kj|` (#1419).
244    ///
245    /// Unlike the Fisher metric [`Self::row_fisher_metric`] — which is PSD but
246    /// does NOT satisfy `G ⪰ H_entropy` (counterexample `a=(0.95,0.05)`,
247    /// `λ=τ=1`: `G₁₁=0.0475 < H₁₁=0.0784`) — this `D` is a genuine Loewner
248    /// majorizer: it is diagonally dominant over `H` (`D_kk − H_kk =
249    /// |H_kk|−H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|`), so `D − H ⪰ 0`, and
250    /// every `D_kk ≥ 0`, so `D ⪰ 0`. It therefore both keeps the assembled
251    /// evidence block PD (the property the entropy block needs so the
252    /// Faddeev–Popov deflation never fires) AND actually majorizes the entropy
253    /// curvature, which the Fisher surrogate did not. The criterion's `log|H|`,
254    /// its θ-adjoint [`Self::row_psd_majorizer_logit_derivative`], and the
255    /// assembled Hessian all differentiate this SAME operator `D`, keeping value
256    /// and adjoint on one exact branch.
257    #[must_use]
258    pub fn row_psd_majorizer(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
259        let k = self.k_atoms;
260        let d = self.psd_majorizer_abs_row_sums(row_logits, scale);
261        let mut out = Array2::<f64>::zeros((k, k));
262        for kk in 0..k {
263            out[[kk, kk]] = d[kk];
264        }
265        out
266    }
267
268    /// Derivative of the per-row Gershgorin majorizer [`Self::row_psd_majorizer`]
269    /// with respect to a single row logit `z_w`, scaled by `scale = λ/τ²`.
270    /// Returns the `K×K` diagonal block `diag(∂D_0/∂z_w, …)` with
271    /// `∂D_kk/∂z_w = Σ_j sign(H_kj)·(∂H_kj/∂z_w)` (#1419), where `H` is the exact
272    /// entropy Hessian [`Self::row_dense_hessian`] and `∂H_kj/∂z_w` is
273    /// [`Self::row_dense_hessian_logit_derivative`]. `sign(0)=0` (a zero entry
274    /// contributes no first-order change to its own magnitude). Built from the
275    /// SAME `(a, L, m)` derivative convention as the dense Hessian derivative, so
276    /// the θ-adjoint differentiates the SAME `D` the assembly added.
277    #[must_use]
278    pub fn row_psd_majorizer_logit_derivative(
279        &self,
280        row_logits: &[f64],
281        scale: f64,
282        w: usize,
283    ) -> Array2<f64> {
284        let k = self.k_atoms;
285        let h = self.row_dense_hessian(row_logits, scale);
286        let dh = self.row_dense_hessian_logit_derivative(row_logits, scale, w);
287        let mut out = Array2::<f64>::zeros((k, k));
288        for kk in 0..k {
289            let mut acc = 0.0_f64;
290            for jj in 0..k {
291                let s = h[[kk, jj]].signum();
292                if h[[kk, jj]] != 0.0 {
293                    acc += s * dh[[kk, jj]];
294                }
295            }
296            out[[kk, kk]] = acc;
297        }
298        out
299    }
300
301    /// Per-row softmax **Fisher-information metric** `G = scale·(diag(a) − a aᵀ)`
302    /// over the row's logits, with `a = softmax(row_logits)` and
303    /// `scale = λ/τ²` (#1190). Returns the symmetric `K×K` block
304    ///
305    /// ```text
306    ///   G_kj = scale·a_k·(δ_kj − a_j).
307    /// ```
308    ///
309    /// `G` is a covariance/Gram matrix, hence exactly PSD and smooth in the
310    /// logits. It is the Fisher-information metric of the row softmax, NOT a
311    /// curvature majorizer of the entropy Hessian: `G − H_entropy` can be
312    /// indefinite (#1419: `K=2`, `a=(0.95,0.05)`, `λ=τ=1` gives `G₁₁=0.0475 <
313    /// H₁₁=0.0784`, so `G ⋡ H`). The genuine Loewner majorizer the assembled
314    /// evidence block now uses is [`Self::row_psd_majorizer`]
315    /// (`D_kk = Σ_j|H_kj|`, which DOES satisfy `D ⪰ H` and `D ⪰ 0`); this
316    /// Fisher metric is retained only as a smooth PSD conditioning reference and
317    /// its derivative [`Self::row_fisher_metric_logit_derivative`], and must not
318    /// be presented or used as a curvature majorizer.
319    #[must_use]
320    pub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64> {
321        let k = self.k_atoms;
322        let a = self.softmax_row(row_logits);
323        let mut g = Array2::<f64>::zeros((k, k));
324        for kk in 0..k {
325            for jj in 0..k {
326                let indicator = if kk == jj { 1.0 } else { 0.0 };
327                g[[kk, jj]] = scale * a[kk] * (indicator - a[jj]);
328            }
329        }
330        g
331    }
332
333    /// Derivative of the per-row softmax Fisher metric
334    /// [`Self::row_fisher_metric`] with respect to a single row logit `z_w`,
335    /// scaled by `scale = λ/τ²` (#1190). Returns the symmetric `K×K` block
336    /// `∂G_kj/∂z_w`, the third-derivative tensor slice the θ-adjoint contracts
337    /// against the row's selected inverse so the adjoint differentiates the SAME
338    /// PSD `G = scale·(diag(a) − a aᵀ)` the assembly added (value/adjoint on one
339    /// branch, no deflation needed). Built from the SAME softmax derivative
340    /// convention as [`Self::row_dense_hessian_logit_derivative`]
341    /// (`∂a_r/∂z_w = a_r(δ_rw − a_w)/τ`). For `G_kj = scale·a_k(δ_kj − a_j)`,
342    /// the product rule gives
343    /// `∂G_kj/∂z_w = scale·[ (∂a_k/∂z_w)(δ_kj − a_j) − a_k(∂a_j/∂z_w) ]`.
344    #[must_use]
345    pub fn row_fisher_metric_logit_derivative(
346        &self,
347        row_logits: &[f64],
348        scale: f64,
349        w: usize,
350    ) -> Array2<f64> {
351        let k = self.k_atoms;
352        let inv_tau = 1.0 / self.temperature;
353        let a = self.softmax_row(row_logits);
354        // ∂a_r/∂z_w = a_r (δ_rw − a_w)/τ — identical convention to the entropy
355        // Hessian derivative above.
356        let da: Vec<f64> = (0..k)
357            .map(|r| a[r] * (if r == w { 1.0 } else { 0.0 } - a[w]) * inv_tau)
358            .collect();
359        let mut dg = Array2::<f64>::zeros((k, k));
360        for kk in 0..k {
361            for jj in 0..k {
362                let indicator = if kk == jj { 1.0 } else { 0.0 };
363                dg[[kk, jj]] = scale * (da[kk] * (indicator - a[jj]) - a[kk] * da[jj]);
364            }
365        }
366        dg
367    }
368}
369
370impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty {
371    fn tier(&self) -> PenaltyTier {
372        PenaltyTier::Psi
373    }
374
375    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
376        let lambda = resolve_learnable_weight(self.weight, rho[0]);
377        let n = target.len() / self.k_atoms;
378        let values: Vec<f64> = target.iter().copied().collect();
379        let mut acc = 0.0;
380        for row in 0..n {
381            let start = row * self.k_atoms;
382            let a = self.softmax_row(&values[start..start + self.k_atoms]);
383            for v in a {
384                if v > 0.0 {
385                    acc += -v * v.ln();
386                }
387            }
388        }
389        lambda * acc
390    }
391
392    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
393        let lambda = resolve_learnable_weight(self.weight, rho[0]);
394        let n = target.len() / self.k_atoms;
395        let values: Vec<f64> = target.iter().copied().collect();
396        let mut out = Array1::<f64>::zeros(target.len());
397        let inv_tau = 1.0 / self.temperature;
398        for row in 0..n {
399            let start = row * self.k_atoms;
400            let a = self.softmax_row(&values[start..start + self.k_atoms]);
401            let mut d_h_da = vec![0.0; self.k_atoms];
402            let mut mean = 0.0;
403            for k in 0..self.k_atoms {
404                let ak = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR);
405                d_h_da[k] = -lambda * (ak.ln() + 1.0);
406                mean += a[k] * d_h_da[k];
407            }
408            for k in 0..self.k_atoms {
409                out[start + k] = a[k] * (d_h_da[k] - mean) * inv_tau;
410            }
411        }
412        out
413    }
414
415    fn hessian_diag(
416        &self,
417        target: ArrayView1<'_, f64>,
418        rho: ArrayView1<'_, f64>,
419    ) -> Option<Array1<f64>> {
420        assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
421        assert!(
422            rho.iter().all(|value| value.is_finite()),
423            "softmax entropy rho must be finite"
424        );
425        assert_eq!(
426            target.len() % self.k_atoms,
427            0,
428            "softmax entropy target length must be divisible by k_atoms"
429        );
430        // Closed-form diagonal of the softmax-entropy Hessian wrt logits.
431        // Derived by probing the row-dense HVP with the unit vector e_k:
432        // for a row with softmax weights a_k and L_k = ln a_k + 1,
433        //   H_kk = (lambda / tau^2) * a_k *
434        //          ((1 - 2 a_k) * (E_a[L] - L_k) + a_k - 1).
435        // This matches `hvp(...) . e_k` analytically (see derivation in the
436        // bug-fix comment on `hvp`) and gives Newton/Arrow-Schur callers a
437        // principled diagonal surrogate without per-row dense factorization.
438        let lambda = resolve_learnable_weight(self.weight, rho[0]);
439        let inv_tau = 1.0 / self.temperature;
440        let scale = lambda * inv_tau * inv_tau;
441        let n = target.len() / self.k_atoms;
442        let values: Vec<f64> = target.iter().copied().collect();
443        let mut out = Array1::<f64>::zeros(target.len());
444        for row in 0..n {
445            let start = row * self.k_atoms;
446            let a = self.softmax_row(&values[start..start + self.k_atoms]);
447            let mut mean_log_plus_one = 0.0;
448            for k in 0..self.k_atoms {
449                mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
450            }
451            for k in 0..self.k_atoms {
452                let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
453                let term = (1.0 - 2.0 * a[k]) * (mean_log_plus_one - log_plus_one) + a[k] - 1.0;
454                out[start + k] = scale * a[k] * term;
455            }
456        }
457        Some(out)
458    }
459
460    fn hvp(
461        &self,
462        target: ArrayView1<'_, f64>,
463        rho: ArrayView1<'_, f64>,
464        v: ArrayView1<'_, f64>,
465    ) -> Array1<f64> {
466        /*
467        Softmax entropy is not coordinate-separable in logits. The old
468        `hessian_diag` returned λ p_k(1-p_k)/τ², which is only the softmax
469        Jacobian diagonal and omits the entropy curvature and all cross-logit
470        terms. For H(p(z)), p'=p*(v-E_p[v])/τ and
471        (log p_k + 1)'=(v_k-E_p[v])/τ. Differentiating
472        g_k=λ p_k(E_p[log p + 1]-(log p_k+1))/τ gives the row-dense product
473        below. `hessian_diag` returns the analytic diagonal extracted from
474        this HVP by setting v = e_k row-by-row.
475        */
476        let lambda = resolve_learnable_weight(self.weight, rho[0]);
477        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
478        let n = target.len() / self.k_atoms;
479        let values: Vec<f64> = target.iter().copied().collect();
480        let mut out = Array1::<f64>::zeros(target.len());
481        let inv_tau = 1.0 / self.temperature;
482        let scale = lambda * inv_tau * inv_tau;
483        for row in 0..n {
484            let start = row * self.k_atoms;
485            let a = self.softmax_row(&values[start..start + self.k_atoms]);
486            let mut mean_log_plus_one = 0.0;
487            let mut mean_v = 0.0;
488            for k in 0..self.k_atoms {
489                mean_log_plus_one += a[k] * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
490                mean_v += a[k] * v[start + k];
491            }
492            let mut mean_centered_v_log_plus_one = 0.0;
493            for k in 0..self.k_atoms {
494                let centered_v = v[start + k] - mean_v;
495                mean_centered_v_log_plus_one +=
496                    a[k] * centered_v * (a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0);
497            }
498            for k in 0..self.k_atoms {
499                let log_plus_one = a[k].max(ENTROPY_LOG_PROBABILITY_FLOOR).ln() + 1.0;
500                let centered_v = v[start + k] - mean_v;
501                out[start + k] = scale
502                    * a[k]
503                    * (centered_v * (mean_log_plus_one - log_plus_one - 1.0)
504                        + mean_centered_v_log_plus_one);
505            }
506        }
507        out
508    }
509
510    fn psd_majorizer_diag(
511        &self,
512        target: ArrayView1<'_, f64>,
513        rho: ArrayView1<'_, f64>,
514    ) -> Option<Array1<f64>> {
515        assert_eq!(rho.len(), 1, "softmax entropy expects one rho parameter");
516        assert_eq!(
517            target.len() % self.k_atoms,
518            0,
519            "softmax entropy target length must be divisible by k_atoms"
520        );
521        // Entropy minimization is nonconvex: the exact per-row Hessian is dense
522        // and indefinite, so the convex-only trait default (which returns the
523        // raw indefinite `hessian_diag`) violates the `B ⪰ 0` contract and is a
524        // diagonal masquerading as a dense operator. Replace it with the
525        // Gershgorin / diagonal-dominance majorizer of the dense per-row block
526        // (see `psd_majorizer_abs_row_sums`): a genuine PSD diagonal with
527        // `D ⪰ H` and `D ⪰ 0`. Coordinate-indexed, so the inherited
528        // `psd_majorizer_hvp` applies `D` as a diagonal operator consistently.
529        let lambda = resolve_learnable_weight(self.weight, rho[0]);
530        let inv_tau = 1.0 / self.temperature;
531        let scale = lambda * inv_tau * inv_tau;
532        let n = target.len() / self.k_atoms;
533        let values: Vec<f64> = target.iter().copied().collect();
534        let mut out = Array1::<f64>::zeros(target.len());
535        for row in 0..n {
536            let start = row * self.k_atoms;
537            let d = self.psd_majorizer_abs_row_sums(&values[start..start + self.k_atoms], scale);
538            for k in 0..self.k_atoms {
539                out[start + k] = d[k];
540            }
541        }
542        Some(out)
543    }
544
545    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
546        Array1::from_vec(vec![self.value(target, rho)])
547    }
548
549    fn rho_count(&self) -> usize {
550        1
551    }
552
553    fn name(&self) -> &str {
554        "softmax_assignment_sparsity"
555    }
556
557    impl_scalar_apply_schedule!(weight);
558}
559
560impl SparsityPenalty {
561    #[must_use = "build error must be handled"]
562    pub fn smoothed_l1(target_tier: PenaltyTier, eps: f64) -> Result<Self, String> {
563        if !(eps.is_finite() && eps > 0.0) {
564            return Err(format!(
565                "SparsityPenalty::smoothed_l1 requires eps > 0 \
566                 (Hessian / gradient have a `1/sqrt(x² + eps²)` factor that needs eps > 0 \
567                 for differentiability at x = 0); got eps = {eps}"
568            ));
569        }
570        Ok(Self {
571            target_tier,
572            kind: SparsityKind::SmoothedL1 { eps },
573            weight: 1.0,
574            weight_schedule: None,
575            strength_rho_index: 0,
576            eps_rho_index: None,
577        })
578    }
579
580    #[must_use = "build error must be handled"]
581    pub fn log(target_tier: PenaltyTier, delta: f64) -> Result<Self, String> {
582        if !(delta.is_finite() && delta > 0.0) {
583            return Err(format!(
584                "SparsityPenalty::log requires delta > 0 \
585                 (the log-sparsifier is log(1 + x²/δ²), undefined at δ = 0); \
586                 got delta = {delta}"
587            ));
588        }
589        Ok(Self {
590            target_tier,
591            kind: SparsityKind::Log { delta },
592            weight: 1.0,
593            weight_schedule: None,
594            strength_rho_index: 0,
595            eps_rho_index: None,
596        })
597    }
598
599    /// Hoyer scale-invariant sparsifier. Requires a target of length > 1
600    /// because the normalized form divides by `sqrt(n) - 1`.
601    #[must_use]
602    pub fn hoyer(target_tier: PenaltyTier) -> Self {
603        Self {
604            target_tier,
605            kind: SparsityKind::Hoyer,
606            weight: 1.0,
607            weight_schedule: None,
608            strength_rho_index: 0,
609            eps_rho_index: None,
610        }
611    }
612
613    impl_with_weight_schedule!(weight);
614
615    #[must_use]
616    pub fn with_eps_reml(mut self, eps_rho_index: usize) -> Self {
617        self.eps_rho_index = Some(eps_rho_index);
618        self
619    }
620
621    /// Resolve `(strength, eps_or_delta)` from the current ρ view.
622    fn resolved(&self, rho: ArrayView1<'_, f64>) -> (f64, f64) {
623        let strength = resolve_learnable_weight(self.weight, rho[self.strength_rho_index]);
624        let smoothing = match (self.eps_rho_index, self.kind) {
625            // A learnable smoothing `exp(rho)` underflows to exact `0.0` for
626            // `rho ≲ -745`, which reintroduces a non-differentiable kink and a
627            // `0/0` at `x = 0` in `sqrt(x² + ε²)` / the Log sparsifier. Floor it
628            // at the smallest positive normal so the smoothing stays strictly
629            // positive while still shrinking arbitrarily close to zero.
630            (Some(idx), _) => rho[idx].exp().max(f64::MIN_POSITIVE),
631            (None, SparsityKind::SmoothedL1 { eps }) => eps,
632            (None, SparsityKind::Log { delta }) => delta,
633            (None, SparsityKind::Hoyer) => 0.0,
634        };
635        (strength, smoothing)
636    }
637}
638
639impl AnalyticPenalty for SparsityPenalty {
640    fn tier(&self) -> PenaltyTier {
641        self.target_tier
642    }
643
644    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
645        let (lam, smooth) = self.resolved(rho);
646        match self.kind {
647            SparsityKind::SmoothedL1 { .. } => {
648                let mut acc = 0.0;
649                for &x in target.iter() {
650                    acc += (x * x + smooth * smooth).sqrt();
651                }
652                lam * acc
653            }
654            SparsityKind::Hoyer => {
655                // Normalized anti-sparsity penalty
656                //   P(x) = (||x||_1 / ||x||_2 - 1) / (sqrt(n) - 1)
657                // maps [1, sqrt(n)] -> [0, 1]. A perfectly dense
658                // equal-magnitude vector hits ||x||_1/||x||_2 = sqrt(n),
659                // so P = 1; a 1-sparse vector has ratio 1, so P = 0
660                // (sparse vectors minimize the penalty).
661                let n = target.len() as f64;
662                assert!(n > 1.0, "Hoyer requires n > 1");
663                let l1: f64 = target.iter().map(|x| x.abs()).sum();
664                let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
665                if l2 == 0.0 {
666                    return 0.0;
667                }
668                let h = (l1 / l2 - 1.0) / (n.sqrt() - 1.0);
669                lam * h
670            }
671            SparsityKind::Log { .. } => {
672                let mut acc = 0.0;
673                let d2 = smooth * smooth;
674                for &x in target.iter() {
675                    acc += (1.0 + x * x / d2).ln();
676                }
677                lam * acc
678            }
679        }
680    }
681
682    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
683        let (lam, smooth) = self.resolved(rho);
684        let mut g = Array1::<f64>::zeros(target.len());
685        match self.kind {
686            SparsityKind::SmoothedL1 { .. } => {
687                let eps2 = smooth * smooth;
688                for (i, &x) in target.iter().enumerate() {
689                    g[i] = lam * x / (x * x + eps2).sqrt();
690                }
691            }
692            SparsityKind::Hoyer => {
693                // P(x) = A · (L1/L2 - 1), A = lam / (sqrt(n) - 1).
694                // ∂P/∂x_i = A · (sign(x_i)/L2 - L1 · x_i / L2³).
695                let n = target.len() as f64;
696                assert!(n > 1.0, "Hoyer requires n > 1");
697                let l1: f64 = target.iter().map(|x| x.abs()).sum();
698                let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
699                if l2 == 0.0 {
700                    return g;
701                }
702                let denom = n.sqrt() - 1.0;
703                let a = lam / denom;
704                let inv_l2 = 1.0 / l2;
705                let inv_l2_cubed = inv_l2 * inv_l2 * inv_l2;
706                for (i, &x) in target.iter().enumerate() {
707                    let sgn = if x > 0.0 {
708                        1.0
709                    } else if x < 0.0 {
710                        -1.0
711                    } else {
712                        0.0
713                    };
714                    g[i] = a * (sgn * inv_l2 - l1 * x * inv_l2_cubed);
715                }
716            }
717            SparsityKind::Log { .. } => {
718                let d2 = smooth * smooth;
719                for (i, &x) in target.iter().enumerate() {
720                    g[i] = lam * 2.0 * x / (d2 + x * x);
721                }
722            }
723        }
724        g
725    }
726
727    fn hessian_diag(
728        &self,
729        target: ArrayView1<'_, f64>,
730        rho: ArrayView1<'_, f64>,
731    ) -> Option<Array1<f64>> {
732        let (lam, smooth) = self.resolved(rho);
733        match self.kind {
734            SparsityKind::SmoothedL1 { .. } => {
735                let mut d = Array1::<f64>::zeros(target.len());
736                let eps2 = smooth * smooth;
737                for (i, &x) in target.iter().enumerate() {
738                    let r = (x * x + eps2).sqrt();
739                    d[i] = lam * eps2 / (r * r * r);
740                }
741                Some(d)
742            }
743            SparsityKind::Log { .. } => {
744                let mut d = Array1::<f64>::zeros(target.len());
745                // The EXACT second derivative of λ log(1 + x²/δ²):
746                //   d/dx [ 2λx/(δ²+x²) ] = 2λ(δ² − x²)/(δ² + x²)²,
747                // which is NEGATIVE for |x| > δ — Log is nonconvex. This is
748                // the genuine Hessian diagonal and exactly differentiates
749                // `grad_target`. PSD consumers (Newton block, preconditioner,
750                // `log_det_plus_λI`, FrozenAnalyticPenaltyOp) must instead
751                // route through `psd_majorizer_diag`/`psd_majorizer_hvp`,
752                // which expose the IRLS/MM surrogate `2λ/(δ²+x²)`.
753                let d2 = smooth * smooth;
754                for (i, &x) in target.iter().enumerate() {
755                    let denom = d2 + x * x;
756                    d[i] = lam * 2.0 * (d2 - x * x) / (denom * denom);
757                }
758                Some(d)
759            }
760            // Hoyer's Hessian is DENSE and NOT generally PSD (Hoyer is a
761            // nonconvex sparsifier). We cannot return a meaningful diagonal
762            // that would be safe to use as a preconditioner / Newton block
763            // through the standard `hessian_diag` path, so we return `None`
764            // and force callers through `hvp`. See `hvp` below for the exact
765            // dense-Hessian-vector product.
766            SparsityKind::Hoyer => None,
767        }
768    }
769
770    fn hvp(
771        &self,
772        target: ArrayView1<'_, f64>,
773        rho: ArrayView1<'_, f64>,
774        v: ArrayView1<'_, f64>,
775    ) -> Array1<f64> {
776        // For SmoothedL1/Log/Hoyer we route through the closed-form Hessian.
777        // SmoothedL1 and Log have purely diagonal Hessians and would
778        // ordinarily reach the diagonal branch of the default `hvp`; we
779        // override here to also serve Hoyer (whose Hessian is dense
780        // rank-1-plus-diagonal).
781        let (lam, smooth) = self.resolved(rho);
782        let n_target = target.len();
783        assert_eq!(v.len(), n_target, "hvp dimension mismatch");
784        match self.kind {
785            SparsityKind::SmoothedL1 { .. } => {
786                let mut out = Array1::<f64>::zeros(n_target);
787                let eps2 = smooth * smooth;
788                for (i, &x) in target.iter().enumerate() {
789                    let r = (x * x + eps2).sqrt();
790                    out[i] = lam * eps2 / (r * r * r) * v[i];
791                }
792                out
793            }
794            SparsityKind::Log { .. } => {
795                // EXACT Hessian-vector product: the Log Hessian is diagonal
796                // with entries 2λ(δ²−x²)/(δ²+x²)², so (Hv)_i = h_i v_i. This
797                // is the genuine second derivative (indefinite for |x|>δ).
798                // PSD consumers use `psd_majorizer_hvp` for the IRLS/MM
799                // surrogate 2λ/(δ²+x²) instead.
800                let mut out = Array1::<f64>::zeros(n_target);
801                let d2 = smooth * smooth;
802                for (i, &x) in target.iter().enumerate() {
803                    let denom = d2 + x * x;
804                    out[i] = lam * 2.0 * (d2 - x * x) / (denom * denom) * v[i];
805                }
806                out
807            }
808            SparsityKind::Hoyer => {
809                // P(x) = A · (L1/L2 - 1), A = lam / (sqrt(n) - 1).
810                // H_ij = A · [ -s_i x_j/L2³ - x_i s_j/L2³
811                //              - L1 δ_ij/L2³ + 3 L1 x_i x_j/L2⁵ ]
812                // (Hv)_i = A · [ -s_i (xᵀv)/L2³ - x_i (sᵀv)/L2³
813                //                - L1 v_i/L2³ + 3 L1 x_i (xᵀv)/L2⁵ ]
814                let n = n_target as f64;
815                assert!(n > 1.0, "Hoyer requires n > 1");
816                let l1: f64 = target.iter().map(|x| x.abs()).sum();
817                let l2: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
818                let mut out = Array1::<f64>::zeros(n_target);
819                if l2 == 0.0 {
820                    return out;
821                }
822                let a = lam / (n.sqrt() - 1.0);
823                let inv_l2_cubed = 1.0 / (l2 * l2 * l2);
824                let inv_l2_5 = inv_l2_cubed / (l2 * l2);
825                let mut x_dot_v = 0.0;
826                let mut s_dot_v = 0.0;
827                for i in 0..n_target {
828                    let xi = target[i];
829                    let si = if xi > 0.0 {
830                        1.0
831                    } else if xi < 0.0 {
832                        -1.0
833                    } else {
834                        0.0
835                    };
836                    x_dot_v += xi * v[i];
837                    s_dot_v += si * v[i];
838                }
839                for i in 0..n_target {
840                    let xi = target[i];
841                    let si = if xi > 0.0 {
842                        1.0
843                    } else if xi < 0.0 {
844                        -1.0
845                    } else {
846                        0.0
847                    };
848                    out[i] = a
849                        * (-si * x_dot_v * inv_l2_cubed
850                            - xi * s_dot_v * inv_l2_cubed
851                            - l1 * v[i] * inv_l2_cubed
852                            + 3.0 * l1 * xi * x_dot_v * inv_l2_5);
853                }
854                out
855            }
856        }
857    }
858
859    fn psd_majorizer_diag(
860        &self,
861        target: ArrayView1<'_, f64>,
862        rho: ArrayView1<'_, f64>,
863    ) -> Option<Array1<f64>> {
864        let (lam, smooth) = self.resolved(rho);
865        match self.kind {
866            // SmoothedL1 is convex: the majorizer equals the exact Hessian.
867            SparsityKind::SmoothedL1 { .. } => self.hessian_diag(target, rho),
868            // Log is nonconvex; expose the IRLS/MM re-weighted-ℓ₂ surrogate
869            //   2λ/(δ²+x²) ⪰ 2λ(δ²−x²)/(δ²+x²)²,
870            // strictly positive, agreeing with the exact Hessian at x = 0.
871            SparsityKind::Log { .. } => {
872                let mut d = Array1::<f64>::zeros(target.len());
873                let d2 = smooth * smooth;
874                for (i, &x) in target.iter().enumerate() {
875                    d[i] = lam * 2.0 / (d2 + x * x);
876                }
877                Some(d)
878            }
879            // Hoyer's Hessian is dense; no diagonal majorizer. Callers fall
880            // back to the exact dense `hvp` through `psd_majorizer_hvp`.
881            SparsityKind::Hoyer => None,
882        }
883    }
884
885    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
886        // Strength axis: ∂P/∂ρ_strength = P (chain rule through exp).
887        // ε axis (if owned): ∂P/∂ρ_eps = ε · ∂P/∂ε.
888        let n_rho = self.rho_count();
889        let mut out = Array1::<f64>::zeros(n_rho);
890        let p_val = self.value(target, rho);
891        out[self.strength_rho_index] = p_val;
892        if let Some(eps_idx) = self.eps_rho_index {
893            let (lam, smooth) = self.resolved(rho);
894            let mut dp_deps = 0.0;
895            match self.kind {
896                SparsityKind::SmoothedL1 { .. } => {
897                    for &x in target.iter() {
898                        dp_deps += smooth / (x * x + smooth * smooth).sqrt();
899                    }
900                    dp_deps *= lam;
901                }
902                SparsityKind::Log { .. } => {
903                    // d/dδ log(1 + x²/δ²) = -2 x² / (δ (δ² + x²))
904                    let d2 = smooth * smooth;
905                    for &x in target.iter() {
906                        dp_deps += -2.0 * x * x / (smooth * (d2 + x * x));
907                    }
908                    dp_deps *= lam;
909                }
910                SparsityKind::Hoyer => {}
911            }
912            // Chain through ρ_eps = log(ε)  ⇒  ∂ε/∂ρ_eps = ε.
913            out[eps_idx] = smooth * dp_deps;
914        }
915        out
916    }
917
918    fn rho_count(&self) -> usize {
919        1 + if self.eps_rho_index.is_some() { 1 } else { 0 }
920    }
921
922    fn name(&self) -> &str {
923        "sparsity"
924    }
925
926    impl_scalar_apply_schedule!(weight);
927}
928
929// ---------------------------------------------------------------------------
930// TopK activation penalty
931// ---------------------------------------------------------------------------
932
933#[derive(Debug, Clone)]
934pub struct TopKActivationPenalty {
935    pub target: PsiSlice,
936    pub k: usize,
937    pub latent_dim: usize,
938    pub weight: f64,
939    pub weight_schedule: Option<ScalarWeightSchedule>,
940}
941
942impl TopKActivationPenalty {
943    #[must_use = "build error must be handled"]
944    pub fn new(target: PsiSlice, k: usize, weight: f64) -> Result<Self, String> {
945        let latent_dim = target
946            .latent_dim
947            .ok_or_else(|| "TopKActivationPenalty::new requires target.latent_dim".to_string())?;
948        if latent_dim == 0 {
949            return Err("TopKActivationPenalty::new requires latent_dim > 0".to_string());
950        }
951        if k == 0 || k > latent_dim {
952            return Err(format!(
953                "TopKActivationPenalty::new requires 0 < k <= latent_dim; got k={k}, latent_dim={latent_dim}"
954            ));
955        }
956        if !(weight.is_finite() && weight > 0.0) {
957            return Err(format!(
958                "TopKActivationPenalty::new requires finite weight > 0, got {weight}"
959            ));
960        }
961        Ok(Self {
962            target,
963            k,
964            latent_dim,
965            weight,
966            weight_schedule: None,
967        })
968    }
969
970    impl_with_weight_schedule!(weight);
971
972    fn topk_mask_row(&self, target: ArrayView1<'_, f64>, row: usize, mask: &mut [bool]) {
973        mask.fill(false);
974        let d = self.latent_dim;
975        let base = row * d;
976        let mut order = (0..d).collect::<Vec<_>>();
977        order.sort_by(|&a, &b| {
978            target[base + b]
979                .abs()
980                .total_cmp(&target[base + a].abs())
981                .then_with(|| a.cmp(&b))
982        });
983        for &axis in order.iter().take(self.k) {
984            mask[axis] = true;
985        }
986    }
987}
988
989impl AnalyticPenalty for TopKActivationPenalty {
990    fn tier(&self) -> PenaltyTier {
991        PenaltyTier::Psi
992    }
993
994    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
995        assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
996        let d = self.latent_dim;
997        let n_obs = target.len() / d;
998        let mut mask = vec![false; d];
999        let mut acc = 0.0;
1000        for row in 0..n_obs {
1001            self.topk_mask_row(target, row, &mut mask);
1002            let base = row * d;
1003            for axis in 0..d {
1004                if mask[axis] {
1005                    let v = target[base + axis];
1006                    acc += 0.5 * self.weight * v * v;
1007                }
1008            }
1009        }
1010        acc
1011    }
1012
1013    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1014        assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1015        let d = self.latent_dim;
1016        let n_obs = target.len() / d;
1017        let mut mask = vec![false; d];
1018        let mut grad = Array1::<f64>::zeros(target.len());
1019        for row in 0..n_obs {
1020            self.topk_mask_row(target, row, &mut mask);
1021            let base = row * d;
1022            for axis in 0..d {
1023                if mask[axis] {
1024                    grad[base + axis] = self.weight * target[base + axis];
1025                }
1026            }
1027        }
1028        grad
1029    }
1030
1031    fn hessian_diag(
1032        &self,
1033        target: ArrayView1<'_, f64>,
1034        rho: ArrayView1<'_, f64>,
1035    ) -> Option<Array1<f64>> {
1036        assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1037        let d = self.latent_dim;
1038        let n_obs = target.len() / d;
1039        let mut mask = vec![false; d];
1040        let mut diag = Array1::<f64>::zeros(target.len());
1041        for row in 0..n_obs {
1042            self.topk_mask_row(target, row, &mut mask);
1043            let base = row * d;
1044            for axis in 0..d {
1045                if mask[axis] {
1046                    diag[base + axis] = self.weight;
1047                }
1048            }
1049        }
1050        Some(diag)
1051    }
1052
1053    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1054        assert_eq!(rho.len(), 0, "TopKActivationPenalty has no rho parameters");
1055        assert_eq!(
1056            target.len() % self.latent_dim,
1057            0,
1058            "TopKActivationPenalty target length must be a multiple of latent_dim"
1059        );
1060        Array1::<f64>::zeros(0)
1061    }
1062
1063    fn rho_count(&self) -> usize {
1064        0
1065    }
1066
1067    fn name(&self) -> &str {
1068        "topk_activation"
1069    }
1070
1071    impl_scalar_apply_schedule!(weight);
1072}
1073
1074// ---------------------------------------------------------------------------
1075// JumpReLU penalty
1076// ---------------------------------------------------------------------------
1077
1078#[derive(Debug, Clone)]
1079pub struct JumpReLUPenalty {
1080    pub target: PsiSlice,
1081    pub latent_dim: usize,
1082    pub thresholds: Array1<f64>,
1083    pub weight: f64,
1084    pub smoothing_eps: f64,
1085    pub weight_schedule: Option<ScalarWeightSchedule>,
1086}
1087
1088impl JumpReLUPenalty {
1089    #[must_use = "build error must be handled"]
1090    pub fn new(
1091        target: PsiSlice,
1092        thresholds: Array1<f64>,
1093        weight: f64,
1094        smoothing_eps: f64,
1095    ) -> Result<Self, String> {
1096        let latent_dim = target
1097            .latent_dim
1098            .ok_or_else(|| "JumpReLUPenalty::new requires target.latent_dim".to_string())?;
1099        if latent_dim == 0 {
1100            return Err("JumpReLUPenalty::new requires latent_dim > 0".to_string());
1101        }
1102        if thresholds.len() != latent_dim {
1103            return Err(format!(
1104                "JumpReLUPenalty::new thresholds length {} does not match latent_dim {latent_dim}",
1105                thresholds.len()
1106            ));
1107        }
1108        for (idx, &tau) in thresholds.iter().enumerate() {
1109            if !(tau.is_finite() && tau > 0.0) {
1110                return Err(format!(
1111                    "JumpReLUPenalty::new thresholds[{idx}] must be finite and > 0, got {tau}"
1112                ));
1113            }
1114        }
1115        if !(weight.is_finite() && weight > 0.0) {
1116            return Err(format!(
1117                "JumpReLUPenalty::new requires finite weight > 0, got {weight}"
1118            ));
1119        }
1120        if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
1121            return Err(format!(
1122                "JumpReLUPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
1123            ));
1124        }
1125        Ok(Self {
1126            target,
1127            latent_dim,
1128            thresholds,
1129            weight,
1130            smoothing_eps,
1131            weight_schedule: None,
1132        })
1133    }
1134
1135    impl_with_weight_schedule!(weight);
1136
1137    fn threshold(&self, axis: usize, rho: ArrayView1<'_, f64>) -> f64 {
1138        // A learnable threshold `θ·exp(rho)` overflows to `inf` for large `rho`;
1139        // the downstream gate `σ((l−θ)/τ)` then evaluates `inf·gate = NaN`. Clamp
1140        // the log-magnitude so the threshold stays a finite normal.
1141        resolve_learnable_weight(self.thresholds[axis], rho[axis])
1142    }
1143
1144    pub(crate) fn sigmoid_gate(&self, x: f64) -> f64 {
1145        if x >= 0.0 {
1146            1.0 / (1.0 + (-x).exp())
1147        } else {
1148            let ex = x.exp();
1149            ex / (1.0 + ex)
1150        }
1151    }
1152
1153    fn true_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
1154        self.weight * tau * gate * (1.0 - gate) * (1.0 - 2.0 * gate)
1155            / (self.smoothing_eps * self.smoothing_eps)
1156    }
1157
1158    fn psd_hessian_diag_entry(&self, tau: f64, gate: f64) -> f64 {
1159        // Genuine PSD majorizer of the indefinite exact diagonal Hessian
1160        //   h(g) = λτ·g(1−g)(1−2g)/ε².
1161        // The bare re-weighted-ℓ₂ surrogate λτ·[g(1−g)]²/ε² is ≥ 0 but only
1162        // dominates h in the concave region g > ½. For g < (3−√5)/2 ≈ 0.382 the
1163        // exact curvature is positive and strictly larger, so the square alone
1164        // is NOT an upper bound — the `B ⪰ ∂²P` contract is violated for exactly
1165        // the comfortably-below-threshold (inactive) coordinates JumpReLU is
1166        // meant to suppress, costing the MM step its monotone-decrease guarantee.
1167        //
1168        // Take the elementwise max of that surrogate and the absolute exact
1169        // Hessian |h| = λτ·g(1−g)|1−2g|/ε². Since |h| ≥ h everywhere and ≥ 0, the
1170        // max is a true PSD upper bound; it equals |h| in the wings (tight where
1171        // the bare square failed) and keeps the surrogate's strictly-positive
1172        // floor near the inflection g ≈ ½ (where h ≈ 0) so the curvature block
1173        // never collapses to zero.
1174        let slope = gate * (1.0 - gate);
1175        let reweighted_l2 = slope * slope;
1176        let abs_exact = slope * (1.0 - 2.0 * gate).abs();
1177        self.weight * tau * reweighted_l2.max(abs_exact) / (self.smoothing_eps * self.smoothing_eps)
1178    }
1179}
1180
1181/// JumpReLU activation gate `φ(z) = z · 1[z > τ]` together with the
1182/// straight-through-estimator derivatives of its smooth surrogate
1183/// `φ̃(z) = z · σ((z − τ)/ε)`. The forward value is the hard gate; the backward
1184/// uses the surrogate's gradients so the activation has a usable subgradient in
1185/// the smoothing band `|z − τ| ≲ ε`:
1186///
1187///   g       = σ((z − τ)/ε)
1188///   φ        = z · 1[z > τ]                 (returned value)
1189///   ∂φ̃/∂z   = g + z · g (1 − g) / ε          (`dphi_dz`)
1190///   ∂φ̃/∂τ   = − z · g (1 − g) / ε            (`dphi_dtau`)
1191///
1192/// This is the single Rust source of truth that `gamfit.torch`'s
1193/// `_JumpReLUSTEFn` consumes so the torch activation gate's backward matches the
1194/// smoothed gate exactly instead of re-deriving it in Python.
1195#[must_use]
1196pub fn jumprelu_gate_value_grad(z: f64, tau: f64, smoothing_eps: f64) -> (f64, f64, f64) {
1197    let g = gam_linalg::utils::stable_logistic((z - tau) / smoothing_eps);
1198    let value = if z > tau { z } else { 0.0 };
1199    let slope = z * g * (1.0 - g) / smoothing_eps;
1200    let dphi_dz = g + slope;
1201    let dphi_dtau = -slope;
1202    (value, dphi_dz, dphi_dtau)
1203}
1204
1205impl AnalyticPenalty for JumpReLUPenalty {
1206    fn tier(&self) -> PenaltyTier {
1207        PenaltyTier::Psi
1208    }
1209
1210    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1211        let d = self.latent_dim;
1212        let n_obs = target.len() / d;
1213        let mut acc = 0.0;
1214        for row in 0..n_obs {
1215            let base = row * d;
1216            for axis in 0..d {
1217                let tau = self.threshold(axis, rho);
1218                let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1219                acc += self.weight * tau * gate;
1220            }
1221        }
1222        acc
1223    }
1224
1225    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1226        let d = self.latent_dim;
1227        let n_obs = target.len() / d;
1228        let mut grad = Array1::<f64>::zeros(target.len());
1229        for row in 0..n_obs {
1230            let base = row * d;
1231            for axis in 0..d {
1232                let tau = self.threshold(axis, rho);
1233                let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1234                grad[base + axis] = self.weight * tau * gate * (1.0 - gate) / self.smoothing_eps;
1235            }
1236        }
1237        grad
1238    }
1239
1240    fn hessian_diag(
1241        &self,
1242        target: ArrayView1<'_, f64>,
1243        rho: ArrayView1<'_, f64>,
1244    ) -> Option<Array1<f64>> {
1245        let d = self.latent_dim;
1246        let n_obs = target.len() / d;
1247        let mut diag = Array1::<f64>::zeros(target.len());
1248        for row in 0..n_obs {
1249            let base = row * d;
1250            for axis in 0..d {
1251                let tau = self.threshold(axis, rho);
1252                let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1253                diag[base + axis] = self.true_hessian_diag_entry(tau, gate);
1254            }
1255        }
1256        Some(diag)
1257    }
1258
1259    fn hvp(
1260        &self,
1261        target: ArrayView1<'_, f64>,
1262        rho: ArrayView1<'_, f64>,
1263        v: ArrayView1<'_, f64>,
1264    ) -> Array1<f64> {
1265        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1266        let d = self.latent_dim;
1267        let n_obs = target.len() / d;
1268        let mut out = Array1::<f64>::zeros(target.len());
1269        for row in 0..n_obs {
1270            let base = row * d;
1271            for axis in 0..d {
1272                let tau = self.threshold(axis, rho);
1273                let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1274                out[base + axis] = self.true_hessian_diag_entry(tau, gate) * v[base + axis];
1275            }
1276        }
1277        out
1278    }
1279
1280    fn psd_majorizer_diag(
1281        &self,
1282        target: ArrayView1<'_, f64>,
1283        rho: ArrayView1<'_, f64>,
1284    ) -> Option<Array1<f64>> {
1285        // The smoothed JumpReLU surrogate's exact diagonal Hessian
1286        //   λτ·g(1−g)(1−2g)/ε²
1287        // is indefinite (negative once the gate passes the inflection
1288        // g = ½). The Newton / PIRLS pipeline needs a PSD curvature block, so
1289        // expose the PSD upper bound implemented by `psd_hessian_diag_entry`:
1290        // the elementwise max of the re-weighted surrogate and the absolute
1291        // exact curvature.
1292        let d = self.latent_dim;
1293        let n_obs = target.len() / d;
1294        let mut diag = Array1::<f64>::zeros(target.len());
1295        for row in 0..n_obs {
1296            let base = row * d;
1297            for axis in 0..d {
1298                let tau = self.threshold(axis, rho);
1299                let gate = self.sigmoid_gate((target[base + axis] - tau) / self.smoothing_eps);
1300                diag[base + axis] = self.psd_hessian_diag_entry(tau, gate);
1301            }
1302        }
1303        Some(diag)
1304    }
1305
1306    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1307        let d = self.latent_dim;
1308        let n_obs = target.len() / d;
1309        let mut out = Array1::<f64>::zeros(d);
1310        for axis in 0..d {
1311            let tau = self.threshold(axis, rho);
1312            let mut g_tau = 0.0;
1313            for row in 0..n_obs {
1314                let x = target[row * d + axis];
1315                let gate = self.sigmoid_gate((x - tau) / self.smoothing_eps);
1316                g_tau += gate - tau * gate * (1.0 - gate) / self.smoothing_eps;
1317            }
1318            out[axis] = self.weight * tau * g_tau;
1319        }
1320        out
1321    }
1322
1323    fn rho_count(&self) -> usize {
1324        self.latent_dim
1325    }
1326
1327    fn name(&self) -> &str {
1328        "jumprelu"
1329    }
1330
1331    impl_scalar_apply_schedule!(weight);
1332}
1333
1334#[cfg(test)]
1335mod fisher_majorizer_1419_tests {
1336    use super::*;
1337    use approx::assert_abs_diff_eq;
1338    use gam_linalg::faer_ndarray::FaerEigh;
1339    use ndarray::Array2;
1340
1341    /// #1419 — the Fisher information metric `G = scale·(diag(a) − a aᵀ)` is PSD
1342    /// but is NOT a curvature majorizer of the exact softmax-entropy Hessian
1343    /// `H_entropy`: `G − H_entropy` is indefinite. The genuine Gershgorin
1344    /// diagonal operator `D_kk = Σ_j|H_kj|` (now `row_psd_majorizer`) IS a
1345    /// Loewner majorizer: `D − H_entropy ⪰ 0` AND `D ⪰ 0`.
1346    ///
1347    /// Oracle: the exact entropy Hessian is built independently from
1348    /// `row_dense_hessian` (the formula at sparsity.rs:160-193); the smallest
1349    /// eigenvalue of `M − H` is computed by a direct symmetric eigensolve. The
1350    /// stated K=2 counterexample (`a=(0.95,0.05)`, `λ=τ=1`) is pinned numerically
1351    /// against the issue's `H_11 = 0.0783747664` and `G_11 = 0.0475`, and the
1352    /// contrast (Fisher FAILS, Gershgorin PASSES) is asserted in both the full
1353    /// K×K block and the single free direction of the reference-logit chart.
1354    #[test]
1355    fn gershgorin_majorizes_entropy_where_fisher_does_not_1419() {
1356        // K=2, λ=τ=1 ⇒ scale = λ/τ² = 1. Logits that realize a = (0.95, 0.05):
1357        // softmax([z0,z1]) = (0.95,0.05) ⟹ z0 − z1 = ln(0.95/0.05) = ln(19).
1358        let temperature = 1.0_f64;
1359        let scale = 1.0_f64; // λ/τ² with λ=1, τ=1.
1360        let pen = SoftmaxAssignmentSparsityPenalty::new(2, temperature);
1361        let z1 = 0.0_f64;
1362        let z0 = z1 + (0.95_f64 / 0.05_f64).ln();
1363        let row = [z0, z1];
1364
1365        // Confirm the realized softmax weights.
1366        let a = pen.softmax_row(&row);
1367        assert_abs_diff_eq!(a[0], 0.95, epsilon = 1e-12);
1368        assert_abs_diff_eq!(a[1], 0.05, epsilon = 1e-12);
1369
1370        // Independent oracles: exact entropy Hessian, Fisher metric, majorizer.
1371        let h = pen.row_dense_hessian(&row, scale);
1372        let g = pen.row_fisher_metric(&row, scale);
1373        let m = pen.row_psd_majorizer(&row, scale);
1374
1375        // Pin the issue's exact numbers in the sole free direction (index 0):
1376        //   H_11 = 0.0783747664,  G_11 = a0·a1 = 0.0475.
1377        assert_abs_diff_eq!(h[[0, 0]], 0.0783747664, epsilon = 1e-9);
1378        assert_abs_diff_eq!(g[[0, 0]], 0.95 * 0.05, epsilon = 1e-12);
1379
1380        // The genuine majorizer's diagonal is the abs-row-sum D_kk = Σ_j|H_kj|.
1381        for kk in 0..2 {
1382            let row_sum: f64 = (0..2).map(|jj| h[[kk, jj]].abs()).sum();
1383            assert_abs_diff_eq!(m[[kk, kk]], row_sum, epsilon = 1e-12);
1384        }
1385        // M is a nonnegative diagonal (PSD by inspection) — off-diagonals zero.
1386        assert_abs_diff_eq!(m[[0, 1]], 0.0, epsilon = 1e-15);
1387        assert_abs_diff_eq!(m[[1, 0]], 0.0, epsilon = 1e-15);
1388        assert!(m[[0, 0]] >= 0.0 && m[[1, 1]] >= 0.0);
1389
1390        // Reference-logit chart: hold z1 fixed, the only free direction is z0, so
1391        // the reduced 1×1 curvature is the (0,0) entry. Fisher FAILS the Loewner
1392        // bound there (G_11 − H_11 < 0), the Gershgorin majorizer PASSES it.
1393        let fisher_free = g[[0, 0]] - h[[0, 0]];
1394        let major_free = m[[0, 0]] - h[[0, 0]];
1395        assert!(
1396            fisher_free < -1e-3,
1397            "Fisher must FAIL the majorizer bound in the free direction (#1419); \
1398             G_11 − H_11 = {fisher_free}"
1399        );
1400        assert!(
1401            major_free >= -1e-12,
1402            "Gershgorin majorizer must SATISFY the bound in the free direction (#1419); \
1403             D_11 − H_11 = {major_free}"
1404        );
1405
1406        // Full K×K Loewner check via a direct symmetric eigensolve oracle.
1407        // smallest eigenvalue of (M − H) ≥ −tiny ⟹ M ⪰ H; the Fisher case has a
1408        // strictly negative smallest eigenvalue ⟹ G ⋡ H.
1409        let mut m_minus_h = Array2::<f64>::zeros((2, 2));
1410        let mut g_minus_h = Array2::<f64>::zeros((2, 2));
1411        for i in 0..2 {
1412            for j in 0..2 {
1413                m_minus_h[[i, j]] = m[[i, j]] - h[[i, j]];
1414                g_minus_h[[i, j]] = g[[i, j]] - h[[i, j]];
1415            }
1416        }
1417        let (m_evals, _) = m_minus_h.eigh(faer::Side::Lower).expect("eigh(M−H)");
1418        let (g_evals, _) = g_minus_h.eigh(faer::Side::Lower).expect("eigh(G−H)");
1419        let m_min = m_evals.iter().cloned().fold(f64::INFINITY, f64::min);
1420        let g_min = g_evals.iter().cloned().fold(f64::INFINITY, f64::min);
1421        assert!(
1422            m_min >= -1e-12,
1423            "Gershgorin majorizer must be a Loewner majorizer (M − H ⪰ 0, #1419); \
1424             smallest eigenvalue of M−H = {m_min}"
1425        );
1426        assert!(
1427            g_min < -1e-9,
1428            "the OLD Fisher metric must FAIL the Loewner majorizer test (#1419); \
1429             smallest eigenvalue of G−H = {g_min} (expected strictly negative)"
1430        );
1431    }
1432
1433    /// #1419 — the majorizer's θ-derivative `∂D_kk/∂z_w = Σ_j sign(H_kj)∂H_kj/∂z_w`
1434    /// is the exact derivative of the operator the assembly installs, so value and
1435    /// log-det adjoint differentiate the SAME `D`. Oracle: a central finite
1436    /// difference of `row_psd_majorizer` itself (away from any sign change, the
1437    /// abs-row-sum is smooth). FD is permitted ONLY inside this test as an
1438    /// independent check of the closed-form derivative.
1439    #[test]
1440    fn gershgorin_majorizer_logit_derivative_matches_fd_1419() {
1441        let pen = SoftmaxAssignmentSparsityPenalty::new(4, 0.8);
1442        let row = [0.3_f64, -0.6, 0.9, 0.2];
1443        let scale = 1.1_f64 * (1.0 / 0.8_f64) * (1.0 / 0.8_f64);
1444        let eps = 1e-6;
1445        for w in 0..4 {
1446            let dd = pen.row_psd_majorizer_logit_derivative(&row, scale, w);
1447            let mut rp = row;
1448            let mut rm = row;
1449            rp[w] += eps;
1450            rm[w] -= eps;
1451            let mp = pen.row_psd_majorizer(&rp, scale);
1452            let mm = pen.row_psd_majorizer(&rm, scale);
1453            for k in 0..4 {
1454                let fd = (mp[[k, k]] - mm[[k, k]]) / (2.0 * eps);
1455                assert_abs_diff_eq!(dd[[k, k]], fd, epsilon = 1e-6);
1456            }
1457            // The derivative is a pure diagonal (D is diagonal).
1458            for i in 0..4 {
1459                for j in 0..4 {
1460                    if i != j {
1461                        assert_abs_diff_eq!(dd[[i, j]], 0.0, epsilon = 1e-15);
1462                    }
1463                }
1464            }
1465        }
1466    }
1467}