gam 0.2.3

Generalized penalized likelihood engine
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
//! # Atom selection for multi-manifold / overlapping atoms (Piece 6).
//!
//! This module is the structural sibling of [`crate::terms::latent_coord`] for
//! the **multi-atom** regime described in `proposals/sae_manifold.md` §3.
//! Where the single-atom case stores one per-row latent `t_n ∈ ℝ^d`, here we
//! maintain:
//!
//! 1. an [`AtomLibrary`] of `K` candidate manifold-atoms (each with its own
//!    on-atom basis, its own intrinsic dimension `d_k`, and its own
//!    [`crate::terms::latent_coord::LatentCoordValues`] block of per-row
//!    on-atom coordinates `t_{·, k} ∈ ℝ^{N × d_k}`),
//! 2. per-observation [`crate::terms::atom_codes::SparseAtomCode`]s recording
//!    a *soft assignment* `a_n ∈ ℝ^K` with active support `S_n ⊆ {1..K}`,
//! 3. a pluggable [`AtomSelectionStrategy`] governing how the assignment is
//!    parameterised and how its discrete active-set is differentiated through.
//!
//! ## Parameter partition
//!
//! Following the SAE-manifold tier-assignment table (`sae_manifold.md` §3.1):
//!
//! | Block | Lives in | Owner (piece) |
//! |---|---|---|
//! | `B_1..B_K` decoder coefficients (shared, *one block per atom*) | β / inner Newton | Piece 1 (decoder) — we only hold references |
//! | `t_{n, k}` on-atom coordinate (per row, per atom) | ext-coord / row-local | this module + Piece 1 |
//! | `a_{n, ·}` soft atom assignment (per row) | ext-coord / row-local | **this module** |
//! | `λ_sp` sparsity strength, `λ_sm` smoothness, `α_kj` ARD | ρ / REML outer loop | Piece 4 (sparsity), Piece 1 (ARD) |
//! | `K` atom count | discrete / `compare_models` | upstream wrapper |
//!
//! The Schur / arrow structure is preserved: each row's
//! `ext_n = (a_{n,·}, t_{n,1,·}, …, t_{n,K,·})` is block-diagonal across `n`,
//! coupled to the dense decoder border only through the *active subset*
//! `S_n` (inactive atoms contribute zero through the gating
//! `a_{n,k} = 0`). The math-audit caveat from `sae_manifold.md` §3.3 about
//! the shared `Schur⁻¹` factor in the REML `log|H|` gradient applies
//! unchanged.
//!
//! ### Per-row local-block size
//!
//! The single-atom case from Piece 1 (`latent_coord.rs`) carries a per-row
//! local block of size `d × d` (just `t_n ∈ ℝ^d`). For the multi-atom case
//! the per-row local block stacks the assignment row and the on-atom
//! coordinates of every atom:
//!
//! ```text
//!   ext_n  =  ( a_{n, 1..K}  ;  t_{n, 1, ·}  ;  …  ;  t_{n, K, ·} )
//!         ∈  ℝ^{K + Σ_k d_k}.
//! ```
//!
//! So the local-block dimension is
//!
//! ```text
//!   dim(ext_n)  =  K  +  Σ_{k=1..K} d_k,
//! ```
//!
//! and the local Hessian block is `(K + Σ_k d_k) × (K + Σ_k d_k)`,
//! block-diagonal across `n`. Piece 1's `solve_arrow_newton_step_with_options` Schur
//! elimination generalises by:
//!
//! 1. Eliminating shared β = `(B_1, …, B_K)` first (the existing inner
//!    factorisation), restricted on each row to the *active subset* `S_n`
//!    — atoms with `a_{n,k} = 0` contribute neither to the border nor to
//!    the row's `(t_{n,k}, ·)` block at first order.
//! 2. Solving each row's `(K + Σ_k d_k) × (K + Σ_k d_k)` local block. In
//!    the typical sparse regime `|S_n| ≪ K`, so the *effective* local
//!    block collapses to `(|S_n| + Σ_{k ∈ S_n} d_k) × (·)` after dropping
//!    the inactive coordinates from the active-set.
//!
//! The production SAE-manifold assembler now applies that stacking recipe in
//! [`crate::terms::sae_manifold::SaeManifoldTerm::assemble_arrow_schur`]:
//! the `(K, K)` assignment block sits on the diagonal corner of `ext_n`, the
//! `K` per-atom `(d_k, d_k)` coordinate blocks tile the rest, and the
//! off-diagonal `(a_{n,k}, t_{n,k,·})` couplings are populated from each
//! atom's basis-derivative jet evaluated against `B_k`.
//!
//! ## Relaxation choices for the assignment
//!
//! The assignment `a_n` is intrinsically combinatorial: in the ideal sparse
//! regime it picks a small support `S_n` and a real-valued amplitude on it.
//! Three differentiable relaxations are exposed via
//! [`AtomSelectionStrategy`]:
//!
//! * [`EntropicSoftmax`] — write `a_n = softmax(ℓ_n / τ)` for free logits
//!   `ℓ_n ∈ ℝ^K`. Stays on the open simplex; gradient is the standard
//!   softmax Jacobian; sparsity is encouraged by adding an
//!   entropic penalty `−H(a_n) = Σ_k a_{n,k} log a_{n,k}` whose strength
//!   trades against the data fit (small `τ` → near-hard assignment, larger
//!   `τ` → diffuse). Default strategy.
//! * [`TopK`] — keep only the `k` largest free amplitudes per row. Exact
//!   sparsity, but the active-set choice is discrete; we use the
//!   *straight-through* gradient estimator (forward pass uses the
//!   sparsified `a_n`; backward pass uses the dense gradient as if the
//!   thresholding were the identity). This is the standard convention in
//!   the TopK-SAE literature and is the canonical choice when pairing a
//!   discrete active-set with smooth manifold atoms; the bias of the
//!   estimator is documented at [`TopK::apply`].
//! * [`L1Relaxed`] — non-negative free amplitudes with a smoothed-L¹
//!   ([`crate::terms::analytic_penalties::SparsityPenalty`]) penalty. Pairs
//!   directly with the active-set inner solver (Piece 4) — strictly-zero
//!   weights *and* a smooth gradient. The relaxation parameter is the
//!   smoothing scale `ε` of the smoothed-L¹, REML-selectable.
//!
//! All three implement [`AtomSelectionStrategy`], which exposes the value and
//! gradient of the assignment-to-code map plus the corresponding penalty
//! contribution.
//!
//! ## Closed-form gradients and production assembly
//!
//! Fully implemented (closed-form, this module):
//!
//! * Softmax forward / Jacobian-vector product
//!   ([`EntropicSoftmax::apply`], [`EntropicSoftmax::jvp_logits`]).
//! * TopK projection with straight-through gradient
//!   ([`TopK::apply`], [`TopK::backward_straight_through`]).
//! * Sparsity-penalty coupling trait
//!   ([`AssignmentSparsityCoupling`]) wired to
//!   [`crate::terms::analytic_penalties::SparsityPenalty`].
//!
//! Production SAE-manifold assembly is now first-class in
//! [`crate::terms::sae_manifold::SaeManifoldTerm::assemble_arrow_schur`]:
//! it materializes the joint `(logits, t)` per-row block, including the
//! assignment diagonal and `(a, t)` cross terms from the atom basis jets, then
//! hands the result to [`crate::solver::arrow_schur::ArrowSchurSystem`].
//!
//! ## Integration hooks to other pieces
//!
//! * Piece 1 (`arrow_schur.rs`, `solve_arrow_newton_step_with_options`): consumed by the
//!   first-class SAE-manifold assembler in
//!   [`crate::terms::sae_manifold::SaeManifoldTerm::assemble_arrow_schur`].
//! * Piece 4 (`SparsityPenalty`): consumed as a black box via the
//!   [`AssignmentSparsityCoupling`] trait below. We do not edit Piece 4.
//! * Piece 5 (REML outer loop): the per-strategy relaxation parameter
//!   (`temperature`, `k`, `eps`) joins the outer ρ vector through the
//!   already-existing [`crate::terms::analytic_penalties`] `rho_index`
//!   plumbing; no new outer-loop code is needed here.

