Skip to main content

gam_terms/analytic_penalties/
orthogonality.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// Block-orthogonality penalty
5// ---------------------------------------------------------------------------
6
7/// Between-block-only orthogonality on a row-major matrix-valued latent
8/// block.
9///
10/// Lives on the extension-coordinate tier. Penalizes the squared Frobenius
11/// norm of the between-block Gram matrices, where `T` is the row-major
12/// `n_eff × latent_dim` view of the target slice and `groups` partitions
13/// the latent axes into disjoint subsets:
14///
15/// ```text
16///   P(T) = ½ · w · Σ_{g < h} ‖ T[:, group_g]^T T[:, group_h] ‖²_F
17/// ```
18///
19/// Within-block structure is unconstrained: this penalty only pushes different
20/// groups into mutually orthogonal subspaces. In the SAE objective it is the
21/// block-level separability / gauge term for latent decompositions where known
22/// or supervised coordinates should not leak into free coordinates.
23///
24/// Typical use: gauge-fixing a latent decomposition where one block has been
25/// supervised (e.g. anchored to known coordinates) and a free block needs to
26/// inhabit the orthogonal complement of that supervision. Pair with per-block
27/// ARD or sparsity when you also want within-block axis selection.
28///
29/// Gotchas:
30///
31/// * `groups` must be a true partition of all latent axes: every axis appears
32///   exactly once, and at least two groups are required.
33/// * The Hessian is dense across rows and axes even though an exact diagonal is
34///   available for diagnostics/preconditioning. Use the HVP for the full
35///   Newton curvature.
36#[derive(Debug, Clone)]
37pub struct BlockOrthogonalityPenalty {
38    pub target: PsiSlice,
39    pub groups: Vec<Vec<usize>>,
40    /// Base strength. If `learnable_weight` is true, the resolved strength is
41    /// `weight * exp(rho[rho_index])`; otherwise it is fixed at `weight`.
42    pub weight: f64,
43    /// Number of rows in the row-major matrix-valued latent block.
44    pub n_eff: usize,
45    pub learnable_weight: bool,
46    pub rho_index: usize,
47    pub weight_schedule: Option<ScalarWeightSchedule>,
48}
49
50impl BlockOrthogonalityPenalty {
51    #[must_use = "build error must be handled"]
52    pub fn new(
53        target: PsiSlice,
54        groups: Vec<Vec<usize>>,
55        weight: f64,
56        n_eff: usize,
57        learnable_weight: bool,
58    ) -> Result<Self, String> {
59        if target.is_empty() {
60            return Err("BlockOrthogonalityPenalty::new requires a non-empty target".to_string());
61        }
62        if !(weight.is_finite() && weight > 0.0) {
63            return Err(format!(
64                "BlockOrthogonalityPenalty::new requires finite weight > 0, got {weight}"
65            ));
66        }
67        if n_eff == 0 {
68            return Err("BlockOrthogonalityPenalty::new requires n_eff > 0".to_string());
69        }
70        if !target.len().is_multiple_of(n_eff) {
71            return Err(format!(
72                "BlockOrthogonalityPenalty::new target length {} is not divisible by n_eff {}",
73                target.len(),
74                n_eff
75            ));
76        }
77        let latent_dim = target.len() / n_eff;
78        if let Some(expected_dim) = target.latent_dim {
79            let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
80                "BlockOrthogonalityPenalty::new target shape overflows usize".to_string()
81            })?;
82            if expected != target.len() {
83                return Err(format!(
84                    "BlockOrthogonalityPenalty::new target length {} does not match n_eff {} × latent_dim {}",
85                    target.len(),
86                    n_eff,
87                    expected_dim
88                ));
89            }
90        }
91        if groups.len() < 2 {
92            return Err("BlockOrthogonalityPenalty::new requires at least two groups".to_string());
93        }
94        let mut seen = vec![false; latent_dim];
95        for (group_idx, group) in groups.iter().enumerate() {
96            if group.is_empty() {
97                return Err(format!(
98                    "BlockOrthogonalityPenalty::new groups[{group_idx}] must not be empty"
99                ));
100            }
101            for &axis in group {
102                if axis >= latent_dim {
103                    return Err(format!(
104                        "BlockOrthogonalityPenalty::new groups[{group_idx}] axis {axis} exceeds latent_dim {latent_dim}"
105                    ));
106                }
107                if seen[axis] {
108                    return Err(format!(
109                        "BlockOrthogonalityPenalty::new axis {axis} appears in more than one group"
110                    ));
111                }
112                seen[axis] = true;
113            }
114        }
115        for (axis, present) in seen.iter().copied().enumerate() {
116            if !present {
117                return Err(format!(
118                    "BlockOrthogonalityPenalty::new groups must partition latent axes; missing axis {axis}"
119                ));
120            }
121        }
122        Ok(Self {
123            target,
124            groups,
125            weight,
126            n_eff,
127            learnable_weight,
128            rho_index: 0,
129            weight_schedule: None,
130        })
131    }
132
133    impl_with_weight_schedule!(weight);
134
135    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
136        if self.learnable_weight {
137            resolve_learnable_weight(self.weight, rho[self.rho_index])
138        } else {
139            self.weight
140        }
141    }
142
143    fn latent_dim(&self, target_len: usize) -> Option<usize> {
144        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
145            assert_eq!(
146                target_len % self.n_eff.max(1),
147                0,
148                "target length must be divisible by n_eff"
149            );
150            return None;
151        }
152        Some(target_len / self.n_eff)
153    }
154
155    fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
156        let d = self.latent_dim(target.len())?;
157        target.into_shape_with_order((self.n_eff, d)).ok()
158    }
159
160    fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
161        let n_obs = m.nrows();
162        let d = m.ncols();
163        let mut out = Array1::<f64>::zeros(n_obs * d);
164        for n in 0..n_obs {
165            for a in 0..d {
166                out[n * d + a] = m[[n, a]];
167            }
168        }
169        out
170    }
171
172    fn cross_gram(t: ArrayView2<'_, f64>, left: &[usize], right: &[usize]) -> Array2<f64> {
173        let mut out = Array2::<f64>::zeros((left.len(), right.len()));
174        for (li, &a) in left.iter().enumerate() {
175            for (ri, &b) in right.iter().enumerate() {
176                let mut s = 0.0;
177                for n in 0..t.nrows() {
178                    s += t[[n, a]] * t[[n, b]];
179                }
180                out[[li, ri]] = s;
181            }
182        }
183        out
184    }
185
186    /// `out[li, ri] = Σ_n a[n, left[li]] · b[n, right[ri]]` — two-argument
187    /// cross-gram used to assemble the directional derivative of `C_{gh}` in
188    /// direction `v`:  `∂_v C_{gh}[gi, hi] = Σ_n {v[n, axes_g[gi]] · t[n, axes_h[hi]] + t[n, axes_g[gi]] · v[n, axes_h[hi]]}`.
189    /// The `cross_gram` helper (single-input self-product) was previously
190    /// (mis)used for both terms, but `cross_gram(v, h, g) + cross_gram(t, h, g)`
191    /// is the unrelated quantity `(v⊗v) + (t⊗t)`, not `(v⊗t) + (t⊗v)`.
192    fn mixed_cross_gram(
193        a: ArrayView2<'_, f64>,
194        b: ArrayView2<'_, f64>,
195        left: &[usize],
196        right: &[usize],
197    ) -> Array2<f64> {
198        assert_eq!(a.nrows(), b.nrows(), "mixed_cross_gram row mismatch");
199        let mut out = Array2::<f64>::zeros((left.len(), right.len()));
200        for (li, &al) in left.iter().enumerate() {
201            for (ri, &br) in right.iter().enumerate() {
202                let mut s = 0.0;
203                for n in 0..a.nrows() {
204                    s += a[[n, al]] * b[[n, br]];
205                }
206                out[[li, ri]] = s;
207            }
208        }
209        out
210    }
211
212    fn add_right_times_cross(
213        out: &mut Array2<f64>,
214        right: ArrayView2<'_, f64>,
215        left_axes: &[usize],
216        right_axes: &[usize],
217        cross_right_left: ArrayView2<'_, f64>,
218        factor: f64,
219    ) {
220        assert_eq!(cross_right_left.dim(), (right_axes.len(), left_axes.len()));
221        for n in 0..out.nrows() {
222            for (li, &left_axis) in left_axes.iter().enumerate() {
223                let mut s = 0.0;
224                for (ri, &right_axis) in right_axes.iter().enumerate() {
225                    s += right[[n, right_axis]] * cross_right_left[[ri, li]];
226                }
227                out[[n, left_axis]] += factor * s;
228            }
229        }
230    }
231
232    fn hvp_with_precomputed_cross(
233        &self,
234        t: ArrayView2<'_, f64>,
235        cross: &[Vec<Option<Array2<f64>>>],
236        v: ArrayView2<'_, f64>,
237        weight: f64,
238    ) -> Array2<f64> {
239        assert_eq!(v.dim(), t.dim(), "hvp matrix dimension mismatch");
240        if v.dim() != t.dim() {
241            return Array2::<f64>::zeros(t.dim());
242        }
243        let mut out = Array2::<f64>::zeros(t.dim());
244        for g in 0..self.groups.len() {
245            let group_g = &self.groups[g];
246            for h in 0..self.groups.len() {
247                if g == h {
248                    continue;
249                }
250                let group_h = &self.groups[h];
251                let c_hg = cross[h][g]
252                    .as_ref()
253                    .expect("between-block cross Gram must be precomputed");
254                // Linear contribution: w · Σ_b C_{g,h}[i,b] · v[n, axes_h[b]] —
255                // the C-direct piece of d/dv (∂P/∂t).
256                Self::add_right_times_cross(&mut out, v, group_g, group_h, c_hg.view(), weight);
257
258                // Directional derivative of C_{hg} in direction v:
259                //   ∂_v C_{hg}[hi, gi] = Σ_n {v[n, axes_h[hi]] · t[n, axes_g[gi]]
260                //                            + t[n, axes_h[hi]] · v[n, axes_g[gi]]}
261                // = MixedCross(v, t, h, g) + MixedCross(t, v, h, g).
262                // The earlier formulation used `cross_gram(v, h, g) +
263                // cross_gram(t, h, g)`, which is `(v⊗v) + (t⊗t)` — quadratic in v
264                // (resp. independent of v) and unrelated to the JVP. The bug made
265                // the Hessian non-symmetric (it added a fixed `(t⊗t)`-driven term
266                // to every column), violated the gradient/Hessian consistency
267                // check that REML's spectral solve relies on, and the sibling
268                // `OrthogonalityPenalty::hvp_with_precomputed_m` already uses the
269                // correct `v_c · t_b + t_c · v_b` mixed pattern.
270                let dv_h_g = Self::mixed_cross_gram(v, t, group_h, group_g);
271                let tv_h_g = Self::mixed_cross_gram(t, v, group_h, group_g);
272                let mut d_c_hg = dv_h_g;
273                d_c_hg += &tv_h_g;
274                Self::add_right_times_cross(&mut out, t, group_g, group_h, d_c_hg.view(), weight);
275            }
276        }
277        out
278    }
279
280    fn precompute_cross(&self, t: ArrayView2<'_, f64>) -> Vec<Vec<Option<Array2<f64>>>> {
281        let mut cross = vec![vec![None; self.groups.len()]; self.groups.len()];
282        for g in 0..self.groups.len() {
283            for h in 0..self.groups.len() {
284                if g != h {
285                    cross[g][h] = Some(Self::cross_gram(t, &self.groups[g], &self.groups[h]));
286                }
287            }
288        }
289        cross
290    }
291
292    /// Materialize the between-block orthogonality Hessian for small-block
293    /// spectral paths.
294    pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
295        let n = target.len();
296        let Some(t) = self.target_matrix(target) else {
297            return Array2::<f64>::zeros((n, n));
298        };
299        let cross = self.precompute_cross(t.view());
300        let weight = self.resolved_weight(rho);
301        let mut dense = Array2::<f64>::zeros((n, n));
302        let mut e = Array1::<f64>::zeros(n);
303        for j in 0..n {
304            e[j] = 1.0;
305            let Some(e_mat) = self.target_matrix(e.view()) else {
306                return Array2::<f64>::zeros((n, n));
307            };
308            let col = self.hvp_with_precomputed_cross(t.view(), &cross, e_mat, weight);
309            for i in 0..n {
310                dense[[i, j]] = col[[i / t.ncols(), i % t.ncols()]];
311            }
312            e[j] = 0.0;
313        }
314        dense
315    }
316}
317
318impl AnalyticPenalty for BlockOrthogonalityPenalty {
319    fn tier(&self) -> PenaltyTier {
320        PenaltyTier::Psi
321    }
322
323    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
324        let Some(t) = self.target_matrix(target) else {
325            return 0.0;
326        };
327        let mut acc = 0.0;
328        for g in 0..self.groups.len() {
329            for h in (g + 1)..self.groups.len() {
330                let c = Self::cross_gram(t.view(), &self.groups[g], &self.groups[h]);
331                for &v in c.iter() {
332                    acc += v * v;
333                }
334            }
335        }
336        0.5 * self.resolved_weight(rho) * acc
337    }
338
339    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
340        let Some(t) = self.target_matrix(target) else {
341            return Array1::<f64>::zeros(target.len());
342        };
343        let cross = self.precompute_cross(t.view());
344        let weight = self.resolved_weight(rho);
345        let mut grad = Array2::<f64>::zeros(t.dim());
346        for g in 0..self.groups.len() {
347            for h in 0..self.groups.len() {
348                if g == h {
349                    continue;
350                }
351                let c_hg = cross[h][g]
352                    .as_ref()
353                    .expect("between-block cross Gram must be precomputed");
354                Self::add_right_times_cross(
355                    &mut grad,
356                    t.view(),
357                    &self.groups[g],
358                    &self.groups[h],
359                    c_hg.view(),
360                    weight,
361                );
362            }
363        }
364        Self::flatten_matrix(&grad)
365    }
366
367    fn hvp(
368        &self,
369        target: ArrayView1<'_, f64>,
370        rho: ArrayView1<'_, f64>,
371        v: ArrayView1<'_, f64>,
372    ) -> Array1<f64> {
373        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
374        if target.len() != v.len() {
375            return Array1::<f64>::zeros(target.len());
376        }
377        let Some(t) = self.target_matrix(target) else {
378            return Array1::<f64>::zeros(target.len());
379        };
380        let Some(v_mat) = self.target_matrix(v) else {
381            return Array1::<f64>::zeros(target.len());
382        };
383        let cross = self.precompute_cross(t.view());
384        let hv = self.hvp_with_precomputed_cross(
385            t.view(),
386            &cross,
387            v_mat.view(),
388            self.resolved_weight(rho),
389        );
390        Self::flatten_matrix(&hv)
391    }
392
393    fn hessian_diag(
394        &self,
395        target: ArrayView1<'_, f64>,
396        rho: ArrayView1<'_, f64>,
397    ) -> Option<Array1<f64>> {
398        let t = self.target_matrix(target)?;
399        let n_obs = t.nrows();
400        let d = t.ncols();
401        let weight = self.resolved_weight(rho);
402        let mut group_of = vec![usize::MAX; d];
403        for (gi, group) in self.groups.iter().enumerate() {
404            for &axis in group {
405                group_of[axis] = gi;
406            }
407        }
408        let mut out = Array1::<f64>::zeros(n_obs * d);
409        for n in 0..n_obs {
410            let mut row_sq = 0.0_f64;
411            let mut group_sq = vec![0.0_f64; self.groups.len()];
412            for b in 0..d {
413                let v = t[[n, b]];
414                let v2 = v * v;
415                row_sq += v2;
416                group_sq[group_of[b]] += v2;
417            }
418            for a in 0..d {
419                let g = group_of[a];
420                out[n * d + a] = weight * (row_sq - group_sq[g]);
421            }
422        }
423        Some(out)
424    }
425
426    impl_learnable_weight_grad_rho!();
427
428    impl_learnable_weight_rho_count!();
429
430    fn name(&self) -> &str {
431        "block_orthogonality"
432    }
433
434    impl_scalar_apply_schedule!(weight);
435}
436
437// ---------------------------------------------------------------------------
438// Decoder column-space incoherence penalty
439// ---------------------------------------------------------------------------
440
441/// Cross-atom decoder column-space incoherence, restricted to co-activating
442/// atom pairs (issue #671).
443///
444/// Lives on the β tier and targets the flat SAE decoder coefficient block. The
445/// β layout concatenates the per-atom decoder blocks in atom order: atom `k`
446/// owns `M_k · p_out` coefficients, stored as
447/// `β[off_k + a·p_out + o]` for basis row `a` and output feature `o`.
448/// The stored block is `B_k ∈ ℝ^{M_k × p_out}` with rows `B_k[a, :]`
449/// representing decoder directions in output space.
450///
451/// The penalty is the co-activation-masked cross-column-space overlap
452///
453/// ```text
454///   P = ½ · w · Σ_{j<k} W[j,k] · ‖B_j B_k^T‖²_F,
455///   W[j,k] = ½ · (coactivation[j,k] + coactivation[k,j]).
456/// ```
457///
458/// `coactivation[j,k]` is the mean over observations of
459/// `gate[n,j] · gate[n,k]`; pairs that never co-fire (`W[j,k] = 0`) contribute
460/// nothing. In the SAE objective this is the separability lever: atoms that
461/// are active on the same examples are discouraged from spanning the same
462/// decoder output directions, while unrelated atoms are not pushed apart just
463/// because they both exist in the dictionary.
464///
465/// The Hessian used here is the Gauss-Newton (positive-semidefinite) curvature
466/// of the Frobenius objective in `C`, dropping the indefinite second-order term
467/// in `C`. This keeps the β-tier Newton / PIRLS curvature block PSD, matching
468/// the other quadratic-on-Gram penalties.
469///
470/// Gotchas:
471///
472/// * `block_sizes` are decoder basis-row counts `M_k`, not output widths;
473///   every atom shares the same `p_out`. Stored decoder blocks are
474///   `(M_k, p_out)`, so `B_j B_k^T` is the cross-Gram of decoder directions in
475///   output space and remains well-defined for heterogeneous `M_k`.
476/// * The descriptor path builds a placeholder penalty; live SAE wiring replaces
477///   the co-activation matrix with the current mean gate products.
478/// * Offsets are interpreted against the vector passed to this penalty. In the
479///   SAE decoder-incoherence path the registered target slice is zero-based;
480///   callers using an already sliced target view must keep that convention.
481#[derive(Debug, Clone)]
482pub struct DecoderIncoherencePenalty {
483    pub target: PsiSlice,
484    /// Per-atom decoder basis-function counts `M_k`. The atom blocks are laid
485    /// out contiguously in β order; `Σ_k M_k·p_out == target.len()`.
486    pub block_sizes: Vec<usize>,
487    /// Output / feature dimension `p_out` (decoder column count, shared by all
488    /// atoms).
489    pub p_out: usize,
490    /// Atom count `K`. The operator only stores the SPARSE list of penalized
491    /// atom pairs (`pairs`), not the dense `K×K` co-activation matrix — at
492    /// `K = 32768` that dense matrix is 8 GiB. Every consumer of this operator
493    /// already skipped pairs whose symmetrized weight is `0`, so storing only
494    /// the nonzero pairs is exactly equivalent to the dense matrix while being
495    /// linear in the number of co-active / near-collinear pairs (#1026).
496    pub k_atoms: usize,
497    /// Sparse penalized atom pairs `(j, k, w)` with `j < k` and the symmetrized
498    /// weight `w = ½·(W[j,k] + W[k,j]) > 0` (this is exactly the value the old
499    /// `pair_weight(j, k)` returned). Pairs with `w == 0` are omitted; the dense
500    /// operator skipped them, so results are byte-identical.
501    pub pairs: Vec<(usize, usize, f64)>,
502    /// Base strength. If `learnable_weight` is true the resolved strength is
503    /// `weight·exp(rho[rho_index])`; otherwise it is fixed at `weight`.
504    pub weight: f64,
505    pub learnable_weight: bool,
506    pub rho_index: usize,
507    pub weight_schedule: Option<ScalarWeightSchedule>,
508}
509
510impl DecoderIncoherencePenalty {
511    #[must_use = "build error must be handled"]
512    pub fn new(
513        target: PsiSlice,
514        block_sizes: Vec<usize>,
515        p_out: usize,
516        coactivation: Array2<f64>,
517        weight: f64,
518        learnable_weight: bool,
519    ) -> Result<Self, String> {
520        if target.is_empty() {
521            return Err("DecoderIncoherencePenalty::new requires a non-empty target".to_string());
522        }
523        if !(weight.is_finite() && weight > 0.0) {
524            return Err(format!(
525                "DecoderIncoherencePenalty::new requires finite weight > 0, got {weight}"
526            ));
527        }
528        if p_out == 0 {
529            return Err("DecoderIncoherencePenalty::new requires p_out > 0".to_string());
530        }
531        if block_sizes.len() < 2 {
532            return Err(
533                "DecoderIncoherencePenalty::new requires at least two atom blocks".to_string(),
534            );
535        }
536        let k = block_sizes.len();
537        if coactivation.dim() != (k, k) {
538            return Err(format!(
539                "DecoderIncoherencePenalty::new requires (K, K)=({k}, {k}) coactivation; got {:?}",
540                coactivation.dim()
541            ));
542        }
543        if !coactivation
544            .iter()
545            .all(|value| value.is_finite() && *value >= 0.0)
546        {
547            return Err(
548                "DecoderIncoherencePenalty::new requires finite non-negative coactivation entries"
549                    .to_string(),
550            );
551        }
552        let mut total = 0usize;
553        for (atom_idx, &m) in block_sizes.iter().enumerate() {
554            if m == 0 {
555                return Err(format!(
556                    "DecoderIncoherencePenalty::new block_sizes[{atom_idx}] must be > 0"
557                ));
558            }
559            let span = m.checked_mul(p_out).ok_or_else(|| {
560                "DecoderIncoherencePenalty::new block span overflows usize".to_string()
561            })?;
562            total = total.checked_add(span).ok_or_else(|| {
563                "DecoderIncoherencePenalty::new total span overflows usize".to_string()
564            })?;
565        }
566        if total != target.len() {
567            return Err(format!(
568                "DecoderIncoherencePenalty::new Σ_k M_k·p_out = {total} does not match target length {}",
569                target.len()
570            ));
571        }
572        // Sparsify: store only the upper-triangular pairs whose symmetrized
573        // weight `½·(W[j,k]+W[k,j])` is nonzero. The dense operator skipped every
574        // pair with a zero symmetrized weight, so the sparse list reproduces it
575        // bit-for-bit while never materializing the dense `K×K` matrix downstream.
576        let mut pairs = Vec::new();
577        for j in 0..k {
578            for kk in (j + 1)..k {
579                let w = 0.5 * (coactivation[[j, kk]] + coactivation[[kk, j]]);
580                if w != 0.0 {
581                    pairs.push((j, kk, w));
582                }
583            }
584        }
585        Ok(Self {
586            target,
587            block_sizes,
588            p_out,
589            k_atoms: k,
590            pairs,
591            weight,
592            learnable_weight,
593            rho_index: 0,
594            weight_schedule: None,
595        })
596    }
597
598    /// Sparse-pair constructor used by the SAE live wiring (#1026): build the
599    /// operator directly from a list of penalized atom pairs `(j, k, w)` with
600    /// `j < k` and the symmetrized per-pair weight `w` (exactly the value the old
601    /// dense `pair_weight(j, k)` returned), avoiding any dense `K×K` allocation.
602    /// `w == 0` pairs and out-of-range indices are dropped / rejected. This is
603    /// equivalent to [`Self::new`] fed the dense symmetric matrix with the same
604    /// nonzero entries.
605    #[must_use = "build error must be handled"]
606    pub fn new_sparse(
607        target: PsiSlice,
608        block_sizes: Vec<usize>,
609        p_out: usize,
610        pairs: Vec<(usize, usize, f64)>,
611        weight: f64,
612        learnable_weight: bool,
613    ) -> Result<Self, String> {
614        if target.is_empty() {
615            return Err(
616                "DecoderIncoherencePenalty::new_sparse requires a non-empty target".to_string(),
617            );
618        }
619        if !(weight.is_finite() && weight > 0.0) {
620            return Err(format!(
621                "DecoderIncoherencePenalty::new_sparse requires finite weight > 0, got {weight}"
622            ));
623        }
624        if p_out == 0 {
625            return Err("DecoderIncoherencePenalty::new_sparse requires p_out > 0".to_string());
626        }
627        if block_sizes.len() < 2 {
628            return Err(
629                "DecoderIncoherencePenalty::new_sparse requires at least two atom blocks"
630                    .to_string(),
631            );
632        }
633        let k = block_sizes.len();
634        let mut total = 0usize;
635        for (atom_idx, &m) in block_sizes.iter().enumerate() {
636            if m == 0 {
637                return Err(format!(
638                    "DecoderIncoherencePenalty::new_sparse block_sizes[{atom_idx}] must be > 0"
639                ));
640            }
641            let span = m.checked_mul(p_out).ok_or_else(|| {
642                "DecoderIncoherencePenalty::new_sparse block span overflows usize".to_string()
643            })?;
644            total = total.checked_add(span).ok_or_else(|| {
645                "DecoderIncoherencePenalty::new_sparse total span overflows usize".to_string()
646            })?;
647        }
648        if total != target.len() {
649            return Err(format!(
650                "DecoderIncoherencePenalty::new_sparse Σ_k M_k·p_out = {total} does not match target length {}",
651                target.len()
652            ));
653        }
654        let mut clean = Vec::with_capacity(pairs.len());
655        for (j, kk, w) in pairs {
656            if j >= k || kk >= k {
657                return Err(format!(
658                    "DecoderIncoherencePenalty::new_sparse pair ({j}, {kk}) out of range K={k}"
659                ));
660            }
661            if j >= kk {
662                return Err(format!(
663                    "DecoderIncoherencePenalty::new_sparse requires j < k for each pair, got ({j}, {kk})"
664                ));
665            }
666            if !(w.is_finite() && w >= 0.0) {
667                return Err(format!(
668                    "DecoderIncoherencePenalty::new_sparse requires finite non-negative pair weight, got {w}"
669                ));
670            }
671            if w != 0.0 {
672                clean.push((j, kk, w));
673            }
674        }
675        Ok(Self {
676            target,
677            block_sizes,
678            p_out,
679            k_atoms: k,
680            pairs: clean,
681            weight,
682            learnable_weight,
683            rho_index: 0,
684            weight_schedule: None,
685        })
686    }
687
688    impl_with_weight_schedule!(weight);
689
690    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
691        if self.learnable_weight {
692            resolve_learnable_weight(self.weight, rho[self.rho_index])
693        } else {
694            self.weight
695        }
696    }
697
698    /// Flat-β offset of atom `k`'s decoder block within the vector passed to
699    /// this penalty. SAE decoder-incoherence wiring registers a zero-based
700    /// target slice, so `target.range.start` is normally zero here.
701    fn block_offsets(&self) -> Vec<usize> {
702        let mut out = Vec::with_capacity(self.block_sizes.len());
703        let mut cursor = self.target.range.start;
704        for &m in &self.block_sizes {
705            out.push(cursor);
706            cursor += m * self.p_out;
707        }
708        out
709    }
710
711    /// Cross-Gram `C[a, b] = Σ_o B_j[a, o]·B_k[b, o]`, shape `(M_j, M_k)`.
712    fn cross_gram(
713        target: ArrayView1<'_, f64>,
714        off_j: usize,
715        m_j: usize,
716        off_k: usize,
717        m_k: usize,
718        p_out: usize,
719    ) -> Array2<f64> {
720        let mut out = Array2::<f64>::zeros((m_j, m_k));
721        for a in 0..m_j {
722            for b in 0..m_k {
723                let mut s = 0.0;
724                for o in 0..p_out {
725                    s += target[off_j + a * p_out + o] * target[off_k + b * p_out + o];
726                }
727                out[[a, b]] = s;
728            }
729        }
730        out
731    }
732
733    /// Shared kernel for the two curvature operators. Accumulates, per penalized
734    /// atom pair `(j, k)`, the Gauss-Newton term `W·Σ_b dC[a,b]·B_k[b,o]` (and
735    /// its `_k` transpose) always, and the residual term `W·Σ_b C[a,b]·V_k[b,o]`
736    /// (and `_k` transpose) only when `include_residual`. With the residual the
737    /// result is the exact `∂²P·v` ([`AnalyticPenalty::hvp`]); without it the
738    /// result is the PSD Gauss-Newton surrogate
739    /// ([`AnalyticPenalty::psd_majorizer_hvp`]).
740    fn hvp_impl(
741        &self,
742        target: ArrayView1<'_, f64>,
743        rho: ArrayView1<'_, f64>,
744        v: ArrayView1<'_, f64>,
745        include_residual: bool,
746    ) -> Array1<f64> {
747        let mut out = Array1::<f64>::zeros(target.len());
748        if target.len() != self.target.len() {
749            return out;
750        }
751        let offsets = self.block_offsets();
752        let weight = self.resolved_weight(rho);
753        let p_out = self.p_out;
754        for &(j, k, w_sym) in &self.pairs {
755            {
756                let w_pair = w_sym * weight;
757                if w_pair == 0.0 {
758                    continue;
759                }
760                let off_j = offsets[j];
761                let off_k = offsets[k];
762                let m_j = self.block_sizes[j];
763                let m_k = self.block_sizes[k];
764                // Directional Gram derivative driving the Gauss-Newton term:
765                //   dC[a, b] = Σ_o (Vj[a, o]·Bk[b, o] + Bj[a, o]·Vk[b, o]).
766                let mut d_c = Array2::<f64>::zeros((m_j, m_k));
767                for a in 0..m_j {
768                    for b in 0..m_k {
769                        let mut s = 0.0;
770                        for o in 0..p_out {
771                            s += v[off_j + a * p_out + o] * target[off_k + b * p_out + o]
772                                + target[off_j + a * p_out + o] * v[off_k + b * p_out + o];
773                        }
774                        d_c[[a, b]] = s;
775                    }
776                }
777                // Cross-Gram C[a, b] = Σ_o Bj[a, o]·Bk[b, o] feeds the residual
778                // term; only materialized for the exact Hessian path.
779                let c = if include_residual {
780                    Some(Self::cross_gram(target, off_j, m_j, off_k, m_k, p_out))
781                } else {
782                    None
783                };
784                // out_j[a, o] += w · Σ_b ( dC[a, b]·Bk[b, o] + C[a, b]·Vk[b, o] )
785                for a in 0..m_j {
786                    for o in 0..p_out {
787                        let mut s = 0.0;
788                        for b in 0..m_k {
789                            s += d_c[[a, b]] * target[off_k + b * p_out + o];
790                            if let Some(c) = &c {
791                                s += c[[a, b]] * v[off_k + b * p_out + o];
792                            }
793                        }
794                        out[off_j + a * p_out + o] += w_pair * s;
795                    }
796                }
797                // out_k[b, o] += w · Σ_a ( dC[a, b]·Bj[a, o] + C[a, b]·Vj[a, o] )
798                for b in 0..m_k {
799                    for o in 0..p_out {
800                        let mut s = 0.0;
801                        for a in 0..m_j {
802                            s += d_c[[a, b]] * target[off_j + a * p_out + o];
803                            if let Some(c) = &c {
804                                s += c[[a, b]] * v[off_j + a * p_out + o];
805                            }
806                        }
807                        out[off_k + b * p_out + o] += w_pair * s;
808                    }
809                }
810            }
811        }
812        out
813    }
814
815    /// Scatter the Gauss-Newton (PSD majorizer) curvature DIRECTLY into a dense
816    /// `β × β` block, accumulating `scale · H_GN` onto `hbb`.
817    ///
818    /// This produces exactly the operator [`AnalyticPenalty::psd_majorizer_hvp`]
819    /// applies (the `include_residual = false` branch of [`Self::hvp_impl`]), but
820    /// assembled block-by-block over the penalized atom pairs instead of
821    /// reconstructed column-by-column from `β` unit-probe HVPs. Since `H_GN` is
822    /// pair-local — it couples only the `(j, k)` pairs in `self.pairs`, each within
823    /// their `(M·p)` decoder blocks — reading off the four output loops of
824    /// `hvp_impl` at a unit probe gives, per pair `(j, k)` with
825    /// `w = w_sym · λ · scale` and `G_x = B_xᵀ B_x` (the `p × p` decoder output
826    /// Gram of atom `x`):
827    ///   * j-block diagonal  `H[(j,a,o),(j,a,o')] += w · G_k[o,o']`
828    ///   * k-block diagonal  `H[(k,b,o),(k,b,o')] += w · G_j[o,o']`
829    ///   * off-diagonal      `H[(j,a,o₁),(k,b,o₂)] += w · B_j[a,o₂] · B_k[b,o₁]`
830    ///     and its symmetric transpose into the `(k, j)` block.
831    ///
832    /// Cost is `O(Σ_pairs (M_j·M_k + M_j + M_k)·p²)`, versus the probe loop's
833    /// `O(β · Σ_pairs M_j·M_k·p)`: once `β = K·M·p` and the collinearity gate
834    /// admits `O(K)` co-active pairs, the probe loop spends `O(K²)` time
835    /// rebuilding a matrix this assembles in `O(K)` (#1026).
836    pub fn accumulate_psd_majorizer_dense(
837        &self,
838        target: ArrayView1<'_, f64>,
839        rho: ArrayView1<'_, f64>,
840        scale: f64,
841        hbb: &mut Array2<f64>,
842    ) {
843        if target.len() != self.target.len() {
844            return;
845        }
846        let offsets = self.block_offsets();
847        let weight = self.resolved_weight(rho);
848        let p = self.p_out;
849        for &(j, k, w_sym) in &self.pairs {
850            let w = w_sym * weight * scale;
851            if w == 0.0 {
852                continue;
853            }
854            let off_j = offsets[j];
855            let off_k = offsets[k];
856            let m_j = self.block_sizes[j];
857            let m_k = self.block_sizes[k];
858            // Per-pair output Grams G_j = B_jᵀB_j and G_k = B_kᵀB_k (p × p), which
859            // drive the within-block diagonal curvature of the partner atom.
860            let mut g_j = vec![0.0_f64; p * p];
861            let mut g_k = vec![0.0_f64; p * p];
862            for o in 0..p {
863                for o2 in 0..p {
864                    let mut sj = 0.0;
865                    for a in 0..m_j {
866                        sj += target[off_j + a * p + o] * target[off_j + a * p + o2];
867                    }
868                    g_j[o * p + o2] = sj;
869                    let mut sk = 0.0;
870                    for b in 0..m_k {
871                        sk += target[off_k + b * p + o] * target[off_k + b * p + o2];
872                    }
873                    g_k[o * p + o2] = sk;
874                }
875            }
876            // j-block diagonal: H[(j,a,o),(j,a,o')] += w · G_k[o,o'].
877            for a in 0..m_j {
878                let base = off_j + a * p;
879                for o in 0..p {
880                    for o2 in 0..p {
881                        hbb[[base + o, base + o2]] += w * g_k[o * p + o2];
882                    }
883                }
884            }
885            // k-block diagonal: H[(k,b,o),(k,b,o')] += w · G_j[o,o'].
886            for b in 0..m_k {
887                let base = off_k + b * p;
888                for o in 0..p {
889                    for o2 in 0..p {
890                        hbb[[base + o, base + o2]] += w * g_j[o * p + o2];
891                    }
892                }
893            }
894            // Off-diagonal coupling: H[(j,a,o₁),(k,b,o₂)] += w · B_j[a,o₂]·B_k[b,o₁],
895            // and the symmetric transpose into the (k, j) block.
896            for a in 0..m_j {
897                for b in 0..m_k {
898                    for o1 in 0..p {
899                        let row_j = off_j + a * p + o1;
900                        let bk_b_o1 = target[off_k + b * p + o1];
901                        for o2 in 0..p {
902                            let col_k = off_k + b * p + o2;
903                            let contrib = w * target[off_j + a * p + o2] * bk_b_o1;
904                            hbb[[row_j, col_k]] += contrib;
905                            hbb[[col_k, row_j]] += contrib;
906                        }
907                    }
908                }
909            }
910        }
911    }
912}
913
914impl AnalyticPenalty for DecoderIncoherencePenalty {
915    fn tier(&self) -> PenaltyTier {
916        PenaltyTier::Beta
917    }
918
919    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
920        if target.len() != self.target.len() {
921            return 0.0;
922        }
923        let offsets = self.block_offsets();
924        let mut acc = 0.0;
925        for &(j, k, w_pair) in &self.pairs {
926            {
927                if w_pair == 0.0 {
928                    continue;
929                }
930                let c = Self::cross_gram(
931                    target,
932                    offsets[j],
933                    self.block_sizes[j],
934                    offsets[k],
935                    self.block_sizes[k],
936                    self.p_out,
937                );
938                let mut frob_sq = 0.0;
939                for &value in c.iter() {
940                    frob_sq += value * value;
941                }
942                acc += w_pair * frob_sq;
943            }
944        }
945        0.5 * self.resolved_weight(rho) * acc
946    }
947
948    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
949        let mut grad = Array1::<f64>::zeros(target.len());
950        if target.len() != self.target.len() {
951            return grad;
952        }
953        let offsets = self.block_offsets();
954        let weight = self.resolved_weight(rho);
955        for &(j, k, w_sym) in &self.pairs {
956            {
957                let w_pair = w_sym * weight;
958                if w_pair == 0.0 {
959                    continue;
960                }
961                let off_j = offsets[j];
962                let off_k = offsets[k];
963                let m_j = self.block_sizes[j];
964                let m_k = self.block_sizes[k];
965                let c = Self::cross_gram(target, off_j, m_j, off_k, m_k, self.p_out);
966                // grad_j[a, o] += w · Σ_b C[a, b] · B_k[b, o]
967                for a in 0..m_j {
968                    for o in 0..self.p_out {
969                        let mut s = 0.0;
970                        for b in 0..m_k {
971                            s += c[[a, b]] * target[off_k + b * self.p_out + o];
972                        }
973                        grad[off_j + a * self.p_out + o] += w_pair * s;
974                    }
975                }
976                // grad_k[b, o] += w · Σ_a C[a, b] · B_j[a, o]
977                for b in 0..m_k {
978                    for o in 0..self.p_out {
979                        let mut s = 0.0;
980                        for a in 0..m_j {
981                            s += c[[a, b]] * target[off_j + a * self.p_out + o];
982                        }
983                        grad[off_k + b * self.p_out + o] += w_pair * s;
984                    }
985                }
986            }
987        }
988        grad
989    }
990
991    /// Exact Hessian-vector product `H v = (∂²P/∂target²) v`.
992    ///
993    /// `P = ½ w Σ_{j<k} w_{jk} ‖C_{jk}‖²_F` is biquadratic (quartic) in the
994    /// decoder blocks, so the second derivative of the nonlinear-least-squares
995    /// objective carries **two** pieces along a direction `V` (per pair, with
996    /// `W = w·w_{jk}`):
997    ///
998    /// ```text
999    ///   (H v)_j[a,o] = W [ Σ_b dC[a,b]·B_k[b,o]   +   Σ_b C[a,b]·V_k[b,o] ]
1000    /// ```
1001    ///
1002    /// the Gauss-Newton term `Σ dC·B` and the residual term `Σ C·V`, with
1003    /// `dC[a,b] = Σ_o (V_j[a,o]·B_k[b,o] + B_j[a,o]·V_k[b,o])` (and the symmetric
1004    /// `_k` block). The residual term is what makes the exact Hessian indefinite;
1005    /// the GN-only surrogate lives in [`Self::psd_majorizer_hvp`].
1006    fn hvp(
1007        &self,
1008        target: ArrayView1<'_, f64>,
1009        rho: ArrayView1<'_, f64>,
1010        v: ArrayView1<'_, f64>,
1011    ) -> Array1<f64> {
1012        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1013        self.hvp_impl(target, rho, v, /* include_residual = */ true)
1014    }
1015
1016    /// PSD majorizer-vector product `B_GN(target; ρ) v` for the **nonconvex**
1017    /// decoder-incoherence penalty.
1018    ///
1019    /// Dropping the indefinite residual term `W·Σ C·V` from the exact
1020    /// [`Self::hvp`] leaves the Gauss-Newton block `W·Jᵀ(J v)` with
1021    /// `J = ∂vec(C)/∂vec(B)`. That block is PSD by construction — a sum of
1022    /// `W ≥ 0` (`weight > 0`, `coactivation ≥ 0`) times rank-structured Gram
1023    /// products `JᵀJ` — and coincides with the exact Hessian as the cross-Gram
1024    /// `C → 0`. The inner Newton / PIRLS curvature block must stay
1025    /// positive-definite, so the GN block is the correct operator here, mirroring
1026    /// the other nonconvex penalties (sparsity, JumpReLU, isometry) that override
1027    /// the majorizer rather than hand back the indefinite true Hessian.
1028    fn psd_majorizer_hvp(
1029        &self,
1030        target: ArrayView1<'_, f64>,
1031        rho: ArrayView1<'_, f64>,
1032        v: ArrayView1<'_, f64>,
1033    ) -> Array1<f64> {
1034        assert_eq!(
1035            target.len(),
1036            v.len(),
1037            "psd_majorizer_hvp dimension mismatch"
1038        );
1039        self.hvp_impl(target, rho, v, /* include_residual = */ false)
1040    }
1041
1042    // `hessian_diag` is intentionally left at the trait default (returns `None`
1043    // for a non-empty target): the Hessian of the cross-Gram Frobenius objective
1044    // is dense, not diagonal, so curvature is supplied via the closed-form
1045    // `hvp` / `psd_majorizer_hvp` path above.
1046
1047    impl_learnable_weight_grad_rho!();
1048
1049    impl_learnable_weight_rho_count!();
1050
1051    fn name(&self) -> &str {
1052        "decoder_incoherence"
1053    }
1054
1055    impl_scalar_apply_schedule!(weight);
1056}
1057
1058// ---------------------------------------------------------------------------
1059// Orthogonality penalty
1060// ---------------------------------------------------------------------------
1061
1062/// Gauge-fixing penalty for latent-coordinate axes.
1063///
1064/// ARD alone is rotation-invariant — pair with Orthogonality to identify
1065/// intrinsic dim. This penalty locks a canonical orthonormal basis first;
1066/// ARD can then shrink axes after the rotation gauge has been identified.
1067#[derive(Debug, Clone)]
1068pub struct OrthogonalityPenalty {
1069    pub target: PsiSlice,
1070    pub latent_dim: usize,
1071    /// Base strength. If `learnable_weight` is true, the resolved strength is
1072    /// `weight * exp(rho[rho_index])`; otherwise it is fixed at `weight`.
1073    pub weight: f64,
1074    /// Effective observation count used to keep the Frobenius contribution on
1075    /// the same scale as per-axis latent priors.
1076    pub n_eff: usize,
1077    pub learnable_weight: bool,
1078    pub rho_index: usize,
1079    pub weight_schedule: Option<ScalarWeightSchedule>,
1080}
1081
1082impl OrthogonalityPenalty {
1083    #[must_use = "build error must be handled"]
1084    pub fn new(
1085        target: PsiSlice,
1086        latent_dim: usize,
1087        weight: f64,
1088        n_eff: usize,
1089        learnable_weight: bool,
1090    ) -> Result<Self, String> {
1091        if latent_dim == 0 {
1092            return Err("OrthogonalityPenalty::new requires latent_dim > 0".to_string());
1093        }
1094        if !target.len().is_multiple_of(latent_dim) {
1095            return Err(format!(
1096                "OrthogonalityPenalty::new target length {} is not divisible by latent_dim {}",
1097                target.len(),
1098                latent_dim
1099            ));
1100        }
1101        let n_obs = target.len() / latent_dim;
1102        if n_obs < latent_dim {
1103            return Err(format!(
1104                "OrthogonalityPenalty::new requires n_obs >= latent_dim for a feasible \
1105                 Stiefel target, got n_obs {n_obs} and latent_dim {latent_dim}"
1106            ));
1107        }
1108        if !(weight.is_finite() && weight > 0.0) {
1109            return Err(format!(
1110                "OrthogonalityPenalty::new requires finite weight > 0, got {weight}"
1111            ));
1112        }
1113        if n_eff == 0 {
1114            return Err("OrthogonalityPenalty::new requires n_eff > 0".to_string());
1115        }
1116        if n_eff != n_obs {
1117            return Err(format!(
1118                "OrthogonalityPenalty::new requires n_eff to match target rows, got \
1119                 n_eff {n_eff} and target rows {n_obs}"
1120            ));
1121        }
1122        Ok(Self {
1123            target,
1124            latent_dim,
1125            weight,
1126            n_eff,
1127            learnable_weight,
1128            rho_index: 0,
1129            weight_schedule: None,
1130        })
1131    }
1132
1133    impl_with_weight_schedule!(weight);
1134
1135    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
1136        if self.learnable_weight {
1137            resolve_learnable_weight(self.weight, rho[self.rho_index])
1138        } else {
1139            self.weight
1140        }
1141    }
1142
1143    pub(crate) fn scale(&self, rho: ArrayView1<'_, f64>) -> f64 {
1144        self.resolved_weight(rho) / self.n_eff as f64
1145    }
1146
1147    pub(crate) fn target_matrix<'a>(
1148        &self,
1149        target: ArrayView1<'a, f64>,
1150    ) -> Option<ArrayView2<'a, f64>> {
1151        let d = self.latent_dim;
1152        if !target.len().is_multiple_of(d) {
1153            assert_eq!(
1154                target.len() % d,
1155                0,
1156                "target length must be divisible by latent_dim"
1157            );
1158            return None;
1159        }
1160        let n_obs = target.len() / d;
1161        target.into_shape_with_order((n_obs, d)).ok()
1162    }
1163
1164    pub(crate) fn gram_minus_identity(t: ArrayView2<'_, f64>) -> Array2<f64> {
1165        let n_obs = t.nrows();
1166        let d = t.ncols();
1167        let mut gram = Array2::<f64>::zeros((d, d));
1168        for a in 0..d {
1169            for b in 0..d {
1170                let mut s = 0.0;
1171                for n in 0..n_obs {
1172                    s += t[[n, a]] * t[[n, b]];
1173                }
1174                gram[[a, b]] = s;
1175            }
1176            gram[[a, a]] -= 1.0;
1177        }
1178        gram
1179    }
1180
1181    fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
1182        let n_obs = m.nrows();
1183        let d = m.ncols();
1184        let mut out = Array1::<f64>::zeros(n_obs * d);
1185        for n in 0..n_obs {
1186            for a in 0..d {
1187                out[n * d + a] = m[[n, a]];
1188            }
1189        }
1190        out
1191    }
1192
1193    pub(crate) fn hvp_with_precomputed_m(
1194        &self,
1195        t: ArrayView2<'_, f64>,
1196        m: ArrayView2<'_, f64>,
1197        v: ArrayView2<'_, f64>,
1198        scale: f64,
1199    ) -> Array2<f64> {
1200        let n_obs = t.nrows();
1201        let d = t.ncols();
1202        assert_eq!(v.dim(), t.dim(), "hvp matrix dimension mismatch");
1203        assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
1204        if v.dim() != t.dim() {
1205            return Array2::<f64>::zeros((n_obs, d));
1206        }
1207
1208        let mut vt_t_plus_tt_v = Array2::<f64>::zeros((d, d));
1209        for c in 0..d {
1210            for b in 0..d {
1211                let mut s = 0.0;
1212                for n in 0..n_obs {
1213                    s += v[[n, c]] * t[[n, b]] + t[[n, c]] * v[[n, b]];
1214                }
1215                vt_t_plus_tt_v[[c, b]] = s;
1216            }
1217        }
1218
1219        let mut out = Array2::<f64>::zeros((n_obs, d));
1220        for n in 0..n_obs {
1221            for b in 0..d {
1222                let mut va = 0.0;
1223                let mut tb = 0.0;
1224                for c in 0..d {
1225                    va += v[[n, c]] * m[[c, b]];
1226                    tb += t[[n, c]] * vt_t_plus_tt_v[[c, b]];
1227                }
1228                out[[n, b]] = 2.0 * scale * (va + tb);
1229            }
1230        }
1231        out
1232    }
1233
1234    pub(crate) fn as_dense_with_precomputed_m(
1235        &self,
1236        t: ArrayView2<'_, f64>,
1237        m: ArrayView2<'_, f64>,
1238        scale: f64,
1239    ) -> Array2<f64> {
1240        let n_obs = t.nrows();
1241        let d = t.ncols();
1242        assert_eq!(m.dim(), (d, d), "precomputed gram dimension mismatch");
1243        if m.dim() != (d, d) {
1244            return Array2::<f64>::zeros((n_obs * d, n_obs * d));
1245        }
1246
1247        let mut dense = Array2::<f64>::zeros((n_obs * d, n_obs * d));
1248        let factor = 2.0 * scale;
1249        for row1 in 0..n_obs {
1250            for row2 in 0..n_obs {
1251                let mut row_dot = 0.0;
1252                for axis in 0..d {
1253                    row_dot += t[[row1, axis]] * t[[row2, axis]];
1254                }
1255                for col1 in 0..d {
1256                    let i = row1 * d + col1;
1257                    for col2 in 0..d {
1258                        let j = row2 * d + col2;
1259                        let mut entry = t[[row1, col2]] * t[[row2, col1]];
1260                        if row1 == row2 {
1261                            entry += m[[col2, col1]];
1262                        }
1263                        if col1 == col2 {
1264                            entry += row_dot;
1265                        }
1266                        dense[[i, j]] = factor * entry;
1267                    }
1268                }
1269            }
1270        }
1271        dense
1272    }
1273}
1274
1275impl AnalyticPenalty for OrthogonalityPenalty {
1276    fn tier(&self) -> PenaltyTier {
1277        PenaltyTier::Psi
1278    }
1279
1280    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1281        let Some(t) = self.target_matrix(target) else {
1282            return 0.0;
1283        };
1284        let gram = Self::gram_minus_identity(t.view());
1285        let mut acc = 0.0;
1286        for &v in gram.iter() {
1287            acc += v * v;
1288        }
1289        0.5 * self.scale(rho) * acc
1290    }
1291
1292    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1293        // Matrix-calculus core:
1294        //   d/dT ½·scale·||TᵀT - I||²_F = 2·scale·T·(TᵀT - I),
1295        // because TᵀT - I is symmetric.
1296        let Some(t) = self.target_matrix(target) else {
1297            return Array1::<f64>::zeros(target.len());
1298        };
1299        let gram = Self::gram_minus_identity(t.view());
1300        let n_obs = t.nrows();
1301        let d = t.ncols();
1302        let factor = 2.0 * self.scale(rho);
1303        let mut grad = Array2::<f64>::zeros((n_obs, d));
1304        for n in 0..n_obs {
1305            for a in 0..d {
1306                let mut s = 0.0;
1307                for b in 0..d {
1308                    s += t[[n, b]] * gram[[b, a]];
1309                }
1310                grad[[n, a]] = factor * s;
1311            }
1312        }
1313        Self::flatten_matrix(&grad)
1314    }
1315
1316    fn hvp(
1317        &self,
1318        target: ArrayView1<'_, f64>,
1319        rho: ArrayView1<'_, f64>,
1320        v: ArrayView1<'_, f64>,
1321    ) -> Array1<f64> {
1322        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1323        if target.len() != v.len() {
1324            return Array1::<f64>::zeros(target.len());
1325        }
1326        let Some(t) = self.target_matrix(target) else {
1327            return Array1::<f64>::zeros(target.len());
1328        };
1329        let Some(v_mat) = self.target_matrix(v) else {
1330            return Array1::<f64>::zeros(target.len());
1331        };
1332        let m = Self::gram_minus_identity(t.view());
1333        let hv = self.hvp_with_precomputed_m(t.view(), m.view(), v_mat.view(), self.scale(rho));
1334        Self::flatten_matrix(&hv)
1335    }
1336
1337    impl_learnable_weight_grad_rho!();
1338
1339    impl_learnable_weight_rho_count!();
1340
1341    fn name(&self) -> &str {
1342        "orthogonality"
1343    }
1344
1345    impl_scalar_apply_schedule!(weight);
1346}