Skip to main content

gam_sae/
assignment.rs

1//! Assignment gates and sparsity-prior helpers for the SAE manifold term.
2//! Mechanically split from `sae_manifold.rs`.
3
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use gam_solve::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
7use gam_terms::analytic_penalties::{
8    AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
9    SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
10};
11use gam_terms::latent::{LatentCoordValues, LatentIdMode, LatentManifold};
12use crate::manifold::SaeManifoldRho;
13
14/// #976 Layer-1 guard: cap on one accepted iteration's assignment-logit
15/// update, in units of the gate temperature τ (the gate's natural length
16/// scale — every assignment mode reads logits through `σ(·/τ)` /
17/// `softmax(·/τ)`). A 4τ move spans the gate's whole soft range, so healthy
18/// convergence is never throttled, but no single inner iteration can carry a
19/// gate from contention to numerically-zero support: a collapse takes
20/// multiple accepted iterations, which guarantees the per-iteration
21/// active-mass guard observes the decay before it completes. The clamp is
22/// applied where the step is realised; when it binds, the realised objective
23/// is evaluated on the clamped state, so the Armijo comparison stays
24/// value-consistent (the unclamped quadratic model is merely conservative,
25/// and step halvings shrink the trial below the cap).
26pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
27
28/// #976 Layer-1 guard: re-seed budget per atom per joint fit. One second
29/// chance from a fresh basin; a second breach means the collapse is (locally)
30/// the objective's verdict at the current hyperparameters, which is recorded
31/// as a terminal collapse event and left for the structure-search death move
32/// to adjudicate — re-seeding in a loop would fight the optimizer.
33pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
34
35/// #976 Layer-1 guard (decoder arm): an atom whose decoder block Frobenius norm
36/// has fallen to this fraction of the dictionary's MEDIAN decoder norm carries
37/// no material reconstruction signal — it has degenerated to (near-)zero output
38/// and decodes the same nothing as every other collapsed atom. This is the
39/// real-data K>1 failure that the gate-mass floor cannot see: the assignment
40/// gates can stay spread across rows (mass guard satisfied) while the decoders
41/// all collapse to ~0, giving EV≈0 and a rank-deficient per-row coordinate
42/// Hessian on every row (the 0→K·n evidence-deflation jump). The statistic is a
43/// RATIO to the dictionary median so it is scale-free and never fires for a
44/// uniformly-small but well-conditioned decoder; only an atom that has fallen
45/// far behind its peers is caught. By construction this is a no-op for K=1
46/// (a single atom has no peer to fall behind, and the median equals its own
47/// norm), so the K=1 path is byte-for-byte unchanged.
48pub(crate) const SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO: f64 = 1.0e-3;
49
50/// #976 / #1117 K>1 robustness: bounded DICTIONARY-level multi-start budget for
51/// the simultaneous co-collapse arm (the EV-floor branch of
52/// [`crate::manifold::SaeManifoldTerm::enforce_decoder_norm_guard`]).
53/// Distinct from the per-atom [`SAE_ATOM_COLLAPSE_RESEED_BUDGET`] (= 1): that
54/// budget governs reseeding ONE atom's gate logits against an optimizer that
55/// keeps killing it, where a loop would fight the optimizer. A co-collapse
56/// reseed is categorically different — it is a full-dictionary multi-start that
57/// re-diversifies ALL atoms onto distinct principal directions of a FRESHLY
58/// recomputed residual, so successive attempts explore genuinely different
59/// basins. A single such reseed empirically cannot always break a K≥3 three-way
60/// basin (identical (K, seed) flips EV≈0.40 ↔ 0.00), so this arm gets a small
61/// bounded budget of independent multi-starts. It is consumed ONLY when the
62/// whole dictionary's reconstruction EV is at or below the data-derived collapse
63/// bar (`collapse_ev_bar` = `SAE_COLLAPSE_PCA_EV_FRACTION` × the rank-K PCA
64/// ceiling, i.e. less than half the variance any rank-K linear dictionary could
65/// reach) — the old absolute magic EV floor was REPLACED by that data-derived
66/// bar, not deleted. A no-op for any healthy fit (real OLMo K=1 ~0.22, K=2
67/// ~0.40, both well above the bar); on the rough patches of a hard K≥3 fit a
68/// transient sub-bar EV can still consume the budget before the fit recovers.
69pub(crate) const SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET: usize = 3;
70
71/// Machine-precision support cutoff for the smooth JumpReLU assignment prior,
72/// in units of the gate temperature below the hard threshold. The forward gate
73/// remains hard-zero at and below `threshold`, but the prior value/gradient and
74/// compact Newton layout keep every logit with `(logit - threshold)/tau > -36`.
75/// At the excluded edge `sigma(-36) ~= 2e-16`, so dropped value/gradient/Hessian
76/// terms are below f64 noise instead of creating an algorithmic discontinuity.
77pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
78
79/// Shared support predicate for JumpReLU optimization inclusion. This is
80/// strictly weaker than the hard forward gate `logit > threshold`, which still
81/// governs data-fit reconstruction and its logit JVP.
82#[inline]
83pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
84    (logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
85}
86
87/// Assignment prior/relaxation used by [`SaeAssignment`].
88#[derive(Debug, Clone, Copy)]
89pub enum AssignmentMode {
90    /// Row-wise simplex assignment with entropy sparsity.
91    Softmax { temperature: f64, sparsity: f64 },
92    /// Deterministic concrete relaxation of a truncated IBP active set.
93    IBPMap {
94        temperature: f64,
95        alpha: f64,
96        learnable_alpha: bool,
97    },
98    /// Hard-thresholded bounded gate: each atom is off (gate = 0) when its logit
99    /// is at or below `threshold`, and on with a threshold-centered shifted
100    /// sigmoid `σ((logit − threshold) / temperature) ∈ [0.5, 1)` above it. This
101    /// is NOT literal JumpReLU `z·1[z>θ]` — the gate carries no magnitude; it is
102    /// a member of the gate family (softmax simplex / IBP sigmoid / this hard
103    /// gate) and stays bounded in [0, 1]. Reconstruction magnitude lives entirely
104    /// in the decoder curve `g_k(t) = φ(t)ᵀ B_k`. The discontinuity at `threshold`
105    /// (0 → 0.5) is the intended "jump".
106    JumpReLU { temperature: f64, threshold: f64 },
107}
108
109/// #1033 — the fixed-form predictor that produces the ρ-invariant FROZEN routing
110/// (amortized routing). Both forms are NO-learned-net deterministic functions of
111/// the current dictionary; they differ in how faithfully they track the
112/// dictionary as it evolves across outer iterates. Kept as alternatives so the
113/// accuracy gate can pick whichever passes the fit-quality bar (the cheap
114/// `Snapshot` if it suffices, the `ChartGeometry` distill otherwise).
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum RoutingPredictor {
117    /// Snapshot the current (converged) logits as the frozen routing — the
118    /// cheapest fixed-form distill, exact at the dictionary it is taken from.
119    /// Goes stale as the dictionary moves (needs a refresh to track), so it is the
120    /// MVP/baseline form.
121    Snapshot,
122    /// Re-derive the per-(row, atom) routing logit from the atom's encode-chart
123    /// geometry against the CURRENT dictionary: encode each row to its predicted
124    /// coord `t̂`, reconstruct the amplitude-1 image `γ_k(t̂) = Bᵀφ(t̂)`, and map
125    /// the reconstruction ALIGNMENT to a logit. This tracks the dictionary
126    /// (a moved decoder changes `γ_k(t̂)` and hence the routing) without re-running
127    /// the free-logit inner solve, so it is the default-readiness form when the
128    /// snapshot proves too stale.
129    ChartGeometry,
130}
131
132impl AssignmentMode {
133    #[must_use]
134    pub fn softmax(temperature: f64) -> Self {
135        Self::Softmax {
136            temperature,
137            sparsity: 1.0,
138        }
139    }
140
141    #[must_use]
142    pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
143        Self::IBPMap {
144            temperature,
145            alpha,
146            learnable_alpha,
147        }
148    }
149
150    #[must_use]
151    pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
152        Self::JumpReLU {
153            temperature,
154            threshold,
155        }
156    }
157
158    pub fn temperature(&self) -> f64 {
159        match *self {
160            AssignmentMode::Softmax { temperature, .. }
161            | AssignmentMode::IBPMap { temperature, .. }
162            | AssignmentMode::JumpReLU { temperature, .. } => temperature,
163        }
164    }
165
166    pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
167        if !(new_temperature.is_finite() && new_temperature > 0.0) {
168            return Err(format!(
169                "AssignmentMode: temperature must be finite and positive; got {new_temperature}"
170            ));
171        }
172        match self {
173            AssignmentMode::Softmax { temperature, .. }
174            | AssignmentMode::IBPMap { temperature, .. }
175            | AssignmentMode::JumpReLU { temperature, .. } => {
176                *temperature = new_temperature;
177            }
178        }
179        Ok(())
180    }
181
182    pub(crate) fn validate(&self) -> Result<(), String> {
183        let temperature = self.temperature();
184        if !(temperature.is_finite() && temperature > 0.0) {
185            return Err(format!(
186                "AssignmentMode: temperature must be finite and positive; got {temperature}"
187            ));
188        }
189        match *self {
190            AssignmentMode::Softmax { sparsity, .. } => {
191                if !(sparsity.is_finite() && sparsity > 0.0) {
192                    return Err(format!(
193                        "AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
194                    ));
195                }
196            }
197            AssignmentMode::IBPMap { alpha, .. } => {
198                if !(alpha.is_finite() && alpha > 0.0) {
199                    return Err(format!(
200                        "AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
201                    ));
202                }
203            }
204            AssignmentMode::JumpReLU { threshold, .. } => {
205                if !threshold.is_finite() {
206                    return Err(format!(
207                        "AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
208                    ));
209                }
210            }
211        }
212        Ok(())
213    }
214
215    pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
216        match *self {
217            AssignmentMode::IBPMap {
218                alpha,
219                learnable_alpha,
220                ..
221            } => Some(if let Some(over) = ibp_alpha_override() {
222                // #1026 — a process-global α override flattens the ordered
223                // geometric prior π_k = (α/(α+1))^{k+1} so all K atoms can
224                // contribute to the reconstruction (the production α=1 gives a
225                // (0.5)^{k+1} schedule that structurally caps atoms 4..K to a few
226                // percent → effective-K≈3). Forces the fixed value, bypassing the
227                // learnable schedule, so a sweep can attribute the EV ceiling.
228                over
229            } else if learnable_alpha {
230                resolve_learnable_weight(alpha, rho.log_lambda_sparse)
231            } else {
232                alpha
233            }),
234            _ => None,
235        }
236    }
237}
238
239// #1026 — process-global IBP-α override (NaN sentinel = "unset → use the
240// AssignmentMode's compiled α"). Lets ONE wheel sweep the prior-flattening axis
241// from Python (`sae_set_ibp_alpha`) without recompiling the gam crate.
242static IBP_ALPHA_OVERRIDE_BITS: std::sync::atomic::AtomicU64 =
243    std::sync::atomic::AtomicU64::new(0x7ff8_0000_0000_0000);
244
245pub(crate) fn ibp_alpha_override() -> Option<f64> {
246    let v = f64::from_bits(IBP_ALPHA_OVERRIDE_BITS.load(std::sync::atomic::Ordering::Relaxed));
247    if v.is_finite() && v > 0.0 {
248        Some(v)
249    } else {
250        None
251    }
252}
253
254/// Set (or, with a non-finite/non-positive value, clear) the process-global
255/// IBP-α override. Called from the gamfit Python FFI sweep driver.
256pub fn set_ibp_alpha_override(alpha: f64) {
257    IBP_ALPHA_OVERRIDE_BITS.store(alpha.to_bits(), std::sync::atomic::Ordering::Relaxed);
258}
259
260/// Per-row latent assignment state.
261///
262/// The stored assignment parameter is `logits`; non-negative assignments are
263/// derived by row-wise softmax, independent IBP-MAP sigmoid active indicators,
264/// or JumpReLU gates. Softmax logits are canonicalized to the reference chart
265/// `logits[K - 1] = 0`, so the row-local Newton coordinates contain only the
266/// first `K - 1` logits (`0` coordinates for `K = 1`). Gate-style modes keep
267/// all `K` logits as identifiable scalar parameters. `coords[k]` holds
268/// `t_{.,k}` for atom `k`.
269#[derive(Debug, Clone)]
270pub struct SaeAssignment {
271    pub logits: Array2<f64>,
272    pub coords: Vec<LatentCoordValues>,
273    pub mode: AssignmentMode,
274    /// #1026 — per-atom UNGATED flag (length `K`, default all-`false`). An
275    /// ungated atom is the dense linear/background tier: its per-row gate is
276    /// fixed at `a_k ≡ 1` (it contributes `γ_k(t_k)` to EVERY row, unweighted),
277    /// it is excluded from the other atoms' gate (for the column-separable
278    /// IBP / JumpReLU modes the remaining atoms are computed independently, so
279    /// they are unaffected), and its logit is NOT a free parameter — its
280    /// logit-JVP, sparsity-prior gradient/curvature, and softmax majorizer
281    /// contributions are all zero, leaving its logit slot an inert
282    /// (ridge-regularized) null direction in the per-row Newton block. This lets
283    /// the linear tier carry FULL-RANK reconstructible variance
284    /// (`fitted = γ_ungated(x) + Σ_{gated} a_k·γ_k(x)`) so a linear SAE can reach
285    /// the rank-(K·d) PCA ceiling, while the gated curved atoms still add sparse
286    /// structure on the residual (#1026 routing-bound finding).
287    pub ungated: Vec<bool>,
288    /// #1033 — AMORTIZED / FROZEN routing. When `Some`, this `(n, K)` matrix is a
289    /// ρ-INVARIANT predicted routing (the amortized `x → logits` map distilled
290    /// from the frozen dictionary): the gates are computed from THESE logits
291    /// instead of the free `self.logits`, and the logits are NOT optimized by the
292    /// inner Newton (their gradient/curvature/prior contributions are zeroed,
293    /// exactly as for [`Self::ungated`]). This is the generalization of an ungated
294    /// atom from "pin the gate at 1" to "pin the gate at the predicted value": it
295    /// makes the per-row routing a fixed function of `x` + the frozen dictionary,
296    /// so the outer ρ-search reuses ONE routing instead of re-solving per-row
297    /// gates every outer eval — the n-independent-outer-loop lever (#1033). `None`
298    /// is the historical free-logit path (bit-identical).
299    pub frozen_logits: Option<Array2<f64>>,
300}
301
302impl SaeAssignment {
303    #[must_use = "build error must be handled"]
304    pub fn new(
305        logits: Array2<f64>,
306        coords: Vec<LatentCoordValues>,
307        temperature: f64,
308    ) -> Result<Self, String> {
309        Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
310    }
311
312    #[must_use = "build error must be handled"]
313    pub fn with_mode(
314        mut logits: Array2<f64>,
315        coords: Vec<LatentCoordValues>,
316        mode: AssignmentMode,
317    ) -> Result<Self, String> {
318        mode.validate()?;
319        let n = logits.nrows();
320        let k = logits.ncols();
321        if coords.len() != k {
322            return Err(format!(
323                "SaeAssignment::new: coords length {} must equal K={k}",
324                coords.len()
325            ));
326        }
327        for (atom, coord) in coords.iter().enumerate() {
328            if coord.n_obs() != n {
329                return Err(format!(
330                    "SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
331                    coord.n_obs()
332                ));
333            }
334        }
335        for row in 0..n {
336            validate_finite_logits(logits.row(row), row)?;
337        }
338        if matches!(mode, AssignmentMode::Softmax { .. }) {
339            canonicalize_softmax_logits(&mut logits);
340        }
341        Ok(Self {
342            logits,
343            coords,
344            mode,
345            ungated: vec![false; k],
346            frozen_logits: None,
347        })
348    }
349
350    /// #1033 — install a ρ-INVARIANT FROZEN routing (the amortized predicted
351    /// logits; see [`SaeAssignment::frozen_logits`]). `predicted` must be
352    /// `(n, K)`. With routing frozen, the gates are computed from `predicted` and
353    /// the logits are excluded from the inner Newton (their gradient/curvature are
354    /// inert, like an ungated atom's). Passing `None` restores the free-logit
355    /// path.
356    #[must_use = "build error must be handled"]
357    pub fn with_frozen_routing(mut self, predicted: Option<Array2<f64>>) -> Result<Self, String> {
358        if let Some(ref p) = predicted {
359            if p.dim() != (self.n_obs(), self.k_atoms()) {
360                return Err(format!(
361                    "SaeAssignment::with_frozen_routing: predicted shape {:?} must be ({}, {})",
362                    p.dim(),
363                    self.n_obs(),
364                    self.k_atoms()
365                ));
366            }
367            if matches!(self.mode, AssignmentMode::Softmax { .. }) {
368                return Err(
369                    "SaeAssignment::with_frozen_routing: frozen routing under Softmax is rejected \
370                     — the coupled simplex's entropy majorizer is assembled over the logits, which \
371                     a frozen (non-optimized) routing would leave inconsistent; this separable-mode \
372                     contract supports IBP-MAP and JumpReLU, whose per-atom gates have no \
373                     simplex-coupled curvature to skip"
374                        .to_string(),
375                );
376            }
377            for row in 0..p.nrows() {
378                validate_finite_logits(p.row(row), row)?;
379            }
380        }
381        self.frozen_logits = predicted;
382        Ok(self)
383    }
384
385    /// Whether the per-row routing is FROZEN (amortized) rather than free-logit.
386    pub fn routing_is_frozen(&self) -> bool {
387        self.frozen_logits.is_some()
388    }
389
390    /// The active routing logits for `row`: the frozen/predicted logits when
391    /// routing is frozen (#1033), else the free `self.logits`. This is the SINGLE
392    /// source the gate value reads, so freezing routing changes every gate
393    /// consistently.
394    pub(crate) fn routing_logits_row(&self, row: usize) -> ArrayView1<'_, f64> {
395        match self.frozen_logits {
396            Some(ref f) => f.row(row),
397            None => self.logits.row(row),
398        }
399    }
400
401    /// Whether atom `k`'s logit is held fixed (not a free Newton parameter): true
402    /// for an ungated atom (#1026, gate pinned at 1) OR when routing is frozen
403    /// (#1033, gate pinned at the predicted value). Both share the same inert
404    /// treatment — zero logit-JVP, zero sparsity-prior gradient/curvature, zero
405    /// softmax majorizer — so the logit slot never moves.
406    pub(crate) fn logit_is_fixed(&self, k: usize) -> bool {
407        self.routing_is_frozen() || self.ungated.get(k).copied().unwrap_or(false)
408    }
409
410    /// Per-atom mask (length `K`) of [`Self::logit_is_fixed`] — the logit slots
411    /// that are NOT free Newton parameters (ungated #1026 and/or frozen-routing
412    /// #1033). Precompute once per assembly and pass to the logit-JVP fillers so
413    /// the data-fit Jacobian zeroes those rows. Under frozen routing every entry
414    /// is `true`; with only ungated atoms it equals `ungated`; otherwise all
415    /// `false` (the historical free-logit path).
416    pub(crate) fn fixed_logit_mask(&self) -> Vec<bool> {
417        if self.routing_is_frozen() {
418            vec![true; self.k_atoms()]
419        } else {
420            self.ungated.clone()
421        }
422    }
423
424    /// #1033 — install the simplest faithful AMORTIZED routing predictor: a
425    /// fixed-form DISTILL of the current dictionary's routing, namely the current
426    /// (converged) logits SNAPSHOTTED as the ρ-invariant frozen routing. This is
427    /// the `x → logits` map "evaluated once at the frozen dictionary" — the
428    /// routing the dictionary already expresses — held fixed so the outer ρ-search
429    /// reuses it instead of re-optimizing the gates at every ρ. (A richer
430    /// predictor that recomputes logits from `x` via the encode-atlas chart
431    /// geometry is a later refinement; snapshotting the converged routing is the
432    /// exact fixed-point it would target at the frozen dictionary.) Rejected for
433    /// Softmax for the same simplex-coupling reason as [`Self::with_frozen_routing`].
434    #[must_use = "build error must be handled"]
435    pub fn freeze_routing_from_current_logits(self) -> Result<Self, String> {
436        let snapshot = self.logits.clone();
437        self.with_frozen_routing(Some(snapshot))
438    }
439
440    /// #1033 — in-place variant of [`Self::freeze_routing_from_current_logits`]
441    /// for callers holding `&mut SaeAssignment` (e.g. inside a `SaeManifoldTerm`),
442    /// where moving the assignment out is awkward. Same contract: snapshot the
443    /// current logits as the ρ-invariant frozen routing; reject Softmax.
444    pub fn freeze_routing_in_place(&mut self) -> Result<(), String> {
445        if matches!(self.mode, AssignmentMode::Softmax { .. }) {
446            return Err(
447                "SaeAssignment::freeze_routing_in_place: frozen routing under Softmax is rejected \
448                 (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
449                    .to_string(),
450            );
451        }
452        let snapshot = self.logits.clone();
453        for row in 0..snapshot.nrows() {
454            validate_finite_logits(snapshot.row(row), row)?;
455        }
456        self.frozen_logits = Some(snapshot);
457        Ok(())
458    }
459
460    /// #1033 — install an explicit predicted routing in place (the
461    /// [`RoutingPredictor::ChartGeometry`] output), `&mut self` variant of
462    /// [`Self::with_frozen_routing`]. `predicted` must be `(n, K)`; rejects Softmax
463    /// (separable-mode contract) and non-finite predictions.
464    pub fn set_frozen_routing_in_place(&mut self, predicted: Array2<f64>) -> Result<(), String> {
465        if predicted.dim() != (self.n_obs(), self.k_atoms()) {
466            return Err(format!(
467                "SaeAssignment::set_frozen_routing_in_place: predicted shape {:?} must be ({}, {})",
468                predicted.dim(),
469                self.n_obs(),
470                self.k_atoms()
471            ));
472        }
473        if matches!(self.mode, AssignmentMode::Softmax { .. }) {
474            return Err(
475                "SaeAssignment::set_frozen_routing_in_place: frozen routing under Softmax is \
476                 rejected (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
477                    .to_string(),
478            );
479        }
480        for row in 0..predicted.nrows() {
481            validate_finite_logits(predicted.row(row), row)?;
482        }
483        self.frozen_logits = Some(predicted);
484        Ok(())
485    }
486
487    /// #1033 — lift the frozen routing, restoring the free-logit search path.
488    pub fn thaw_routing(&mut self) {
489        self.frozen_logits = None;
490    }
491
492    /// #1026 — designate which atoms are UNGATED (the dense linear/background
493    /// tier; see [`SaeAssignment::ungated`]). `flags` must have length `K`.
494    ///
495    /// Ungating is defined for the COLUMN-SEPARABLE gate modes (IBP-MAP and
496    /// JumpReLU): each atom's gate is an independent per-atom function of its own
497    /// logit, so pinning one atom to `a_k ≡ 1` leaves every other atom's gate
498    /// exactly as computed. Softmax is a coupled simplex (`Σ_k a_k = 1` over all
499    /// `K`), so a unit gate for one atom is only well defined relative to a
500    /// gated-subset renormalization that must also be reflected in the logit-JVP
501    /// and the entropy majorizer; this constructor's contract is restricted to
502    /// the separable modes, and an ungated atom under Softmax is REJECTED here so
503    /// the inner solve never runs on a value/gradient-mismatched gate. Callers
504    /// wanting a dense background tier under Softmax route it as an IBP-MAP or
505    /// JumpReLU atom.
506    #[must_use = "build error must be handled"]
507    pub fn with_ungated(mut self, flags: Vec<bool>) -> Result<Self, String> {
508        if flags.len() != self.k_atoms() {
509            return Err(format!(
510                "SaeAssignment::with_ungated: flags length {} must equal K={}",
511                flags.len(),
512                self.k_atoms()
513            ));
514        }
515        if matches!(self.mode, AssignmentMode::Softmax { .. }) && flags.iter().any(|&u| u) {
516            return Err(
517                "SaeAssignment::with_ungated: an ungated atom under Softmax routing is \
518                 rejected — the coupled simplex requires a gated-subset renormalization \
519                 reflected in the logit-JVP and entropy majorizer, which this separable-mode \
520                 contract does not perform; route a dense background tier as IBP-MAP or JumpReLU"
521                    .to_string(),
522            );
523        }
524        self.ungated = flags;
525        Ok(self)
526    }
527
528    /// Whether any atom is ungated (the #1026 background tier is engaged).
529    pub fn has_ungated(&self) -> bool {
530        self.ungated.iter().any(|&u| u)
531    }
532
533    pub fn n_obs(&self) -> usize {
534        self.logits.nrows()
535    }
536
537    pub fn k_atoms(&self) -> usize {
538        self.logits.ncols()
539    }
540
541    pub fn total_coord_dim(&self) -> usize {
542        self.coords.iter().map(|c| c.latent_dim()).sum()
543    }
544
545    pub fn assignment_coord_dim(&self) -> usize {
546        match self.mode {
547            AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
548            AssignmentMode::IBPMap { .. } | AssignmentMode::JumpReLU { .. } => self.k_atoms(),
549        }
550    }
551
552    pub fn row_block_dim(&self) -> usize {
553        self.assignment_coord_dim() + self.total_coord_dim()
554    }
555
556    pub fn coord_offsets(&self) -> Vec<usize> {
557        let mut out = Vec::with_capacity(self.k_atoms());
558        let mut cursor = self.assignment_coord_dim();
559        for coord in &self.coords {
560            out.push(cursor);
561            cursor += coord.latent_dim();
562        }
563        out
564    }
565
566    pub fn assignments(&self) -> Array2<f64> {
567        let n = self.n_obs();
568        let k = self.k_atoms();
569        let mut out = Array2::<f64>::zeros((n, k));
570        for row in 0..n {
571            let a = self.assignments_row(row);
572            for atom in 0..k {
573                out[[row, atom]] = a[atom];
574            }
575        }
576        out
577    }
578
579    pub fn assignments_row(&self, row: usize) -> Array1<f64> {
580        self.try_assignments_row(row)
581            .expect("assignment logits must be finite")
582    }
583
584    pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
585        self.try_assignments_row_with_alpha(row, None)
586    }
587
588    pub(crate) fn try_assignments_row_for_rho(
589        &self,
590        row: usize,
591        rho: &SaeManifoldRho,
592    ) -> Result<Array1<f64>, String> {
593        self.try_assignments_row_with_alpha(row, self.mode.resolved_ibp_alpha(rho))
594    }
595
596    fn try_assignments_row_with_alpha(
597        &self,
598        row: usize,
599        resolved_ibp_alpha: Option<f64>,
600    ) -> Result<Array1<f64>, String> {
601        // #1033 — read the ACTIVE routing logits: the ρ-invariant frozen/predicted
602        // logits when routing is frozen, else the free `self.logits`. This single
603        // source makes the gate value ρ-invariant under frozen routing (the
604        // amortized-routing lever) and bit-identical to the historical path when
605        // not frozen.
606        let routing = self.routing_logits_row(row);
607        validate_finite_logits(routing, row)?;
608        // Only Softmax collapses to a fixed assignment at K==1: its
609        // assignment_coord_dim is K-1 = 0, so there is no free logit. IBPMap and
610        // JumpReLU keep a free per-atom gate logit even at K==1
611        // (assignment_coord_dim = K = 1), so they must fall through to their real
612        // row functions or the logit would move the prior but not the gate.
613        if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
614            return Ok(Array1::from_vec(vec![1.0]));
615        }
616        let mut row_gates = match self.mode {
617            AssignmentMode::Softmax { temperature, .. } => softmax_row(routing, temperature),
618            AssignmentMode::IBPMap {
619                temperature, alpha, ..
620            } => ibp_map_row(routing, temperature, resolved_ibp_alpha.unwrap_or(alpha)),
621            AssignmentMode::JumpReLU {
622                temperature,
623                threshold,
624            } => jumprelu_row(routing, temperature, threshold),
625        };
626        // #1026 — ungated (background-tier) atoms have a fixed unit gate. For the
627        // column-separable IBP / JumpReLU modes the other atoms' gates are
628        // computed independently above, so overwriting the ungated entries to 1.0
629        // leaves the gated atoms exactly as they were; the ungated atom then
630        // contributes `γ_k(t_k)` unweighted to every row. (Softmax + ungated is
631        // rejected at `with_ungated`, so no simplex renormalization is needed
632        // here.)
633        if self.has_ungated() {
634            for (k, gate) in row_gates.iter_mut().enumerate() {
635                if self.ungated[k] {
636                    *gate = 1.0;
637                }
638            }
639        }
640        Ok(row_gates)
641    }
642
643    /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_for_rho`].
644    ///
645    /// Writes the EXACT SAME per-atom assignment row into `out` (length
646    /// `k_atoms()`) instead of allocating a fresh `Array1`. Bit-identical to the
647    /// allocating path; intended for the hot per-row loops that immediately
648    /// consume the row, reusing a single scratch buffer across rows.
649    pub(crate) fn try_assignments_row_for_rho_into(
650        &self,
651        row: usize,
652        rho: &SaeManifoldRho,
653        out: &mut [f64],
654    ) -> Result<(), String> {
655        self.try_assignments_row_with_alpha_into(row, self.mode.resolved_ibp_alpha(rho), out)
656    }
657
658    /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_with_alpha`].
659    ///
660    /// `out` must have length `k_atoms()`; it is fully overwritten with the same
661    /// values the allocating variant would return. Every branch (early-return
662    /// K==1 Softmax, the per-mode row math, the #1026 ungated overwrite) mirrors
663    /// the allocating path exactly so the two are bit-identical.
664    pub(crate) fn try_assignments_row_with_alpha_into(
665        &self,
666        row: usize,
667        resolved_ibp_alpha: Option<f64>,
668        out: &mut [f64],
669    ) -> Result<(), String> {
670        // `out` is sized `k_atoms()` by every caller; the per-mode helpers below
671        // fully overwrite indices `0..k_atoms()`.
672        let routing = self.routing_logits_row(row);
673        validate_finite_logits(routing, row)?;
674        // Mirror the allocating early-return: only Softmax collapses to a fixed
675        // unit assignment at K==1.
676        if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
677            out[0] = 1.0;
678            return Ok(());
679        }
680        match self.mode {
681            AssignmentMode::Softmax { temperature, .. } => {
682                softmax_row_into(routing, temperature, out)
683            }
684            AssignmentMode::IBPMap {
685                temperature, alpha, ..
686            } => ibp_map_row_into(
687                routing,
688                temperature,
689                resolved_ibp_alpha.unwrap_or(alpha),
690                out,
691            ),
692            AssignmentMode::JumpReLU {
693                temperature,
694                threshold,
695            } => jumprelu_row_into(routing, temperature, threshold, out),
696        };
697        // #1026 — ungated (background-tier) atoms have a fixed unit gate, exactly
698        // as in the allocating path.
699        if self.has_ungated() {
700            for (k, gate) in out.iter_mut().enumerate() {
701                if self.ungated[k] {
702                    *gate = 1.0;
703                }
704            }
705        }
706        Ok(())
707    }
708
709    pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
710        let AssignmentMode::IBPMap {
711            temperature,
712            alpha,
713            learnable_alpha: true,
714        } = self.mode
715        else {
716            return false;
717        };
718        let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
719        self.mode = AssignmentMode::IBPMap {
720            temperature,
721            alpha: resolved_alpha,
722            learnable_alpha: false,
723        };
724        true
725    }
726
727    pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
728        let n = self.n_obs();
729        let k = self.k_atoms();
730        let mut out = Array2::<f64>::zeros((n, k));
731        for row in 0..n {
732            let a = self.try_assignments_row_for_rho(row, rho)?;
733            for atom in 0..k {
734                out[[row, atom]] = a[atom];
735            }
736        }
737        Ok(out)
738    }
739
740    /// Flatten extension coordinates in row-major SAE layout:
741    /// `(assignment chart_i, t_i0[0..d_0], ..., t_iK[0..d_K])` for every row.
742    /// Softmax contributes the first `K - 1` reference logits and omits the
743    /// fixed reference logit; gate-style assignment modes contribute all `K`
744    /// logits.
745    pub fn flatten_ext_coords(&self) -> Array1<f64> {
746        let n = self.n_obs();
747        let q = self.row_block_dim();
748        let k = self.k_atoms();
749        let assignment_dim = self.assignment_coord_dim();
750        let offsets = self.coord_offsets();
751        let mut out = Array1::<f64>::zeros(n * q);
752        for row in 0..n {
753            let base = row * q;
754            for atom in 0..assignment_dim {
755                out[base + atom] = self.logits[[row, atom]];
756            }
757            for atom in 0..k {
758                let d = self.coords[atom].latent_dim();
759                let t_row = self.coords[atom].row(row);
760                for axis in 0..d {
761                    out[base + offsets[atom] + axis] = t_row[axis];
762                }
763            }
764        }
765        out
766    }
767
768    #[must_use = "build error must be handled"]
769    pub fn from_blocks_with_mode(
770        logits: Array2<f64>,
771        coord_blocks: Vec<Array2<f64>>,
772        mode: AssignmentMode,
773    ) -> Result<Self, String> {
774        let coords = coord_blocks
775            .iter()
776            .map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
777            .collect();
778        Self::with_mode(logits, coords, mode)
779    }
780
781    #[must_use = "build error must be handled"]
782    pub fn from_blocks_with_mode_and_manifolds(
783        logits: Array2<f64>,
784        coord_blocks: Vec<Array2<f64>>,
785        manifolds: Vec<LatentManifold>,
786        mode: AssignmentMode,
787    ) -> Result<Self, String> {
788        if coord_blocks.len() != manifolds.len() {
789            return Err(format!(
790                "SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
791                coord_blocks.len(),
792                manifolds.len()
793            ));
794        }
795        let coords = coord_blocks
796            .iter()
797            .zip(manifolds)
798            .map(|(c, manifold)| {
799                LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
800            })
801            .collect();
802        Self::with_mode(logits, coords, mode)
803    }
804}
805
806pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
807    match mode {
808        AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
809        AssignmentMode::IBPMap {
810            temperature, alpha, ..
811        } => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
812        AssignmentMode::JumpReLU { .. } => Array1::from_elem(k_atoms, 0.5),
813    }
814}
815
816pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
817    let k = logits.len();
818    let inv_tau = 1.0 / temperature;
819    let mut max_logit = f64::NEG_INFINITY;
820    for &v in logits.iter() {
821        max_logit = max_logit.max(v);
822    }
823    let mut out = Array1::<f64>::zeros(k);
824    let mut sum = 0.0;
825    for i in 0..k {
826        let v = ((logits[i] - max_logit) * inv_tau).exp();
827        out[i] = v;
828        sum += v;
829    }
830    assert!(sum.is_finite() && sum > 0.0);
831    for v in out.iter_mut() {
832        *v /= sum;
833    }
834    out
835}
836
837pub(crate) fn validate_finite_logits(
838    logits: ArrayView1<'_, f64>,
839    row: usize,
840) -> Result<(), String> {
841    for (col, &v) in logits.iter().enumerate() {
842        if !v.is_finite() {
843            return Err(format!(
844                "SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
845            ));
846        }
847    }
848    Ok(())
849}
850
851pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
852    let k = logits.ncols();
853    if k == 0 {
854        return;
855    }
856    if k == 1 {
857        logits.fill(0.0);
858        return;
859    }
860    for row in 0..logits.nrows() {
861        let reference = logits[[row, k - 1]];
862        for col in 0..k - 1 {
863            logits[[row, col]] -= reference;
864        }
865        logits[[row, k - 1]] = 0.0;
866    }
867}
868
869/// Truncated Indian-Buffet-Process stick-breaking prior *means*
870/// `π_k = E[∏_{j=0}^{k} v_j] = (α/(α+1))^{k+1}` for k = 0, .., K-1, with sticks
871/// `v_j ~ Beta(α, 1)` so `E[v_j] = α/(α+1)`. EVERY atom (including the first,
872/// `π_0 = α/(α+1)`) carries the consistent Beta(α, 1) shrinkage: there is no
873/// special-cased always-on base atom, so `α` behaves as a genuine IBP
874/// concentration — larger `α` ⇒ heavier mass / slower decay, `α → 0` ⇒ all mass
875/// collapses onto nothing, matching the stick-breaking limit. This is the
876/// deterministic MAP / mean-field form of the IBP prior (the closed form the
877/// analytic Newton / Hessian / Woodbury machinery differentiates); no sticks are
878/// *sampled* here, the per-atom weight is the exact expectation of the
879/// stick-breaking product. (#614: previously `π_0 = 1` left the first atom
880/// unshrunk, which is the prior mean of NO stick at all and broke α's role as a
881/// concentration; the consistent product mean restores genuine IBP semantics.)
882pub(crate) fn ordered_geometric_shrinkage_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
883    // Accumulate the geometric schedule `π_k = ratio^(k+1)` in LOG space so the
884    // prior stays a finite *soft* weight even for large `K`. The naive product
885    // `acc *= ratio` underflows to exact `0.0` once `ratio^(k+1) < f64::MIN_POSITIVE`
886    // (e.g. `(0.1/1.1)^320`), which would turn the soft shrinkage prior into a
887    // HARD mask: such atoms would receive zero assignment AND zero logit
888    // gradient (the gradient is multiplied by `π_k`), so they could never
889    // reactivate. Working in log-space and flooring the exponentiated weight at
890    // the smallest positive normal keeps every atom's gradient path alive while
891    // preserving the geometric ordering.
892    let mut out = Array1::<f64>::zeros(k_atoms);
893    let log_ratio = (alpha / (alpha + 1.0)).ln();
894    for k in 0..k_atoms {
895        // π_k = (α/(α+1))^{k+1}: the product of (k+1) i.i.d. Beta(α,1) stick
896        // means, so atom 0 is also shrunk by one stick (E[v_0] = α/(α+1)).
897        let log_pi = ((k + 1) as f64) * log_ratio;
898        out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
899    }
900    out
901}
902
903/// IBP-MAP row activations: per-atom sigmoid likelihood times the truncated
904/// stick-breaking prior mean `π_k = (α/(α+1))^{k+1}`. With tied logits the prior
905/// dominates and yields strictly decreasing activations in atom index, with the
906/// first atom already shrunk by one Beta(α,1) stick mean (no unshrunk base atom).
907pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
908    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
909    let mut out = Array1::<f64>::zeros(logits.len());
910    for i in 0..logits.len() {
911        out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
912    }
913    out
914}
915
916/// IBP-MAP activations together with the diagonal Jacobian `∂z_k/∂l_k`,
917/// shared with the torch autograd `Function` so the Python IBP-Gumbel path
918/// applies the same stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` and
919/// temperature scaling as the Rust closed form. With `z_k = σ(l_k/τ)·π_k` the
920/// per-atom derivative is
921/// `σ(l_k/τ)(1 − σ(l_k/τ))·π_k / τ`; the map is diagonal in `k`, so the
922/// Jacobian is returned as the per-atom diagonal vector.
923#[must_use]
924pub fn ibp_map_row_value_grad(
925    logits: ArrayView1<'_, f64>,
926    temperature: f64,
927    alpha: f64,
928) -> (Array1<f64>, Array1<f64>) {
929    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
930    let inv_tau = 1.0 / temperature;
931    let mut value = Array1::<f64>::zeros(logits.len());
932    let mut grad = Array1::<f64>::zeros(logits.len());
933    for i in 0..logits.len() {
934        let sig = gam_linalg::utils::stable_logistic(logits[i] * inv_tau);
935        value[i] = sig * prior[i];
936        grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
937    }
938    (value, grad)
939}
940
941pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
942    let mut out = Array1::<f64>::zeros(logits.len());
943    for i in 0..logits.len() {
944        // Hard gate: strictly zero below threshold (the intended "jump"). Above
945        // threshold the surrogate is centered at the threshold so the gate is
946        // most informative exactly at the boundary it switches on:
947        // σ((l−θ)/τ) ∈ [0.5, 1). Magnitude lives in the decoder, so the gate
948        // stays bounded in [0, 1] by design.
949        if logits[i] > threshold {
950            out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
951        }
952    }
953    out
954}
955
956// #1557 — fill-into-caller-buffer variants of the three per-mode row functions.
957// These compute the EXACT SAME values as `softmax_row` / `ibp_map_row` /
958// `jumprelu_row` (same arithmetic, same order of operations) but write into a
959// caller-provided `&mut [f64]` slice instead of heap-allocating a fresh
960// `Array1<f64>` per call. The hot per-row loops (loss eval, arrow/Schur row
961// loops) call these with a reused scratch buffer, eliminating millions of tiny
962// K-sized allocations while staying bit-identical to the allocating path.
963// `out` must have length `logits.len()`; the slice is fully overwritten.
964
965pub(crate) fn softmax_row_into(logits: ArrayView1<'_, f64>, temperature: f64, out: &mut [f64]) {
966    let k = logits.len();
967    let inv_tau = 1.0 / temperature;
968    let mut max_logit = f64::NEG_INFINITY;
969    for &v in logits.iter() {
970        max_logit = max_logit.max(v);
971    }
972    let mut sum = 0.0;
973    for i in 0..k {
974        let v = ((logits[i] - max_logit) * inv_tau).exp();
975        out[i] = v;
976        sum += v;
977    }
978    assert!(sum.is_finite() && sum > 0.0);
979    for v in out.iter_mut() {
980        *v /= sum;
981    }
982}
983
984pub(crate) fn ibp_map_row_into(
985    logits: ArrayView1<'_, f64>,
986    temperature: f64,
987    alpha: f64,
988    out: &mut [f64],
989) {
990    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
991    for i in 0..logits.len() {
992        out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
993    }
994}
995
996pub(crate) fn jumprelu_row_into(
997    logits: ArrayView1<'_, f64>,
998    temperature: f64,
999    threshold: f64,
1000    out: &mut [f64],
1001) {
1002    for i in 0..logits.len() {
1003        // Match `jumprelu_row`: strictly zero below threshold, sigmoid surrogate
1004        // above. The buffer is fully overwritten (no read of prior contents).
1005        if logits[i] > threshold {
1006            out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1007        } else {
1008            out[i] = 0.0;
1009        }
1010    }
1011}
1012
1013pub(crate) struct ActiveAtomLogitJvp<'a> {
1014    pub(crate) mode: AssignmentMode,
1015    pub(crate) k: usize,
1016    pub(crate) logit_k: f64,
1017    pub(crate) a_k: f64,
1018    pub(crate) decoded_k: ArrayView1<'a, f64>,
1019    pub(crate) fitted: ArrayView1<'a, f64>,
1020    pub(crate) ibp_prior: Option<&'a [f64]>,
1021    pub(crate) compact_index: usize,
1022    /// #1026 — when `true`, atom `k` is the ungated background tier: its gate is
1023    /// the constant `1`, so its logit-JVP `da_k/dl_k` is identically zero (the
1024    /// compact row is left untouched / zero).
1025    pub(crate) ungated: bool,
1026}
1027
1028/// Fill the single compact logit-JVP row for active atom `k`, using the
1029/// per-mode assignment sensitivity `da_k/dl_k` contracted into the decoded /
1030/// fitted-corrected output direction. This is the active-set analogue of
1031/// [`fill_assignment_logit_jvp_rows`]: it reproduces that function's diagonal
1032/// logit row exactly for the atom `k`, but writes into a compact position of a
1033/// heterogeneous-`q` row block instead of the dense full-`K` Jacobian. `fitted`
1034/// is the row's *active-set* reconstruction so the softmax cross term
1035/// `(decoded_k − fitted)` is consistent with the curvature the compact block
1036/// carries.
1037pub(crate) fn fill_active_atom_logit_jvp(
1038    input: ActiveAtomLogitJvp<'_>,
1039    jac_compact: &mut Array2<f64>,
1040) {
1041    let ActiveAtomLogitJvp {
1042        mode,
1043        k,
1044        logit_k,
1045        a_k,
1046        decoded_k,
1047        fitted,
1048        ibp_prior,
1049        compact_index,
1050        ungated,
1051    } = input;
1052    let p = fitted.len();
1053    // #1026 — an ungated atom's gate is constant, so its logit-JVP is zero; leave
1054    // its compact row untouched (the buffer row is pre-zeroed by the caller).
1055    if ungated {
1056        return;
1057    }
1058    match mode {
1059        AssignmentMode::Softmax { temperature, .. } => {
1060            // da_k/dl_k contracted: a_k (decoded_k − fitted) / τ.
1061            let inv_tau = 1.0 / temperature;
1062            for out_col in 0..p {
1063                jac_compact[[compact_index, out_col]] =
1064                    a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
1065            }
1066        }
1067        AssignmentMode::IBPMap { temperature, .. } => {
1068            // z_k = σ(l_k/τ)·π_k ⇒ dz_k/dl_k = a_k(π_k − a_k)/(π_k τ) · π_k form
1069            // (matches `fill_assignment_logit_jvp_rows`).
1070            let inv_tau = 1.0 / temperature;
1071            let prior =
1072                ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
1073            let pi_k = prior[k];
1074            let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1075            let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1076            for out_col in 0..p {
1077                jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
1078            }
1079        }
1080        AssignmentMode::JumpReLU {
1081            temperature,
1082            threshold,
1083        } => {
1084            // The data-fit Jacobian follows the hard forward gate. Below the
1085            // threshold the reconstruction contribution is exactly zero, so the
1086            // data-fit logit derivative must also be zero. Band-only atoms stay
1087            // in the compact row for prior terms, not phantom reconstruction
1088            // slope.
1089            if logit_k <= threshold {
1090                return;
1091            }
1092            let inv_tau = 1.0 / temperature;
1093            let activation = gam_linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
1094            let da = activation * (1.0 - activation) * inv_tau;
1095            for out_col in 0..p {
1096                jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
1097            }
1098        }
1099    }
1100}
1101
1102pub(crate) fn fill_assignment_logit_jvp_rows(
1103    mode: AssignmentMode,
1104    logits: ArrayView1<'_, f64>,
1105    assignments: ArrayView1<'_, f64>,
1106    decoded: ArrayView2<'_, f64>,
1107    fitted: ArrayView1<'_, f64>,
1108    ibp_prior: Option<&[f64]>,
1109    // #1026 — per-atom ungated flags (length `K`). An ungated atom's gate is
1110    // constant, so its logit-JVP row is identically zero (skipped below). Empty
1111    // ⇒ no atom is ungated (the historical path, bit-identical).
1112    ungated: &[bool],
1113    local_jac: &mut Array2<f64>,
1114) {
1115    let is_ungated = |k: usize| ungated.get(k).copied().unwrap_or(false);
1116    match mode {
1117        AssignmentMode::Softmax { temperature, .. } => {
1118            if assignments.len() == 1 {
1119                return;
1120            }
1121            // da_k/dl_j = a_k (1[k=j] - a_j) / tau, contracted against
1122            // the assignment-weighted fitted row. The dense row layout uses
1123            // the reference-logit chart, so only columns `0..K-1` are free;
1124            // the final reference logit is fixed at zero and has no row.
1125            let inv_tau = 1.0 / temperature;
1126            for logit_col in 0..assignments.len() - 1 {
1127                if is_ungated(logit_col) {
1128                    continue;
1129                }
1130                for out_col in 0..fitted.len() {
1131                    local_jac[[logit_col, out_col]] = assignments[logit_col]
1132                        * (decoded[[logit_col, out_col]] - fitted[out_col])
1133                        * inv_tau;
1134                }
1135            }
1136        }
1137        AssignmentMode::IBPMap { temperature, .. } => {
1138            // Truncated-IBP concrete relaxation: z_k = σ(l_k/τ) · π_k where
1139            // π_k is the stick-breaking prior. Thus
1140            // dz_k/dl_k = σ(l/τ)(1-σ(l/τ))/τ · π_k = a_k(π_k - a_k)/(π_k τ).
1141            let inv_tau = 1.0 / temperature;
1142            let prior = ibp_prior
1143                .expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
1144            for logit_col in 0..assignments.len() {
1145                if is_ungated(logit_col) {
1146                    continue;
1147                }
1148                let pi_k = prior[logit_col];
1149                let a_k = assignments[logit_col];
1150                let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1151                let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1152                for out_col in 0..fitted.len() {
1153                    local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
1154                }
1155            }
1156        }
1157        AssignmentMode::JumpReLU {
1158            temperature,
1159            threshold,
1160        } => {
1161            // Data-fit sensitivity follows the hard forward gate: rows at or
1162            // below the threshold have zero reconstruction value and therefore
1163            // zero data-fit logit derivative. The wider machine-precision prior
1164            // support is a compact-layout/prior rule, not a data-fit STE.
1165            let inv_tau = 1.0 / temperature;
1166            for logit_col in 0..assignments.len() {
1167                if is_ungated(logit_col) || logits[logit_col] <= threshold {
1168                    continue;
1169                }
1170                let activation = gam_linalg::utils::stable_logistic(
1171                    (logits[logit_col] - threshold) * inv_tau,
1172                );
1173                let da = activation * (1.0 - activation) * inv_tau;
1174                for out_col in 0..fitted.len() {
1175                    local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
1176                }
1177            }
1178        }
1179    }
1180}
1181
1182pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
1183    let mut out = Array1::<f64>::zeros(logits.len());
1184    for row in 0..logits.nrows() {
1185        let start = row * logits.ncols();
1186        for col in 0..logits.ncols() {
1187            out[start + col] = logits[[row, col]];
1188        }
1189    }
1190    out
1191}
1192
1193pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
1194    for row in 0..assignment.n_obs() {
1195        validate_finite_logits(assignment.logits.row(row), row)
1196            .expect("assignment logits must be finite");
1197    }
1198    let target = flat_logits(assignment.logits.view());
1199    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1200        return 0.0;
1201    }
1202    match assignment.mode {
1203        AssignmentMode::Softmax {
1204            temperature,
1205            sparsity,
1206        } => {
1207            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1208            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1209            penalty.value(target.view(), rho_view.view())
1210        }
1211        AssignmentMode::IBPMap {
1212            temperature,
1213            alpha,
1214            learnable_alpha,
1215        } => {
1216            let mut penalty = IBPAssignmentPenalty::new(
1217                assignment.k_atoms(),
1218                alpha,
1219                temperature,
1220                learnable_alpha,
1221            );
1222            let rho_view = if learnable_alpha {
1223                Array1::from_vec(vec![rho.log_lambda_sparse])
1224            } else {
1225                // Keep the fixed-alpha value path on the same weighting branch as
1226                // assignment_prior_grad_hdiag; that gradient path owns the
1227                // lambda_sparse convention for IBP assignment sparsity.
1228                penalty.weight = rho.lambda_sparse();
1229                Array1::zeros(0)
1230            };
1231            penalty.value(target.view(), rho_view.view())
1232        }
1233        AssignmentMode::JumpReLU {
1234            temperature,
1235            threshold,
1236        } => {
1237            // Sparsity penalty uses the same threshold-centered surrogate and
1238            // machine-precision support as its gradient/Hessian. Data-fit
1239            // reconstruction remains hard-gated by `jumprelu_row`.
1240            let sparsity_strength = rho.lambda_sparse();
1241            let mut acc = 0.0;
1242            for &logit in target.iter() {
1243                if jumprelu_in_optimization_band(logit, threshold, temperature) {
1244                    acc += gam_linalg::utils::stable_logistic((logit - threshold) / temperature);
1245                }
1246            }
1247            sparsity_strength * acc
1248        }
1249    }
1250}
1251
1252pub(crate) fn assignment_prior_log_strength_derivative(
1253    assignment: &SaeAssignment,
1254    rho: &SaeManifoldRho,
1255) -> f64 {
1256    for row in 0..assignment.n_obs() {
1257        validate_finite_logits(assignment.logits.row(row), row)
1258            .expect("assignment logits must be finite");
1259    }
1260    let target = flat_logits(assignment.logits.view());
1261    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1262        return 0.0;
1263    }
1264    match assignment.mode {
1265        AssignmentMode::Softmax { .. } | AssignmentMode::JumpReLU { .. } => {
1266            assignment_prior_value(assignment, rho)
1267        }
1268        AssignmentMode::IBPMap {
1269            temperature,
1270            alpha,
1271            learnable_alpha,
1272        } => {
1273            let mut penalty = IBPAssignmentPenalty::new(
1274                assignment.k_atoms(),
1275                alpha,
1276                temperature,
1277                learnable_alpha,
1278            );
1279            if learnable_alpha {
1280                let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1281                penalty.grad_rho(target.view(), rho_view.view())[0]
1282            } else {
1283                penalty.weight = rho.lambda_sparse();
1284                penalty.value(target.view(), Array1::<f64>::zeros(0).view())
1285            }
1286        }
1287    }
1288}
1289
1290pub(crate) fn assignment_prior_log_strength_hdiag(
1291    assignment: &SaeAssignment,
1292    rho: &SaeManifoldRho,
1293) -> Result<Array1<f64>, String> {
1294    for row in 0..assignment.n_obs() {
1295        validate_finite_logits(assignment.logits.row(row), row)?;
1296    }
1297    let target = flat_logits(assignment.logits.view());
1298    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1299        return Ok(Array1::<f64>::zeros(target.len()));
1300    }
1301    match assignment.mode {
1302        AssignmentMode::Softmax {
1303            temperature,
1304            sparsity,
1305        } => {
1306            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1307            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1308            penalty
1309                .hessian_diag(target.view(), rho_view.view())
1310                .ok_or_else(|| {
1311                    "softmax assignment log-strength hessian diag unavailable".to_string()
1312                })
1313        }
1314        AssignmentMode::JumpReLU {
1315            temperature,
1316            threshold,
1317        } => {
1318            let sparsity_strength = rho.lambda_sparse();
1319            let inv_tau = 1.0 / temperature;
1320            let inv_tau2 = inv_tau * inv_tau;
1321            let mut d = Array1::<f64>::zeros(target.len());
1322            for idx in 0..target.len() {
1323                let logit = target[idx];
1324                if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1325                    continue;
1326                }
1327                let activation =
1328                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1329                let slope = activation * (1.0 - activation);
1330                d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1331            }
1332            Ok(d)
1333        }
1334        AssignmentMode::IBPMap {
1335            temperature,
1336            alpha,
1337            learnable_alpha,
1338        } => {
1339            let mut penalty = IBPAssignmentPenalty::new(
1340                assignment.k_atoms(),
1341                alpha,
1342                temperature,
1343                learnable_alpha,
1344            );
1345            if learnable_alpha {
1346                let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1347                Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
1348            } else {
1349                penalty.weight = rho.lambda_sparse();
1350                penalty
1351                    .hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
1352                    .ok_or_else(|| {
1353                        "IBP assignment log-strength hessian diag unavailable".to_string()
1354                    })
1355            }
1356        }
1357    }
1358}
1359
1360pub(crate) fn assignment_prior_log_strength_target_mixed(
1361    assignment: &SaeAssignment,
1362    rho: &SaeManifoldRho,
1363) -> Result<Array1<f64>, String> {
1364    for row in 0..assignment.n_obs() {
1365        validate_finite_logits(assignment.logits.row(row), row)?;
1366    }
1367    let target = flat_logits(assignment.logits.view());
1368    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1369        return Ok(Array1::<f64>::zeros(target.len()));
1370    }
1371    match assignment.mode {
1372        AssignmentMode::IBPMap {
1373            temperature,
1374            alpha,
1375            learnable_alpha: true,
1376        } => {
1377            let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
1378            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1379            Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
1380        }
1381        _ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
1382    }
1383}
1384
1385pub(crate) fn assignment_prior_grad_hdiag(
1386    assignment: &SaeAssignment,
1387    rho: &SaeManifoldRho,
1388) -> Result<(Array1<f64>, Array1<f64>), String> {
1389    for row in 0..assignment.n_obs() {
1390        validate_finite_logits(assignment.logits.row(row), row)?;
1391    }
1392    let target = flat_logits(assignment.logits.view());
1393    let mut grad = Array1::<f64>::zeros(target.len());
1394    let mut diag = Array1::<f64>::zeros(target.len());
1395    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1396        return Ok((grad, diag));
1397    }
1398    let (sparsity_grad, sparsity_diag) = match assignment.mode {
1399        AssignmentMode::Softmax {
1400            temperature,
1401            sparsity,
1402        } => {
1403            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1404            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1405            let g = penalty.grad_target(target.view(), rho_view.view());
1406            let d = penalty
1407                .hessian_diag(target.view(), rho_view.view())
1408                .ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
1409            (g, d)
1410        }
1411        AssignmentMode::IBPMap {
1412            temperature,
1413            alpha,
1414            learnable_alpha,
1415        } => {
1416            // Scale the IBP assignment-sparsity prior by `lambda_sparse`, exactly
1417            // like the Softmax and JumpReLU branches do (Softmax folds it into the
1418            // penalty's rho coordinate, JumpReLU multiplies `sparsity_strength`).
1419            // Previously the IBP penalty used its hardcoded `weight = 1.0` and the
1420            // `rho.log_lambda_sparse` coordinate never reached it (the rho_view was
1421            // empty for the common `learnable_alpha = false` config), so the prior
1422            // ran at full strength with no way to dial it down — and its
1423            // Beta-Bernoulli BCE energy `−mass·ln π_k − (n−mass)·ln(1−π_k)` toward
1424            // the self-referential empirical active fraction `π_k` has its global
1425            // minimum at the all-off gate, so at full weight it over-shrank the
1426            // assignment off both atoms even with a truth-seeded decoder (#853).
1427            // Routing `lambda_sparse` into the penalty weight makes the prior a
1428            // genuine, user-controllable lever balanced against the data fit.
1429            let mut penalty = IBPAssignmentPenalty::new(
1430                assignment.k_atoms(),
1431                alpha,
1432                temperature,
1433                learnable_alpha,
1434            );
1435            // When `alpha` is learnable, `log_lambda_sparse` already modulates
1436            // it through `resolved_alpha(rho)`, so the weight stays 1.0 to avoid
1437            // double-counting that coordinate. Only when `alpha` is fixed (so the
1438            // sparse coordinate would otherwise be ignored entirely) does
1439            // `lambda_sparse` become the prior's weight lever.
1440            let rho_view = if learnable_alpha {
1441                Array1::from_vec(vec![rho.log_lambda_sparse])
1442            } else {
1443                penalty.weight = rho.lambda_sparse();
1444                Array1::zeros(0)
1445            };
1446            let g = penalty.grad_target(target.view(), rho_view.view());
1447            let d = penalty
1448                .hessian_diag(target.view(), rho_view.view())
1449                .ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
1450            (g, d)
1451        }
1452        AssignmentMode::JumpReLU {
1453            temperature,
1454            threshold,
1455        } => {
1456            // Gradient and exact diagonal Hessian of the sparsity value's
1457            // threshold-centered surrogate σ((l−θ)/τ), using the same
1458            // machine-precision support as the value path. Data-fit JVP support
1459            // is narrower and follows the hard forward gate.
1460            let sparsity_strength = rho.lambda_sparse();
1461            let inv_tau = 1.0 / temperature;
1462            let inv_tau2 = inv_tau * inv_tau;
1463            let mut g = Array1::<f64>::zeros(target.len());
1464            let mut d = Array1::<f64>::zeros(target.len());
1465            for idx in 0..target.len() {
1466                let logit = target[idx];
1467                if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1468                    continue;
1469                }
1470                let activation =
1471                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1472                let slope = activation * (1.0 - activation);
1473                g[idx] = sparsity_strength * slope * inv_tau;
1474                d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1475            }
1476            (g, d)
1477        }
1478    };
1479    grad += &sparsity_grad;
1480    diag += &sparsity_diag;
1481    // #1026/#1033 — a FIXED logit (an ungated atom's, or every atom's under
1482    // frozen routing) is not a free parameter, so it carries NO sparsity-prior
1483    // gradient or curvature. Zero its flat columns (`flat_logits` is row-major
1484    // `row*K + atom`) so the assembled `gt` and `htt` logit slots stay zero —
1485    // matching the zero logit-JVP. The column-separable IBP / JumpReLU priors are
1486    // per-atom, so zeroing one atom's columns leaves the others' prior intact;
1487    // under frozen routing ALL atoms' logit columns are zeroed (the whole routing
1488    // is a fixed predicted function, not optimized).
1489    if assignment.has_ungated() || assignment.routing_is_frozen() {
1490        let k = assignment.k_atoms();
1491        for idx in 0..grad.len() {
1492            if assignment.logit_is_fixed(idx % k) {
1493                grad[idx] = 0.0;
1494                diag[idx] = 0.0;
1495            }
1496        }
1497    }
1498    Ok((grad, diag))
1499}
1500
1501/// Build the exact IBP `hessian_diag` logit third-derivative channels (#1006)
1502/// for the SAE log-det adjoint Γ, using the SAME penalty configuration —
1503/// `alpha`/`tau`/`learnable_alpha` and the `lambda_sparse` weight convention —
1504/// that [`assignment_prior_grad_hdiag`] assembles into `htt`. Returns `None`
1505/// for non-IBP assignment modes (no cross-row empirical-π coupling to correct).
1506pub(crate) fn ibp_assignment_third_channels(
1507    assignment: &SaeAssignment,
1508    rho: &SaeManifoldRho,
1509) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
1510    let AssignmentMode::IBPMap {
1511        temperature,
1512        alpha,
1513        learnable_alpha,
1514    } = assignment.mode
1515    else {
1516        return Ok(None);
1517    };
1518    for row in 0..assignment.n_obs() {
1519        validate_finite_logits(assignment.logits.row(row), row)?;
1520    }
1521    let target = flat_logits(assignment.logits.view());
1522    let mut penalty =
1523        IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
1524    // Mirror assignment_prior_grad_hdiag exactly: when alpha is learnable the
1525    // sparse coordinate already modulates it through resolved_alpha(rho), so the
1526    // weight stays 1.0; otherwise lambda_sparse becomes the prior's weight lever.
1527    let rho_view = if learnable_alpha {
1528        Array1::from_vec(vec![rho.log_lambda_sparse])
1529    } else {
1530        penalty.weight = rho.lambda_sparse();
1531        Array1::zeros(0)
1532    };
1533    let mut channels = penalty.hessian_diag_logit_third_channels(target.view(), rho_view.view());
1534    // #1026/#1033 — zero the log-det third-derivative channels of FIXED-logit
1535    // atoms (ungated, or all atoms under frozen routing) so the #1006 θ-adjoint
1536    // differentiates the SAME (fixed-logit-zeroed) `htt` that
1537    // `assignment_prior_grad_hdiag` assembled. `k_max` columns, row-major `N·K`
1538    // for the per-(row,atom) arrays and length-`K` for the per-column ones.
1539    if assignment.has_ungated() || assignment.routing_is_frozen() {
1540        let k = channels.k_max;
1541        for idx in 0..channels.z_jac.len() {
1542            if assignment.logit_is_fixed(idx % k) {
1543                channels.z_jac[idx] = 0.0;
1544                channels.local_logit_third[idx] = 0.0;
1545                channels.m_channel[idx] = 0.0;
1546                channels.logit_curvature[idx] = 0.0;
1547            }
1548        }
1549        for atom in 0..k {
1550            if assignment.logit_is_fixed(atom) {
1551                channels.cross_row_d[atom] = 0.0;
1552                channels.cross_row_dd[atom] = 0.0;
1553            }
1554        }
1555    }
1556    Ok(Some(channels))
1557}
1558
1559/// #1026 hybrid curved + linear-tail adjudication for one SAE atom slot.
1560///
1561/// A hybrid dictionary lets each atom slot be either a CURVED atom (its fitted
1562/// `latent_dim ≥ 1` manifold chart, whose decoded image may turn) or its LINEAR
1563/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
1564/// `γ(t) = t·b`, zero turning). The two are nested: the linear atom is exactly
1565/// the curved family restricted to its straight sub-model, so a hybrid slot
1566/// cannot lose to pure-linear at matched actives — it strictly generalizes it.
1567///
1568/// This is the single call the SAE fitter makes per atom to choose the split by
1569/// EVIDENCE rather than fiat. It packages the atom's two already-fitted
1570/// candidates — each scored on the COMMON rank-aware Laplace scale (`−V = NLE`,
1571/// lower wins, identical to the union/mixture rungs) on the same rows — and
1572/// routes them through [`select_hybrid_atom`]. The curved candidate's fitted
1573/// turning `Θ` (from
1574/// [`crate::chart_canonicalization::d1_atom_fitted_turning`]) enters
1575/// as the decision feature: a `Θ → 0` atom yields to the cheaper linear tail by
1576/// construction (the dominance floor — a curved atom buys nothing on a straight
1577/// feature), a high-`Θ` atom takes the curved parameterization when its
1578/// curvature lowers the NLE by more than its extra-parameter price (the `Θ/√ε`
1579/// crossover).
1580///
1581/// `manifold` is the atom's fitted chart manifold; a non-curveable (already
1582/// Euclidean-flat) chart can only present the linear candidate, which this
1583/// helper enforces by ignoring any curved candidate offered for a flat chart —
1584/// a flat chart has no curvature to price, so the linear special case is its
1585/// only honest parameterization. Curveable charts present both candidates.
1586///
1587/// # Wiring into the fitter (the one call into `sae_manifold.rs`)
1588///
1589/// The post-fit pass in `sae_manifold.rs` already computes each d=1 atom's
1590/// fitted turning `Θ` (the read-only EV-vs-Θ diagnostic). To make the split
1591/// load-bearing, that pass supplies, per atom, the curved-candidate NLE +
1592/// parameter count + `Θ` and the linear-candidate NLE + parameter count (both
1593/// fitted on the atom's rows), and calls this helper; the returned
1594/// [`HybridAtomChoice`] tells the fitter which parameterization to keep for that
1595/// slot. The fitting of the two candidates lives in `sae_manifold.rs` (the
1596/// manifold-chart fitter); the SELECTION/scoring lives here.
1597pub fn select_hybrid_atom_parameterization(
1598    manifold: &LatentManifold,
1599    curved: Option<HybridAtomCandidate>,
1600    linear: HybridAtomCandidate,
1601) -> HybridAtomChoice {
1602    // A flat (Euclidean) chart has no curvature to price: its only honest
1603    // parameterization is the linear special case, so any curved candidate
1604    // offered for it is dropped before the evidence comparison. Curveable charts
1605    // (Circle / Sphere / Torus / curved products) present both candidates.
1606    let curved = if manifold.is_euclidean() {
1607        None
1608    } else {
1609        curved
1610    };
1611    let candidates: Vec<HybridAtomCandidate> = match curved {
1612        Some(c) => vec![linear, c],
1613        None => vec![linear],
1614    };
1615    // `candidates` is never empty (it always contains the linear candidate), so
1616    // the selector always returns a choice.
1617    select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
1618}
1619
1620#[cfg(test)]
1621mod ibp_prior_614_tests {
1622    // #614: `ibp_stick_breaking_prior` used to compute `π_k = (α/(α+1))^k` with
1623    // `π_0 = 1`, i.e. an UNSHRUNK first atom — the prior mean of no stick at all,
1624    // which broke α's role as an IBP concentration parameter. The consistent
1625    // truncated-IBP stick-breaking prior mean is `π_k = (α/(α+1))^{k+1}`, the
1626    // expectation of the product of (k+1) i.i.d. Beta(α,1) stick means, so EVERY
1627    // atom (including the first) carries one stick of shrinkage. This test pins
1628    // that contract so the regression cannot silently return.
1629    use super::*;
1630
1631    fn ratio(alpha: f64) -> f64 {
1632        alpha / (alpha + 1.0)
1633    }
1634
1635    #[test]
1636    fn first_atom_is_shrunk_not_unity() {
1637        // The #614 defect: π_0 must equal the single-stick mean α/(α+1), NOT 1.0.
1638        for &alpha in &[0.1_f64, 0.5, 1.0, 2.0, 5.0] {
1639            let prior = ordered_geometric_shrinkage_prior(8, alpha);
1640            let r = ratio(alpha);
1641            assert!(
1642                (prior[0] - r).abs() < 1e-12,
1643                "π_0 must be the single-stick mean α/(α+1)={r} (was the unshrunk 1.0 in #614); got {}",
1644                prior[0]
1645            );
1646            assert!(
1647                prior[0] < 1.0,
1648                "first atom must be shrunk (π_0<1) for alpha={alpha}; got {}",
1649                prior[0]
1650            );
1651        }
1652    }
1653
1654    #[test]
1655    fn prior_is_consistent_geometric_product_mean() {
1656        // π_k = (α/(α+1))^{k+1} exactly, and every successive ratio equals α/(α+1).
1657        for &alpha in &[0.3_f64, 1.0, 4.0] {
1658            let k = 12;
1659            let prior = ordered_geometric_shrinkage_prior(k, alpha);
1660            let r = ratio(alpha);
1661            for j in 0..k {
1662                let expected = r.powi((j + 1) as i32);
1663                assert!(
1664                    (prior[j] - expected).abs() < 1e-12 * expected.max(1.0),
1665                    "alpha={alpha} π_{j}: expected {expected}, got {}",
1666                    prior[j]
1667                );
1668            }
1669            // Strictly decreasing (ordered shrinkage), no plateau at the head.
1670            for j in 1..k {
1671                assert!(
1672                    prior[j] < prior[j - 1],
1673                    "alpha={alpha}: prior must strictly decrease at index {j}"
1674                );
1675            }
1676        }
1677    }
1678
1679    #[test]
1680    fn alpha_behaves_as_concentration() {
1681        // Larger α => heavier mass / slower decay: π_0 increases toward 1 and the
1682        // tail (e.g. π_4) carries more mass. This is the IBP-concentration role
1683        // the #614 fix restored.
1684        let lo = ordered_geometric_shrinkage_prior(8, 0.5);
1685        let hi = ordered_geometric_shrinkage_prior(8, 5.0);
1686        assert!(
1687            hi[0] > lo[0],
1688            "larger alpha must raise π_0 (concentration): {} vs {}",
1689            hi[0],
1690            lo[0]
1691        );
1692        assert!(
1693            hi[4] > lo[4],
1694            "larger alpha must put more mass in the tail: {} vs {}",
1695            hi[4],
1696            lo[4]
1697        );
1698    }
1699}
1700
1701#[cfg(test)]
1702mod hybrid_split_tests {
1703    use super::*;
1704    use gam_solve::evidence::HybridAtomParam;
1705
1706    #[test]
1707    fn flat_chart_drops_curved_candidate_and_keeps_linear() {
1708        // A Euclidean chart has no curvature: even if a curved candidate with a
1709        // lower NLE is offered, the helper drops it (a flat chart cannot honestly
1710        // present a curved parameterization).
1711        let linear = HybridAtomCandidate::linear(100.0, 2);
1712        let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
1713        let choice =
1714            select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
1715        assert!(choice.param.is_linear());
1716    }
1717
1718    #[test]
1719    fn curveable_chart_selects_curved_when_turning_pays() {
1720        // A Circle chart presents both candidates; a turning feature whose curved
1721        // fit beats the linear secant on evidence selects curved.
1722        let linear = HybridAtomCandidate::linear(100.0, 2);
1723        let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
1724        let choice = select_hybrid_atom_parameterization(
1725            &LatentManifold::Circle {
1726                period: 2.0 * std::f64::consts::PI,
1727            },
1728            Some(curved),
1729            linear,
1730        );
1731        assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
1732    }
1733
1734    #[test]
1735    fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
1736        let linear = HybridAtomCandidate::linear(33.0, 2);
1737        let choice = select_hybrid_atom_parameterization(
1738            &LatentManifold::Circle {
1739                period: 2.0 * std::f64::consts::PI,
1740            },
1741            None,
1742            linear,
1743        );
1744        assert!(choice.param.is_linear());
1745        assert_eq!(choice.num_parameters, 2);
1746    }
1747}
1748
1749#[cfg(test)]
1750mod frozen_routing_1033_tests {
1751    //! #1033 — the FROZEN (amortized) routing mechanism: once installed, the
1752    //! per-row gate is a ρ-invariant function of the FROZEN predicted logits and
1753    //! is DECOUPLED from any subsequent update to the free `self.logits` (the
1754    //! inner-fit logit drift the outer ρ-search would otherwise re-incur every
1755    //! eval). These are deterministic mechanism invariants — no inner fit — so
1756    //! they pin the load-bearing freeze properties without the cluster.
1757    use super::*;
1758
1759    fn ibp_assignment(n: usize, k: usize) -> SaeAssignment {
1760        let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1761            0.3 + 0.05 * (i as f64) - 0.1 * (kk as f64)
1762        });
1763        let coords: Vec<Array2<f64>> = (0..k)
1764            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1765            .collect();
1766        // learnable_alpha = false: alpha is ρ-independent, isolating the routing.
1767        SaeAssignment::from_blocks_with_mode(
1768            logits,
1769            coords,
1770            AssignmentMode::ibp_map(0.5, 1.0, false),
1771        )
1772        .unwrap()
1773    }
1774
1775    #[test]
1776    fn frozen_routing_decouples_gates_from_logit_updates_1033() {
1777        let (n, k) = (6usize, 3usize);
1778        let mut a = ibp_assignment(n, k)
1779            .freeze_routing_from_current_logits()
1780            .unwrap();
1781        assert!(a.routing_is_frozen());
1782        // Gates BEFORE mutating the free logits.
1783        let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
1784        let before: Vec<Array1<f64>> = (0..n)
1785            .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1786            .collect();
1787        // Simulate an inner-fit logit update (what the ρ-search would otherwise do
1788        // every eval): perturb every free logit substantially.
1789        a.logits.mapv_inplace(|v| v + 5.0);
1790        let after: Vec<Array1<f64>> = (0..n)
1791            .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1792            .collect();
1793        // FROZEN routing reads the snapshot, so the gates are UNCHANGED by the
1794        // free-logit perturbation — the routing is decoupled from inner-fit drift.
1795        for r in 0..n {
1796            for kk in 0..k {
1797                assert_eq!(
1798                    before[r][kk], after[r][kk],
1799                    "row {r} atom {kk}: frozen-routing gate must be UNCHANGED by a free-logit \
1800                     update (decoupled from inner-fit drift); {} vs {}",
1801                    before[r][kk], after[r][kk]
1802                );
1803            }
1804        }
1805    }
1806
1807    #[test]
1808    fn frozen_routing_gates_are_rho_invariant_1033() {
1809        let (n, k) = (5usize, 2usize);
1810        let a = ibp_assignment(n, k)
1811            .freeze_routing_from_current_logits()
1812            .unwrap();
1813        // Two different ρ (different sparse + smooth strengths). With frozen routing
1814        // and learnable_alpha=false, the gate value must be identical at both ρ.
1815        let rho_a = SaeManifoldRho::new(
1816            (1e-3_f64).ln(),
1817            (1e-2_f64).ln(),
1818            vec![Array1::<f64>::zeros(1); k],
1819        );
1820        let rho_b = SaeManifoldRho::new(
1821            (1e3_f64).ln(),
1822            (1e1_f64).ln(),
1823            vec![Array1::<f64>::zeros(1); k],
1824        );
1825        for r in 0..n {
1826            let ga = a.try_assignments_row_for_rho(r, &rho_a).unwrap();
1827            let gb = a.try_assignments_row_for_rho(r, &rho_b).unwrap();
1828            for kk in 0..k {
1829                assert_eq!(
1830                    ga[kk], gb[kk],
1831                    "row {r} atom {kk}: frozen-routing gate must be ρ-INVARIANT (the n-independence \
1832                     lever); {} at ρ_a vs {} at ρ_b",
1833                    ga[kk], gb[kk]
1834                );
1835            }
1836        }
1837    }
1838
1839    #[test]
1840    fn frozen_routing_fixes_all_logits_and_thaw_restores_free_path_1033() {
1841        let (n, k) = (4usize, 3usize);
1842        let mut a = ibp_assignment(n, k)
1843            .freeze_routing_from_current_logits()
1844            .unwrap();
1845        // Under frozen routing EVERY logit is fixed (not a free Newton coord).
1846        let mask = a.fixed_logit_mask();
1847        assert_eq!(mask.len(), k);
1848        assert!(
1849            mask.iter().all(|&f| f),
1850            "frozen routing must fix ALL logits"
1851        );
1852        for kk in 0..k {
1853            assert!(
1854                a.logit_is_fixed(kk),
1855                "atom {kk} logit must be fixed under frozen routing"
1856            );
1857        }
1858        // Thawing restores the free-logit path (no fixed logits, no ungated).
1859        a.thaw_routing();
1860        assert!(!a.routing_is_frozen());
1861        assert!(
1862            a.fixed_logit_mask().iter().all(|&f| !f),
1863            "thaw must restore the free-logit path"
1864        );
1865    }
1866
1867    #[test]
1868    fn frozen_routing_rejects_softmax_1033() {
1869        let (n, k) = (4usize, 3usize);
1870        let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.1 * (i as f64) - 0.05 * (kk as f64));
1871        let coords: Vec<Array2<f64>> = (0..k)
1872            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1873            .collect();
1874        let a = SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::softmax(1.0))
1875            .unwrap();
1876        // Softmax + frozen routing is rejected (the coupled-simplex entropy
1877        // majorizer would be inconsistent with a frozen, non-optimized routing).
1878        assert!(
1879            a.freeze_routing_from_current_logits().is_err(),
1880            "frozen routing under Softmax must be rejected (simplex entropy-majorizer coupling)"
1881        );
1882    }
1883}
1884
1885#[cfg(test)]
1886mod fill_into_buffer_1557_tests {
1887    //! #1557 — the fill-into-caller-buffer variant
1888    //! [`SaeAssignment::try_assignments_row_for_rho_into`] must produce
1889    //! BIT-IDENTICAL output to the allocating
1890    //! [`SaeAssignment::try_assignments_row_for_rho`] across every assignment
1891    //! mode (Softmax, IBPMap, JumpReLU), the #1026 ungated case, and the K==1
1892    //! edge. Exact `==` on f64 — not an approximate tolerance — because the
1893    //! `_into` path is a pure allocation-elision refactor and any numeric drift
1894    //! is a regression.
1895    use super::*;
1896
1897    fn build(n: usize, k: usize, mode: AssignmentMode) -> SaeAssignment {
1898        // Deterministic, asymmetric logits/coords so every atom takes a distinct
1899        // value (no accidental ties masking an index bug).
1900        let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1901            0.37 + 0.11 * (i as f64) - 0.23 * (kk as f64)
1902        });
1903        let coords: Vec<Array2<f64>> = (0..k)
1904            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| 0.1 + 0.05 * (i as f64)))
1905            .collect();
1906        SaeAssignment::from_blocks_with_mode(logits, coords, mode).unwrap()
1907    }
1908
1909    fn rho(k: usize) -> SaeManifoldRho {
1910        SaeManifoldRho::new(
1911            (1e-2_f64).ln(),
1912            (1e-1_f64).ln(),
1913            vec![Array1::<f64>::zeros(1); k],
1914        )
1915    }
1916
1917    fn assert_into_matches_alloc(a: &SaeAssignment) {
1918        let n = a.n_obs();
1919        let k = a.k_atoms();
1920        let rho = rho(k);
1921        let mut scratch = vec![f64::NAN; k];
1922        for row in 0..n {
1923            let allocated = a.try_assignments_row_for_rho(row, &rho).unwrap();
1924            // Pre-fill with NaN so a partial write (e.g. a JumpReLU below-threshold
1925            // entry left untouched) is caught as a mismatch, not silently passed.
1926            for s in scratch.iter_mut() {
1927                *s = f64::NAN;
1928            }
1929            a.try_assignments_row_for_rho_into(row, &rho, &mut scratch)
1930                .unwrap();
1931            assert_eq!(allocated.len(), k);
1932            for kk in 0..k {
1933                assert_eq!(
1934                    allocated[kk], scratch[kk],
1935                    "row {row} atom {kk}: _into must be BIT-IDENTICAL to the allocating \
1936                     try_assignments_row_for_rho; got {} vs {}",
1937                    allocated[kk], scratch[kk]
1938                );
1939            }
1940        }
1941    }
1942
1943    #[test]
1944    fn softmax_into_is_bit_identical() {
1945        assert_into_matches_alloc(&build(7, 4, AssignmentMode::softmax(0.8)));
1946    }
1947
1948    #[test]
1949    fn ibp_map_into_is_bit_identical() {
1950        // Both learnable and fixed alpha exercise the resolved-alpha branch.
1951        assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, false)));
1952        assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, true)));
1953    }
1954
1955    #[test]
1956    fn jumprelu_into_is_bit_identical() {
1957        // Threshold chosen so SOME atoms fall below it (the untouched-entry path)
1958        // and some clear it (the sigmoid path) — both branches are exercised.
1959        assert_into_matches_alloc(&build(7, 5, AssignmentMode::jumprelu(0.9, 0.2)));
1960    }
1961
1962    #[test]
1963    fn ungated_into_is_bit_identical() {
1964        // #1026 ungated overwrite under a gate-style mode (IBP/JumpReLU allow it).
1965        let a = build(6, 4, AssignmentMode::ibp_map(0.6, 1.1, false))
1966            .with_ungated(vec![false, true, false, true])
1967            .unwrap();
1968        assert_into_matches_alloc(&a);
1969        let j = build(6, 4, AssignmentMode::jumprelu(0.9, 0.15))
1970            .with_ungated(vec![true, false, true, false])
1971            .unwrap();
1972        assert_into_matches_alloc(&j);
1973    }
1974
1975    #[test]
1976    fn k_equals_one_into_is_bit_identical() {
1977        // Softmax K==1 hits the fixed-unit early return; IBP/JumpReLU K==1 keep a
1978        // free per-atom gate and fall through to the real row functions.
1979        assert_into_matches_alloc(&build(5, 1, AssignmentMode::softmax(1.0)));
1980        assert_into_matches_alloc(&build(5, 1, AssignmentMode::ibp_map(0.7, 1.0, false)));
1981        assert_into_matches_alloc(&build(5, 1, AssignmentMode::jumprelu(0.8, 0.1)));
1982    }
1983}