use ndarray::{Array1, ArrayView1};

use crate::terms::analytic_penalties::{AnalyticPenalty, SparsityPenalty};
use crate::terms::atom_codes::{BitVec, SparseAtomCode, SparseAtomCodes};
use crate::terms::latent_coord::LatentCoordValues;

// ---------------------------------------------------------------------------
// Atom shape (decoder reference) — kept as an opaque token here.
// ---------------------------------------------------------------------------

/// Opaque handle to an atom's smooth-decoder shape function.
///
/// In the full integration this will resolve to a concrete `Smooth` from
/// `crate::terms::smooth` — but `smooth.rs` is owned by another piece, so we
/// keep this layer abstract: a [`ShapeRef`] is a stable index into an
/// externally-held registry of decoder bases, plus the intrinsic dimension
/// `d_k` and basis size `M_k` of that atom. The atom-selection layer never
/// dereferences the shape directly; it asks the caller for evaluated
/// decoder outputs `g_k(t_{n,k}) ∈ ℝ^p` and design-gradient jets.
#[derive(Debug, Clone, Copy)]
pub struct ShapeRef {
    /// Stable index into the caller's decoder-shape registry.
    pub id: usize,
    /// Intrinsic dimension `d_k` of this atom's on-manifold coordinate.
    pub intrinsic_dim: usize,
    /// Basis size `M_k` (number of decoder coefficient columns per output dim).
    pub basis_size: usize,
}

// ---------------------------------------------------------------------------
// Per-atom record and library
// ---------------------------------------------------------------------------

/// One candidate manifold-atom: its decoder-shape reference plus the per-row
/// on-atom coordinates `t_{·, k} ∈ ℝ^{N × d_k}`.
///
/// Note that the *decoder coefficients* `B_k` live in the β tier (owned by
/// Piece 1 / `pirls.rs`); we hold only row-local extension-coordinate state here.
#[derive(Debug, Clone)]
pub struct AtomRecord {
    pub shape: ShapeRef,
    pub coords: LatentCoordValues,
}

impl AtomRecord {
    pub fn new(shape: ShapeRef, coords: LatentCoordValues) -> Self {
        assert_eq!(
            coords.latent_dim(),
            shape.intrinsic_dim,
            "AtomRecord: coord latent_dim {} != shape.intrinsic_dim {}",
            coords.latent_dim(),
            shape.intrinsic_dim,
        );
        Self { shape, coords }
    }

    pub fn intrinsic_dim(&self) -> usize {
        self.shape.intrinsic_dim
    }
}

/// `K` candidate manifold-atoms sharing a single observation set of size `N`.
///
/// All atoms must agree on `n_obs`. They may have different intrinsic
/// dimensions `d_k` (ragged), which is why the per-atom on-row coordinate
/// blocks are stored as separate [`LatentCoordValues`] rather than as one
/// dense `(N, K, d)` tensor.
#[derive(Debug, Clone)]
pub struct AtomLibrary {
    atoms: Vec<AtomRecord>,
    n_obs: usize,
}

impl AtomLibrary {
    /// Construct from a non-empty `Vec` of atoms. Errors if the per-atom
    /// `n_obs` disagree, or if no atoms are supplied.
    pub fn new(atoms: Vec<AtomRecord>) -> Result<Self, String> {
        if atoms.is_empty() {
            return Err("AtomLibrary::new: at least one atom required".into());
        }
        let n_obs = atoms[0].coords.n_obs();
        for (k, a) in atoms.iter().enumerate() {
            if a.coords.n_obs() != n_obs {
                return Err(format!(
                    "AtomLibrary::new: atom {k} has n_obs={} but atom 0 has n_obs={n_obs}",
                    a.coords.n_obs()
                ));
            }
        }
        Ok(Self { atoms, n_obs })
    }

    pub fn n_obs(&self) -> usize {
        self.n_obs
    }

    pub fn k_atoms(&self) -> usize {
        self.atoms.len()
    }

    pub fn atom(&self, k: usize) -> &AtomRecord {
        &self.atoms[k]
    }

    pub fn atom_mut(&mut self, k: usize) -> &mut AtomRecord {
        &mut self.atoms[k]
    }

    pub fn iter(&self) -> impl Iterator<Item = &AtomRecord> {
        self.atoms.iter()
    }

    /// Total intrinsic-dimension count `Σ_k d_k`. The per-row ext-coordinate block has
    /// size `K + Σ_k d_k` (assignment plus per-atom coord).
    pub fn total_intrinsic_dim(&self) -> usize {
        self.atoms.iter().map(|a| a.intrinsic_dim()).sum()
    }

    /// Allocate matching [`SparseAtomCodes`] storage (all-empty).
    pub fn fresh_codes(&self) -> SparseAtomCodes {
        SparseAtomCodes::empty(self.n_obs, self.k_atoms())
    }
}

// ---------------------------------------------------------------------------
// Sparsity coupling trait
// ---------------------------------------------------------------------------

/// Trait wiring an [`AtomSelectionStrategy`] into a
/// [`SparsityPenalty`] (Piece 4) without depending on Piece 4's internal
/// representation.
///
/// The contract: implementors expose the *target slice* over which the
/// sparsity penalty applies — for the L1-relaxed strategy this is the
/// free-amplitude vector itself; for entropic-softmax it is typically a no-op
/// (the entropic regulariser, owned by the strategy, replaces L¹); for TopK
/// it is also a no-op (cardinality is the regulariser).
pub trait AssignmentSparsityCoupling {
    /// Apply `penalty` to the row-`n` assignment, returning `(value, grad)`
    /// over the row's `K` free amplitudes. `rho` is the local penalty view.
    fn penalty_value_and_grad(
        &self,
        penalty: &SparsityPenalty,
        free_amplitudes_row: ArrayView1<'_, f64>,
        rho: ArrayView1<'_, f64>,
    ) -> (f64, Array1<f64>);
}

// ---------------------------------------------------------------------------
// AtomSelectionStrategy trait + three impls
// ---------------------------------------------------------------------------

/// Pluggable strategy governing the assignment parameterisation.
///
/// All strategies operate row-wise (the assignment is per-observation) and
/// take a length-`K` slice of *free amplitudes* `ℓ_n` (logits for softmax,
/// raw non-negative amplitudes for L¹-relaxed, raw amplitudes for TopK).
pub trait AtomSelectionStrategy: AssignmentSparsityCoupling {
    /// Strategy tag — useful for diagnostics / `compare_models` keying.
    fn name(&self) -> &'static str;

    /// Forward: map free amplitudes `ℓ_n` to a [`SparseAtomCode`] for row `n`.
    fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode;

    /// Backward: given `∂ℒ/∂a_{n,·}` from the data-fit (length `K`, dense),
    /// return `∂ℒ/∂ℓ_{n,·}` (length `K`).
    ///
    /// Strategies differ in their Jacobian; see per-impl docs.
    fn backward(
        &self,
        free_amplitudes_row: ArrayView1<'_, f64>,
        code: &SparseAtomCode,
        grad_a_row: ArrayView1<'_, f64>,
    ) -> Array1<f64>;
}

// --- EntropicSoftmax --------------------------------------------------------

/// Fully differentiable simplex parameterisation.
///
/// ```text
///   a_{n,k} = exp(ℓ_{n,k} / τ) / Σ_j exp(ℓ_{n,j} / τ).
/// ```
///
/// The temperature `τ > 0` is the relaxation parameter. Lower `τ` produces
/// near-hard assignments (but with vanishing gradients); higher `τ` keeps
/// assignments diffuse. The default is `τ = 1.0`.
///
/// The entropic regulariser `−H(a_n)` is *not* materialised here — it is
/// added through the standard penalty layer (see
/// [`AssignmentSparsityCoupling`]) using
/// [`SparsityKind::Log`](crate::terms::analytic_penalties::SparsityKind) as a
/// proxy. Pure cross-entropy support is deferred until Piece 4 grows a
/// dedicated `EntropyPenalty`.
#[derive(Debug, Clone)]
pub struct EntropicSoftmax {
    pub temperature: f64,
    /// If `Some(thr)`, atoms with softmax mass below `thr` are masked out
    /// (still soft below — the mask only affects which atoms count as
    /// active for the per-row Schur reduction). The default is `None`,
    /// i.e. full dense support, which is appropriate for very small `K`.
    pub mask_threshold: Option<f64>,
}

impl EntropicSoftmax {
    pub fn new(temperature: f64) -> Self {
        assert!(
            temperature.is_finite() && temperature > 0.0,
            "EntropicSoftmax temperature must be finite and positive, got {temperature}"
        );
        Self {
            temperature,
            mask_threshold: None,
        }
    }

    pub fn with_mask_threshold(mut self, thr: f64) -> Self {
        assert!(
            thr.is_finite(),
            "EntropicSoftmax mask threshold must be finite, got {thr}"
        );
        self.mask_threshold = Some(thr);
        self
    }

    /// Numerically stable softmax with temperature.
    fn softmax(&self, logits: ArrayView1<'_, f64>) -> Array1<f64> {
        let k = logits.len();
        let tau = self.temperature;
        // shift by max for stability
        let mut m = f64::NEG_INFINITY;
        for &l in logits.iter() {
            let s = l / tau;
            if s > m {
                m = s;
            }
        }
        let mut out = Array1::<f64>::zeros(k);
        let mut s = 0.0;
        for i in 0..k {
            let v = (logits[i] / tau - m).exp();
            out[i] = v;
            s += v;
        }
        assert!(s > 0.0);
        for v in out.iter_mut() {
            *v /= s;
        }
        out
    }

    /// Jacobian-vector product: given `g_a = ∂ℒ/∂a`, return `∂ℒ/∂ℓ`.
    ///
    /// The softmax Jacobian (per row) is `J = (diag(a) − a aᵀ) / τ`, so
    /// `∂ℒ/∂ℓ = (a ⊙ g_a − a · (a · g_a)) / τ`.
    pub fn jvp_logits(&self, a: ArrayView1<'_, f64>, g_a: ArrayView1<'_, f64>) -> Array1<f64> {
        let k = a.len();
        let mut dot = 0.0;
        for i in 0..k {
            dot += a[i] * g_a[i];
        }
        let inv_tau = 1.0 / self.temperature;
        Array1::<f64>::from_iter((0..k).map(|i| a[i] * (g_a[i] - dot) * inv_tau))
    }
}

impl AtomSelectionStrategy for EntropicSoftmax {
    fn name(&self) -> &'static str {
        "entropic_softmax"
    }

    fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
        let a = self.softmax(free_amplitudes_row);
        let k = a.len();
        let mut mask = BitVec::ones(k);
        if let Some(thr) = self.mask_threshold {
            for i in 0..k {
                if a[i] < thr {
                    mask.set(i, false);
                }
            }
        }
        let mut weights = vec![0.0_f64; k];
        for i in 0..k {
            if mask.get(i) {
                weights[i] = a[i];
            }
        }
        SparseAtomCode {
            active_mask: mask,
            weights,
        }
    }

    fn backward(
        &self,
        free_amplitudes_row: ArrayView1<'_, f64>,
        atom_code: &SparseAtomCode,
        grad_a_row: ArrayView1<'_, f64>,
    ) -> Array1<f64> {
        assert_eq!(
            grad_a_row.len(),
            free_amplitudes_row.len(),
            "EntropicSoftmax backward gradient length mismatch"
        );
        assert_eq!(
            atom_code.k_atoms(),
            free_amplitudes_row.len(),
            "EntropicSoftmax backward code/free-amplitude length mismatch"
        );
        assert!(
            atom_code.weights.iter().all(|weight| weight.is_finite()),
            "EntropicSoftmax backward requires finite assignment weights"
        );
        // Recompute softmax (cheap; alternative is to cache it in the code,
        // but that conflates the masked weights with the *unmasked* softmax
        // needed by the Jacobian).
        let a = self.softmax(free_amplitudes_row);
        self.jvp_logits(a.view(), grad_a_row)
    }
}

impl AssignmentSparsityCoupling for EntropicSoftmax {
    fn penalty_value_and_grad(
        &self,
        sparsity_penalty: &SparsityPenalty,
        free_amplitudes_row: ArrayView1<'_, f64>,
        rho: ArrayView1<'_, f64>,
    ) -> (f64, Array1<f64>) {
        assert_eq!(
            rho.len(),
            sparsity_penalty.rho_count(),
            "EntropicSoftmax sparsity rho length mismatch"
        );
        assert!(
            rho.iter().all(|value| value.is_finite()),
            "EntropicSoftmax sparsity rho must be finite"
        );
        // Entropic-softmax does not consume the L¹ sparsity penalty
        // directly; the entropy regularisation lives inside the strategy
        // itself. We return zero contribution here (Piece 4 sees nothing
        // to penalise) so the global energy isn't double-counted.
        let k = free_amplitudes_row.len();
        (0.0, Array1::<f64>::zeros(k))
    }
}

// --- TopK -------------------------------------------------------------------

/// Hard active-set: keep the `k` largest free amplitudes per row.
///
/// Reconstruction uses `a_{n,j} = ℓ_{n,j}` if `j ∈ topk(ℓ_n)` else `0`.
///
/// **Straight-through gradient.** The forward map is discontinuous (the
/// active-set changes at amplitude ties); the backward pass treats the
/// thresholding as the identity, so `∂ℒ/∂ℓ ≈ ∂ℒ/∂a`. This is the standard
/// TopK-SAE convention (Makhzani & Frey 2014; Gao et al. 2024). The bias is
/// (i) zero whenever the active set is locally stable, and (ii) bounded by
/// `‖∂ℒ/∂a‖_∞` at tie crossings, so a small temperature in the upstream
/// objective is a sufficient mitigation when used together with adaptive
/// step sizes.
#[derive(Debug, Clone, Copy)]
pub struct TopK {
    pub k: usize,
}

impl TopK {
    pub fn new(k: usize) -> Self {
        assert!(k > 0, "TopK requires k > 0");
        Self { k }
    }

    fn topk_indices(&self, amps: ArrayView1<'_, f64>) -> Vec<usize> {
        let n = amps.len();
        let k_use = self.k.min(n);
        if k_use == 0 {
            return Vec::new();
        }
        let mut idx: Vec<usize> = (0..n).collect();
        let pivot = k_use.saturating_sub(1).min(n - 1);
        idx.sort_by(|&a, &b| {
            amps[b]
                .partial_cmp(&amps[a])
                .unwrap_or(std::cmp::Ordering::Equal)
        });
        idx.truncate(pivot + 1);
        idx
    }

    /// Straight-through backward: ∂ℒ/∂ℓ = ∂ℒ/∂a *masked to the active set*.
    ///
    /// Documenting the convention: we zero out the gradient on inactive
    /// coordinates (a stricter form of straight-through that matches the
    /// "dead-feature freezing" behaviour observed in TopK-SAE). Some
    /// references (e.g. Hubinger-style straight-through) pass the gradient
    /// through unmodified; the masked form is empirically better at avoiding
    /// dead atoms.
    pub fn backward_straight_through(
        &self,
        code: &SparseAtomCode,
        grad_a_row: ArrayView1<'_, f64>,
    ) -> Array1<f64> {
        let k = grad_a_row.len();
        let mut out = Array1::<f64>::zeros(k);
        for i in code.active_mask.iter_ones() {
            out[i] = grad_a_row[i];
        }
        out
    }
}

impl AtomSelectionStrategy for TopK {
    fn name(&self) -> &'static str {
        "topk"
    }

    fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
        let k_total = free_amplitudes_row.len();
        let mut mask = BitVec::zeros(k_total);
        let mut weights = vec![0.0_f64; k_total];
        for i in self.topk_indices(free_amplitudes_row) {
            mask.set(i, true);
            weights[i] = free_amplitudes_row[i];
        }
        SparseAtomCode {
            active_mask: mask,
            weights,
        }
    }

    fn backward(
        &self,
        free_amplitudes_row: ArrayView1<'_, f64>,
        code: &SparseAtomCode,
        grad_a_row: ArrayView1<'_, f64>,
    ) -> Array1<f64> {
        assert_eq!(
            free_amplitudes_row.len(),
            grad_a_row.len(),
            "TopK backward free-amplitude/gradient length mismatch"
        );
        assert_eq!(
            code.k_atoms(),
            grad_a_row.len(),
            "TopK backward code/gradient length mismatch"
        );
        self.backward_straight_through(code, grad_a_row)
    }
}

impl AssignmentSparsityCoupling for TopK {
    fn penalty_value_and_grad(
        &self,
        sparsity_penalty: &SparsityPenalty,
        free_amplitudes_row: ArrayView1<'_, f64>,
        rho: ArrayView1<'_, f64>,
    ) -> (f64, Array1<f64>) {
        assert_eq!(
            rho.len(),
            sparsity_penalty.rho_count(),
            "TopK sparsity rho length mismatch"
        );
        assert!(
            rho.iter().all(|value| value.is_finite()),
            "TopK sparsity rho must be finite"
        );
        // Cardinality is enforced structurally; no smooth penalty consumed.
        let k = free_amplitudes_row.len();
        (0.0, Array1::<f64>::zeros(k))
    }
}

// --- L1Relaxed --------------------------------------------------------------

/// Non-negative free amplitudes with a smoothed-L¹ penalty (Piece 4).
///
/// Forward: `a_{n,k} = max(ℓ_{n,k}, 0)`; active iff `a_{n,k} > 0`.
///
/// The active-set is *induced* by the smoothed-L¹ via the existing
/// active-set inner solver (see `src/solver/active_set.rs`). The
/// smoothing scale `ε` is the relaxation parameter (REML-selectable through
/// [`SparsityPenalty::with_eps_reml`]).
#[derive(Debug, Clone)]
pub struct L1Relaxed {
    /// Threshold below which an amplitude is treated as inactive. Defaults
    /// to `0.0` (exact non-negativity); larger values give a deadzone.
    pub active_threshold: f64,
}

impl L1Relaxed {
    pub fn new() -> Self {
        Self {
            active_threshold: 0.0,
        }
    }

    pub fn with_threshold(thr: f64) -> Self {
        Self {
            active_threshold: thr,
        }
    }
}

impl Default for L1Relaxed {
    fn default() -> Self {
        Self::new()
    }
}

impl AtomSelectionStrategy for L1Relaxed {
    fn name(&self) -> &'static str {
        "l1_relaxed"
    }

    fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
        let k = free_amplitudes_row.len();
        let mut mask = BitVec::zeros(k);
        let mut weights = vec![0.0_f64; k];
        for i in 0..k {
            let a = free_amplitudes_row[i].max(0.0);
            if a > self.active_threshold {
                mask.set(i, true);
                weights[i] = a;
            }
        }
        SparseAtomCode {
            active_mask: mask,
            weights,
        }
    }

    fn backward(
        &self,
        free_amplitudes_row: ArrayView1<'_, f64>,
        code: &SparseAtomCode,
        grad_a_row: ArrayView1<'_, f64>,
    ) -> Array1<f64> {
        // d max(ℓ,0)/dℓ = 1 for ℓ>0, 0 otherwise; restricted to the active set.
        let k = grad_a_row.len();
        let mut out = Array1::<f64>::zeros(k);
        for i in code.active_mask.iter_ones() {
            if free_amplitudes_row[i] > 0.0 {
                out[i] = grad_a_row[i];
            }
        }
        out
    }
}

impl AssignmentSparsityCoupling for L1Relaxed {
    fn penalty_value_and_grad(
        &self,
        penalty: &SparsityPenalty,
        free_amplitudes_row: ArrayView1<'_, f64>,
        rho: ArrayView1<'_, f64>,
    ) -> (f64, Array1<f64>) {
        // Apply the smoothed-L¹ to the non-negative free amplitudes
        // directly. Negative entries map to zero in the forward pass; for
        // the penalty we evaluate on the *clipped* values to keep the
        // sub-gradient at zero consistent with the active-set semantics.
        let k = free_amplitudes_row.len();
        let clipped = Array1::<f64>::from_iter((0..k).map(|i| free_amplitudes_row[i].max(0.0)));
        let v = penalty.value(clipped.view(), rho);
        let mut g = penalty.grad_target(clipped.view(), rho);
        for i in 0..k {
            if free_amplitudes_row[i] <= 0.0 {
                g[i] = 0.0;
            }
        }
        (v, g)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode};
    use ndarray::array;

    fn lib() -> AtomLibrary {
        let c0 = LatentCoordValues::from_matrix(
            array![[0.0, 0.0], [0.1, 0.2]].view(),
            LatentIdMode::None,
        );
        let c1 = LatentCoordValues::from_matrix(array![[0.0], [1.0]].view(), LatentIdMode::None);
        AtomLibrary::new(vec![
            AtomRecord::new(
                ShapeRef {
                    id: 0,
                    intrinsic_dim: 2,
                    basis_size: 8,
                },
                c0,
            ),
            AtomRecord::new(
                ShapeRef {
                    id: 1,
                    intrinsic_dim: 1,
                    basis_size: 5,
                },
                c1,
            ),
        ])
        .expect("library")
    }

    #[test]
    fn library_construct() {
        let l = lib();
        assert_eq!(l.k_atoms(), 2);
        assert_eq!(l.n_obs(), 2);
        assert_eq!(l.total_intrinsic_dim(), 3);
    }

    #[test]
    fn softmax_is_simplex() {
        let s = EntropicSoftmax::new(1.0);
        let logits = array![1.0_f64, 2.0, 3.0];
        let code = s.apply(logits.view());
        let sum: f64 = code.weights.iter().sum();
        assert!((sum - 1.0).abs() < 1e-12);
        assert_eq!(code.active_mask.count_ones(), 3);
    }

    #[test]
    fn topk_keeps_top() {
        let t = TopK::new(2);
        let amps = array![0.1_f64, 0.9, 0.4, 0.5];
        let code = t.apply(amps.view());
        assert_eq!(code.active_mask.count_ones(), 2);
        assert!(code.active_mask.get(1));
        assert!(code.active_mask.get(3));
    }

    #[test]
    fn l1_relaxed_clips_negatives() {
        let l = L1Relaxed::new();
        let amps = array![-0.5_f64, 0.3, -0.1, 0.7];
        let code = l.apply(amps.view());
        assert_eq!(code.active_mask.count_ones(), 2);
        assert_eq!(code.weights[1], 0.3);
        assert_eq!(code.weights[3], 0.7);
    }
}