Skip to main content

gam_sae/
assignment.rs

1//! Assignment gates and sparsity-prior helpers for the SAE manifold term.
2//! Mechanically split from `sae_manifold.rs`.
3
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use gam_solve::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
7use gam_terms::analytic_penalties::{
8    AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
9    SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
10};
11use gam_terms::latent::{LatentCoordValues, LatentIdMode, LatentManifold};
12use crate::manifold::SaeManifoldRho;
13
14/// #976 Layer-1 guard: cap on one accepted iteration's assignment-logit
15/// update, in units of the gate temperature τ (the gate's natural length
16/// scale — every assignment mode reads logits through `σ(·/τ)` /
17/// `softmax(·/τ)`). A 4τ move spans the gate's whole soft range, so healthy
18/// convergence is never throttled, but no single inner iteration can carry a
19/// gate from contention to numerically-zero support: a collapse takes
20/// multiple accepted iterations, which guarantees the per-iteration
21/// active-mass guard observes the decay before it completes. The clamp is
22/// applied where the step is realised; when it binds, the realised objective
23/// is evaluated on the clamped state, so the Armijo comparison stays
24/// value-consistent (the unclamped quadratic model is merely conservative,
25/// and step halvings shrink the trial below the cap).
26pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
27
28/// #976 Layer-1 guard: re-seed budget per atom per joint fit. One second
29/// chance from a fresh basin; a second breach means the collapse is (locally)
30/// the objective's verdict at the current hyperparameters, which is recorded
31/// as a terminal collapse event and left for the structure-search death move
32/// to adjudicate — re-seeding in a loop would fight the optimizer.
33pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
34
35/// #976 Layer-1 guard (decoder arm): an atom whose decoder block Frobenius norm
36/// has fallen to this fraction of the dictionary's MEDIAN decoder norm carries
37/// no material reconstruction signal — it has degenerated to (near-)zero output
38/// and decodes the same nothing as every other collapsed atom. This is the
39/// real-data K>1 failure that the gate-mass floor cannot see: the assignment
40/// gates can stay spread across rows (mass guard satisfied) while the decoders
41/// all collapse to ~0, giving EV≈0 and a rank-deficient per-row coordinate
42/// Hessian on every row (the 0→K·n evidence-deflation jump). The statistic is a
43/// RATIO to the dictionary median so it is scale-free and never fires for a
44/// uniformly-small but well-conditioned decoder; only an atom that has fallen
45/// far behind its peers is caught. By construction this is a no-op for K=1
46/// (a single atom has no peer to fall behind, and the median equals its own
47/// norm), so the K=1 path is byte-for-byte unchanged.
48pub(crate) const SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO: f64 = 1.0e-3;
49
50/// #976 / #1117 K>1 robustness: bounded DICTIONARY-level multi-start budget for
51/// the simultaneous co-collapse arm (the EV-floor branch of
52/// [`crate::manifold::SaeManifoldTerm::enforce_decoder_norm_guard`]).
53/// Distinct from the per-atom [`SAE_ATOM_COLLAPSE_RESEED_BUDGET`] (= 1): that
54/// budget governs reseeding ONE atom's gate logits against an optimizer that
55/// keeps killing it, where a loop would fight the optimizer. A co-collapse
56/// reseed is categorically different — it is a full-dictionary multi-start that
57/// re-diversifies ALL atoms onto distinct principal directions of a FRESHLY
58/// recomputed residual, so successive attempts explore genuinely different
59/// basins. A single such reseed empirically cannot always break a K≥3 three-way
60/// basin (identical (K, seed) flips EV≈0.40 ↔ 0.00), so this arm gets a small
61/// bounded budget of independent multi-starts. It is consumed ONLY when the
62/// whole dictionary's reconstruction EV is at or below the data-derived collapse
63/// bar (`collapse_ev_bar` = `SAE_COLLAPSE_PCA_EV_FRACTION` × the rank-K PCA
64/// ceiling, i.e. less than half the variance any rank-K linear dictionary could
65/// reach) — the old absolute magic EV floor was REPLACED by that data-derived
66/// bar, not deleted. A no-op for any healthy fit (real OLMo K=1 ~0.22, K=2
67/// ~0.40, both well above the bar); on the rough patches of a hard K≥3 fit a
68/// transient sub-bar EV can still consume the budget before the fit recovers.
69pub(crate) const SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET: usize = 3;
70
71/// Machine-precision support cutoff for the smooth JumpReLU assignment prior,
72/// in units of the gate temperature below the hard threshold. The forward gate
73/// remains hard-zero at and below `threshold`, but the prior value/gradient and
74/// compact Newton layout keep every logit with `(logit - threshold)/tau > -36`.
75/// At the excluded edge `sigma(-36) ~= 2e-16`, so dropped value/gradient/Hessian
76/// terms are below f64 noise instead of creating an algorithmic discontinuity.
77pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
78
79/// Shared support predicate for JumpReLU optimization inclusion. This is
80/// strictly weaker than the hard forward gate `logit > threshold`, which still
81/// governs data-fit reconstruction and its logit JVP.
82#[inline]
83pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
84    (logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
85}
86
87/// Assignment prior/relaxation used by [`SaeAssignment`].
88#[derive(Debug, Clone, Copy)]
89pub enum AssignmentMode {
90    /// Row-wise simplex assignment with entropy sparsity.
91    Softmax { temperature: f64, sparsity: f64 },
92    /// Deterministic concrete relaxation of a truncated IBP active set: each
93    /// atom's gate is the INDEPENDENT prior-shrunk activation
94    /// `σ(logit/temperature) · π_k`, with `π_k = (α/(α+1))^{k+1}` the
95    /// stick-breaking prior mean (see [`ibp_map_row`]). These are per-atom gates,
96    /// NOT mixture/simplex responsibilities: they are computed independently per
97    /// column and do NOT sum to 1 across atoms (there is no row normalization).
98    /// Each `a_k ∈ [0, π_k] ⊂ [0, 1)` is the relaxed "atom k is active in this
99    /// row" indicator of a truncated IBP, not a share of a unit reconstruction
100    /// budget.
101    IBPMap {
102        temperature: f64,
103        alpha: f64,
104        learnable_alpha: bool,
105    },
106    /// Hard-thresholded bounded gate: each atom is off (gate = 0) when its logit
107    /// is at or below `threshold`, and on with a threshold-centered shifted
108    /// sigmoid `σ((logit − threshold) / temperature) ∈ [0.5, 1)` above it.
109    ///
110    /// #1777 RENAMED from `JumpReLU` (an inaccurate name): this is NOT the
111    /// literature JumpReLU activation `z·1[z>θ]`, which carries the thresholded
112    /// MAGNITUDE `z`. This mode is a thresholded-logistic GATE (a hard-sigmoid
113    /// gate): it carries no magnitude at all — its output is a bounded `[0, 1)`
114    /// indicator. `ThresholdGate` names it for what it is. It is a member of the
115    /// gate family (softmax simplex / IBP sigmoid / this hard gate); reconstruction
116    /// magnitude lives entirely in the decoder curve `g_k(t) = φ(t)ᵀ B_k`. The
117    /// discontinuity at `threshold` (0 → 0.5) is the intended "jump".
118    ///
119    /// BACK-COMPAT: the constructor [`Self::threshold_gate`] is the primary
120    /// spelling; [`Self::jumprelu`] is retained as a deprecated alias, and the FFI
121    /// string parser accepts both `"threshold_gate"` and the legacy `"jumprelu"`.
122    ThresholdGate { temperature: f64, threshold: f64 },
123}
124
125/// #1033 — the fixed-form predictor that produces the ρ-invariant FROZEN routing
126/// (amortized routing). Both forms are NO-learned-net deterministic functions of
127/// the current dictionary; they differ in how faithfully they track the
128/// dictionary as it evolves across outer iterates. Kept as alternatives so the
129/// accuracy gate can pick whichever passes the fit-quality bar (the cheap
130/// `Snapshot` if it suffices, the `ChartGeometry` distill otherwise).
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum RoutingPredictor {
133    /// Snapshot the current (converged) logits as the frozen routing — the
134    /// cheapest fixed-form distill, exact at the dictionary it is taken from.
135    /// Goes stale as the dictionary moves (needs a refresh to track), so it is the
136    /// MVP/baseline form.
137    Snapshot,
138    /// Re-derive the per-(row, atom) routing logit from the atom's encode-chart
139    /// geometry against the CURRENT dictionary: encode each row to its predicted
140    /// coord `t̂`, reconstruct the amplitude-1 image `γ_k(t̂) = Bᵀφ(t̂)`, and map
141    /// the reconstruction ALIGNMENT to a logit. This tracks the dictionary
142    /// (a moved decoder changes `γ_k(t̂)` and hence the routing) without re-running
143    /// the free-logit inner solve, so it is the default-readiness form when the
144    /// snapshot proves too stale.
145    ChartGeometry,
146}
147
148impl AssignmentMode {
149    #[must_use]
150    pub fn softmax(temperature: f64) -> Self {
151        Self::Softmax {
152            temperature,
153            sparsity: 1.0,
154        }
155    }
156
157    #[must_use]
158    pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
159        Self::IBPMap {
160            temperature,
161            alpha,
162            learnable_alpha,
163        }
164    }
165
166    /// #1777 — construct the hard-sigmoid [`Self::ThresholdGate`] (the accurate
167    /// name for what was `jumprelu`). Primary spelling; [`Self::jumprelu`] is a
168    /// deprecated alias kept for back-compat.
169    #[must_use]
170    pub fn threshold_gate(temperature: f64, threshold: f64) -> Self {
171        Self::ThresholdGate {
172            temperature,
173            threshold,
174        }
175    }
176
177    /// Back-compat alias for [`Self::threshold_gate`] (#1777): the mode is a
178    /// hard-sigmoid gate, not the literature JumpReLU magnitude activation. Retained
179    /// (NOT `#[deprecated]`, since the workspace denies warnings and many callers
180    /// and the legacy `"jumprelu"` FFI token still use it) so existing code keeps
181    /// compiling; new code should prefer `threshold_gate`.
182    #[must_use]
183    pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
184        Self::threshold_gate(temperature, threshold)
185    }
186
187    pub fn temperature(&self) -> f64 {
188        match *self {
189            AssignmentMode::Softmax { temperature, .. }
190            | AssignmentMode::IBPMap { temperature, .. }
191            | AssignmentMode::ThresholdGate { temperature, .. } => temperature,
192        }
193    }
194
195    pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
196        if !(new_temperature.is_finite() && new_temperature > 0.0) {
197            return Err(format!(
198                "AssignmentMode: temperature must be finite and positive; got {new_temperature}"
199            ));
200        }
201        match self {
202            AssignmentMode::Softmax { temperature, .. }
203            | AssignmentMode::IBPMap { temperature, .. }
204            | AssignmentMode::ThresholdGate { temperature, .. } => {
205                *temperature = new_temperature;
206            }
207        }
208        Ok(())
209    }
210
211    pub(crate) fn validate(&self) -> Result<(), String> {
212        let temperature = self.temperature();
213        if !(temperature.is_finite() && temperature > 0.0) {
214            return Err(format!(
215                "AssignmentMode: temperature must be finite and positive; got {temperature}"
216            ));
217        }
218        match *self {
219            AssignmentMode::Softmax { sparsity, .. } => {
220                if !(sparsity.is_finite() && sparsity > 0.0) {
221                    return Err(format!(
222                        "AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
223                    ));
224                }
225            }
226            AssignmentMode::IBPMap { alpha, .. } => {
227                if !(alpha.is_finite() && alpha > 0.0) {
228                    return Err(format!(
229                        "AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
230                    ));
231                }
232            }
233            AssignmentMode::ThresholdGate { threshold, .. } => {
234                if !threshold.is_finite() {
235                    return Err(format!(
236                        "AssignmentMode::ThresholdGate: threshold must be finite; got {threshold}"
237                    ));
238                }
239            }
240        }
241        Ok(())
242    }
243
244    /// Resolve the effective truncated-IBP concentration `α` for this mode.
245    ///
246    /// `per_fit_override` is the #1777 PER-FIT override (from
247    /// [`SaeAssignment::ibp_alpha_override`]) and is the SOURCE OF TRUTH when set.
248    /// It falls back to the deprecated process-global [`ibp_alpha_override`] atomic
249    /// only when the per-fit field is unset, then to the mode's own `α` /
250    /// learnable schedule — so nothing breaks, but concurrent fits are isolatable
251    /// via the per-fit field.
252    pub(crate) fn resolved_ibp_alpha(
253        &self,
254        rho: &SaeManifoldRho,
255        per_fit_override: Option<f64>,
256    ) -> Option<f64> {
257        match *self {
258            AssignmentMode::IBPMap {
259                alpha,
260                learnable_alpha,
261                ..
262            } => Some(if let Some(over) = per_fit_override.or_else(ibp_alpha_override) {
263                // #1777 — the per-fit override (else the deprecated process-global
264                // one) flattens the ordered geometric prior π_k = (α/(α+1))^{k+1}
265                // so all K atoms can contribute to the reconstruction (the
266                // production α=1 gives a (0.5)^{k+1} schedule that structurally
267                // caps atoms 4..K → effective-K≈3). Forces the fixed value,
268                // bypassing the learnable schedule.
269                over
270            } else if learnable_alpha {
271                resolve_learnable_weight(alpha, rho.log_lambda_sparse)
272            } else {
273                alpha
274            }),
275            _ => None,
276        }
277    }
278}
279
280// #1026 — process-global IBP-α override (NaN sentinel = "unset → use the
281// AssignmentMode's compiled α"). Lets ONE wheel sweep the prior-flattening axis
282// from Python (`sae_set_ibp_alpha`) without recompiling the gam crate.
283//
284// CONCURRENCY WARNING: this is a PROCESS-GLOBAL atomic, not per-fit config. It is
285// read by `ibp_alpha_override` from every IBP assignment evaluation in the
286// process, so setting it affects ALL in-flight fits, not just the caller's. It is
287// therefore UNSAFE to use across concurrent / parallel in-process fits — one
288// fit's sweep value leaks into another's gates. It is safe only for serial,
289// whole-process sweeps (the single-wheel FFI sweep driver it exists for). This
290// should be migrated to per-fit configuration (threaded through `SaeManifoldRho`
291// / the AssignmentMode) before any concurrent multi-fit use; that refactor is
292// cross-cutting (FFI + term plumbing) and deliberately out of scope here.
293static IBP_ALPHA_OVERRIDE_BITS: std::sync::atomic::AtomicU64 =
294    std::sync::atomic::AtomicU64::new(0x7ff8_0000_0000_0000);
295
296pub(crate) fn ibp_alpha_override() -> Option<f64> {
297    let v = f64::from_bits(IBP_ALPHA_OVERRIDE_BITS.load(std::sync::atomic::Ordering::Relaxed));
298    if v.is_finite() && v > 0.0 {
299        Some(v)
300    } else {
301        None
302    }
303}
304
305/// Set (or, with a non-finite/non-positive value, clear) the process-global
306/// IBP-α override. Called from the gamfit Python FFI sweep driver.
307///
308/// PROCESS-GLOBAL / NOT CONCURRENCY-SAFE: this mutates one process-wide atomic
309/// read by every IBP assignment in the process. Calling it while any other fit is
310/// running leaks the override into that fit's gates. Use only for serial,
311/// whole-process sweeps; do not use across concurrent in-process fits. See the
312/// `IBP_ALPHA_OVERRIDE_BITS` note on migrating this to per-fit config.
313pub fn set_ibp_alpha_override(alpha: f64) {
314    IBP_ALPHA_OVERRIDE_BITS.store(alpha.to_bits(), std::sync::atomic::Ordering::Relaxed);
315}
316
317/// Per-row latent assignment state.
318///
319/// The stored assignment parameter is `logits`; non-negative assignments are
320/// derived by row-wise softmax, independent IBP-MAP sigmoid active indicators,
321/// or JumpReLU gates. Softmax logits are canonicalized to the reference chart
322/// `logits[K - 1] = 0`, so the row-local Newton coordinates contain only the
323/// first `K - 1` logits (`0` coordinates for `K = 1`). Gate-style modes keep
324/// all `K` logits as identifiable scalar parameters. `coords[k]` holds
325/// `t_{.,k}` for atom `k`.
326#[derive(Debug, Clone)]
327pub struct SaeAssignment {
328    pub logits: Array2<f64>,
329    pub coords: Vec<LatentCoordValues>,
330    pub mode: AssignmentMode,
331    /// #1026 — per-atom UNGATED flag (length `K`, default all-`false`). An
332    /// ungated atom is the dense linear/background tier: its per-row gate is
333    /// fixed at `a_k ≡ 1` (it contributes `γ_k(t_k)` to EVERY row, unweighted),
334    /// it is excluded from the other atoms' gate (for the column-separable
335    /// IBP / JumpReLU modes the remaining atoms are computed independently, so
336    /// they are unaffected), and its logit is NOT a free parameter — its
337    /// logit-JVP, sparsity-prior gradient/curvature, and softmax majorizer
338    /// contributions are all zero, leaving its logit slot an inert
339    /// (ridge-regularized) null direction in the per-row Newton block. This lets
340    /// the linear tier carry FULL-RANK reconstructible variance
341    /// (`fitted = γ_ungated(x) + Σ_{gated} a_k·γ_k(x)`) so a linear SAE can reach
342    /// the rank-(K·d) PCA ceiling, while the gated curved atoms still add sparse
343    /// structure on the residual (#1026 routing-bound finding).
344    pub ungated: Vec<bool>,
345    /// #1033 — AMORTIZED / FROZEN routing. When `Some`, this `(n, K)` matrix is a
346    /// ρ-INVARIANT predicted routing (the amortized `x → logits` map distilled
347    /// from the frozen dictionary): the gates are computed from THESE logits
348    /// instead of the free `self.logits`, and the logits are NOT optimized by the
349    /// inner Newton (their gradient/curvature/prior contributions are zeroed,
350    /// exactly as for [`Self::ungated`]). This is the generalization of an ungated
351    /// atom from "pin the gate at 1" to "pin the gate at the predicted value": it
352    /// makes the per-row routing a fixed function of `x` + the frozen dictionary,
353    /// so the outer ρ-search reuses ONE routing instead of re-solving per-row
354    /// gates every outer eval — the n-independent-outer-loop lever (#1033). `None`
355    /// is the historical free-logit path (bit-identical).
356    pub frozen_logits: Option<Array2<f64>>,
357    /// #1777 PER-FIT IBP-α override — the source of truth for the truncated-IBP
358    /// concentration `α` when set, replacing the process-global
359    /// [`set_ibp_alpha_override`] atomic. `Some(α)` forces the fixed value
360    /// (bypassing the learnable schedule), scoped to THIS assignment/fit so
361    /// concurrent in-process fits are isolated. `None` ⇒ fall back to the
362    /// deprecated process-global override, then to the `AssignmentMode`'s own `α` /
363    /// learnable schedule (bit-identical to the historical path when neither
364    /// override is set). Read via [`Self::resolved_ibp_alpha`]; set from the FFI
365    /// through the term's `set_fit_config`.
366    pub ibp_alpha_override: Option<f64>,
367}
368
369impl SaeAssignment {
370    #[must_use = "build error must be handled"]
371    pub fn new(
372        logits: Array2<f64>,
373        coords: Vec<LatentCoordValues>,
374        temperature: f64,
375    ) -> Result<Self, String> {
376        Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
377    }
378
379    #[must_use = "build error must be handled"]
380    pub fn with_mode(
381        mut logits: Array2<f64>,
382        coords: Vec<LatentCoordValues>,
383        mode: AssignmentMode,
384    ) -> Result<Self, String> {
385        mode.validate()?;
386        let n = logits.nrows();
387        let k = logits.ncols();
388        if coords.len() != k {
389            return Err(format!(
390                "SaeAssignment::new: coords length {} must equal K={k}",
391                coords.len()
392            ));
393        }
394        for (atom, coord) in coords.iter().enumerate() {
395            if coord.n_obs() != n {
396                return Err(format!(
397                    "SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
398                    coord.n_obs()
399                ));
400            }
401        }
402        for row in 0..n {
403            validate_finite_logits(logits.row(row), row)?;
404        }
405        if matches!(mode, AssignmentMode::Softmax { .. }) {
406            canonicalize_softmax_logits(&mut logits);
407        }
408        Ok(Self {
409            logits,
410            coords,
411            mode,
412            ungated: vec![false; k],
413            frozen_logits: None,
414            ibp_alpha_override: None,
415        })
416    }
417
418    /// #1033 — install a ρ-INVARIANT FROZEN routing (the amortized predicted
419    /// logits; see [`SaeAssignment::frozen_logits`]). `predicted` must be
420    /// `(n, K)`. With routing frozen, the gates are computed from `predicted` and
421    /// the logits are excluded from the inner Newton (their gradient/curvature are
422    /// inert, like an ungated atom's). Passing `None` restores the free-logit
423    /// path.
424    #[must_use = "build error must be handled"]
425    pub fn with_frozen_routing(mut self, predicted: Option<Array2<f64>>) -> Result<Self, String> {
426        if let Some(ref p) = predicted {
427            if p.dim() != (self.n_obs(), self.k_atoms()) {
428                return Err(format!(
429                    "SaeAssignment::with_frozen_routing: predicted shape {:?} must be ({}, {})",
430                    p.dim(),
431                    self.n_obs(),
432                    self.k_atoms()
433                ));
434            }
435            if matches!(self.mode, AssignmentMode::Softmax { .. }) {
436                return Err(
437                    "SaeAssignment::with_frozen_routing: frozen routing under Softmax is rejected \
438                     — the coupled simplex's entropy majorizer is assembled over the logits, which \
439                     a frozen (non-optimized) routing would leave inconsistent; this separable-mode \
440                     contract supports IBP-MAP and JumpReLU, whose per-atom gates have no \
441                     simplex-coupled curvature to skip"
442                        .to_string(),
443                );
444            }
445            for row in 0..p.nrows() {
446                validate_finite_logits(p.row(row), row)?;
447            }
448        }
449        self.frozen_logits = predicted;
450        Ok(self)
451    }
452
453    /// Whether the per-row routing is FROZEN (amortized) rather than free-logit.
454    pub fn routing_is_frozen(&self) -> bool {
455        self.frozen_logits.is_some()
456    }
457
458    /// The active routing logits for `row`: the frozen/predicted logits when
459    /// routing is frozen (#1033), else the free `self.logits`. This is the SINGLE
460    /// source the gate value reads, so freezing routing changes every gate
461    /// consistently.
462    pub(crate) fn routing_logits_row(&self, row: usize) -> ArrayView1<'_, f64> {
463        match self.frozen_logits {
464            Some(ref f) => f.row(row),
465            None => self.logits.row(row),
466        }
467    }
468
469    /// Whether atom `k`'s logit is held fixed (not a free Newton parameter): true
470    /// for an ungated atom (#1026, gate pinned at 1) OR when routing is frozen
471    /// (#1033, gate pinned at the predicted value). Both share the same inert
472    /// treatment — zero logit-JVP, zero sparsity-prior gradient/curvature, zero
473    /// softmax majorizer — so the logit slot never moves.
474    pub(crate) fn logit_is_fixed(&self, k: usize) -> bool {
475        self.routing_is_frozen() || self.ungated.get(k).copied().unwrap_or(false)
476    }
477
478    /// Per-atom mask (length `K`) of [`Self::logit_is_fixed`] — the logit slots
479    /// that are NOT free Newton parameters (ungated #1026 and/or frozen-routing
480    /// #1033). Precompute once per assembly and pass to the logit-JVP fillers so
481    /// the data-fit Jacobian zeroes those rows. Under frozen routing every entry
482    /// is `true`; with only ungated atoms it equals `ungated`; otherwise all
483    /// `false` (the historical free-logit path).
484    pub(crate) fn fixed_logit_mask(&self) -> Vec<bool> {
485        if self.routing_is_frozen() {
486            vec![true; self.k_atoms()]
487        } else {
488            self.ungated.clone()
489        }
490    }
491
492    /// #1033 — install the simplest faithful AMORTIZED routing predictor: a
493    /// fixed-form DISTILL of the current dictionary's routing, namely the current
494    /// (converged) logits SNAPSHOTTED as the ρ-invariant frozen routing. This is
495    /// the `x → logits` map "evaluated once at the frozen dictionary" — the
496    /// routing the dictionary already expresses — held fixed so the outer ρ-search
497    /// reuses it instead of re-optimizing the gates at every ρ. (A richer
498    /// predictor that recomputes logits from `x` via the encode-atlas chart
499    /// geometry is a later refinement; snapshotting the converged routing is the
500    /// exact fixed-point it would target at the frozen dictionary.) Rejected for
501    /// Softmax for the same simplex-coupling reason as [`Self::with_frozen_routing`].
502    #[must_use = "build error must be handled"]
503    pub fn freeze_routing_from_current_logits(self) -> Result<Self, String> {
504        let snapshot = self.logits.clone();
505        self.with_frozen_routing(Some(snapshot))
506    }
507
508    /// #1033 — in-place variant of [`Self::freeze_routing_from_current_logits`]
509    /// for callers holding `&mut SaeAssignment` (e.g. inside a `SaeManifoldTerm`),
510    /// where moving the assignment out is awkward. Same contract: snapshot the
511    /// current logits as the ρ-invariant frozen routing; reject Softmax.
512    pub fn freeze_routing_in_place(&mut self) -> Result<(), String> {
513        if matches!(self.mode, AssignmentMode::Softmax { .. }) {
514            return Err(
515                "SaeAssignment::freeze_routing_in_place: frozen routing under Softmax is rejected \
516                 (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
517                    .to_string(),
518            );
519        }
520        let snapshot = self.logits.clone();
521        for row in 0..snapshot.nrows() {
522            validate_finite_logits(snapshot.row(row), row)?;
523        }
524        self.frozen_logits = Some(snapshot);
525        Ok(())
526    }
527
528    /// #1033 — install an explicit predicted routing in place (the
529    /// [`RoutingPredictor::ChartGeometry`] output), `&mut self` variant of
530    /// [`Self::with_frozen_routing`]. `predicted` must be `(n, K)`; rejects Softmax
531    /// (separable-mode contract) and non-finite predictions.
532    pub fn set_frozen_routing_in_place(&mut self, predicted: Array2<f64>) -> Result<(), String> {
533        if predicted.dim() != (self.n_obs(), self.k_atoms()) {
534            return Err(format!(
535                "SaeAssignment::set_frozen_routing_in_place: predicted shape {:?} must be ({}, {})",
536                predicted.dim(),
537                self.n_obs(),
538                self.k_atoms()
539            ));
540        }
541        if matches!(self.mode, AssignmentMode::Softmax { .. }) {
542            return Err(
543                "SaeAssignment::set_frozen_routing_in_place: frozen routing under Softmax is \
544                 rejected (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
545                    .to_string(),
546            );
547        }
548        for row in 0..predicted.nrows() {
549            validate_finite_logits(predicted.row(row), row)?;
550        }
551        self.frozen_logits = Some(predicted);
552        Ok(())
553    }
554
555    /// #1033 — lift the frozen routing, restoring the free-logit search path.
556    pub fn thaw_routing(&mut self) {
557        self.frozen_logits = None;
558    }
559
560    /// #1026 — designate which atoms are UNGATED (the dense linear/background
561    /// tier; see [`SaeAssignment::ungated`]). `flags` must have length `K`.
562    ///
563    /// Ungating is defined for the COLUMN-SEPARABLE gate modes (IBP-MAP and
564    /// JumpReLU): each atom's gate is an independent per-atom function of its own
565    /// logit, so pinning one atom to `a_k ≡ 1` leaves every other atom's gate
566    /// exactly as computed. Softmax is a coupled simplex (`Σ_k a_k = 1` over all
567    /// `K`), so a unit gate for one atom is only well defined relative to a
568    /// gated-subset renormalization that must also be reflected in the logit-JVP
569    /// and the entropy majorizer; this constructor's contract is restricted to
570    /// the separable modes, and an ungated atom under Softmax is REJECTED here so
571    /// the inner solve never runs on a value/gradient-mismatched gate. Callers
572    /// wanting a dense background tier under Softmax route it as an IBP-MAP or
573    /// JumpReLU atom.
574    #[must_use = "build error must be handled"]
575    pub fn with_ungated(mut self, flags: Vec<bool>) -> Result<Self, String> {
576        if flags.len() != self.k_atoms() {
577            return Err(format!(
578                "SaeAssignment::with_ungated: flags length {} must equal K={}",
579                flags.len(),
580                self.k_atoms()
581            ));
582        }
583        if matches!(self.mode, AssignmentMode::Softmax { .. }) && flags.iter().any(|&u| u) {
584            return Err(
585                "SaeAssignment::with_ungated: an ungated atom under Softmax routing is \
586                 rejected — the coupled simplex requires a gated-subset renormalization \
587                 reflected in the logit-JVP and entropy majorizer, which this separable-mode \
588                 contract does not perform; route a dense background tier as IBP-MAP or JumpReLU"
589                    .to_string(),
590            );
591        }
592        self.ungated = flags;
593        Ok(self)
594    }
595
596    /// Whether any atom is ungated (the #1026 background tier is engaged).
597    pub fn has_ungated(&self) -> bool {
598        self.ungated.iter().any(|&u| u)
599    }
600
601    pub fn n_obs(&self) -> usize {
602        self.logits.nrows()
603    }
604
605    pub fn k_atoms(&self) -> usize {
606        self.logits.ncols()
607    }
608
609    pub fn total_coord_dim(&self) -> usize {
610        self.coords.iter().map(|c| c.latent_dim()).sum()
611    }
612
613    pub fn assignment_coord_dim(&self) -> usize {
614        match self.mode {
615            AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
616            AssignmentMode::IBPMap { .. } | AssignmentMode::ThresholdGate { .. } => self.k_atoms(),
617        }
618    }
619
620    pub fn row_block_dim(&self) -> usize {
621        self.assignment_coord_dim() + self.total_coord_dim()
622    }
623
624    pub fn coord_offsets(&self) -> Vec<usize> {
625        let mut out = Vec::with_capacity(self.k_atoms());
626        let mut cursor = self.assignment_coord_dim();
627        for coord in &self.coords {
628            out.push(cursor);
629            cursor += coord.latent_dim();
630        }
631        out
632    }
633
634    pub fn assignments(&self) -> Array2<f64> {
635        let n = self.n_obs();
636        let k = self.k_atoms();
637        let mut out = Array2::<f64>::zeros((n, k));
638        for row in 0..n {
639            let a = self.assignments_row(row);
640            for atom in 0..k {
641                out[[row, atom]] = a[atom];
642            }
643        }
644        out
645    }
646
647    pub fn assignments_row(&self, row: usize) -> Array1<f64> {
648        self.try_assignments_row(row)
649            .expect("assignment logits must be finite")
650    }
651
652    pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
653        self.try_assignments_row_with_alpha(row, None)
654    }
655
656    /// #1777 — the effective truncated-IBP `α` for this assignment at `rho`,
657    /// honoring the PER-FIT [`Self::ibp_alpha_override`] first (source of truth),
658    /// then the deprecated process-global override, then the mode's own schedule.
659    /// The single seam every gate/jet/prior site reads so the per-fit override is
660    /// applied consistently. `None` for non-IBP modes.
661    pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
662        self.mode.resolved_ibp_alpha(rho, self.ibp_alpha_override)
663    }
664
665    /// #1777 — install (or clear, with `None`) the PER-FIT IBP-α override on this
666    /// assignment. Source of truth used by [`Self::resolved_ibp_alpha`]; the FFI
667    /// reaches it through the term's `set_fit_config`.
668    pub fn set_ibp_alpha_override(&mut self, alpha: Option<f64>) {
669        self.ibp_alpha_override = alpha;
670    }
671
672    pub(crate) fn try_assignments_row_for_rho(
673        &self,
674        row: usize,
675        rho: &SaeManifoldRho,
676    ) -> Result<Array1<f64>, String> {
677        self.try_assignments_row_with_alpha(row, self.resolved_ibp_alpha(rho))
678    }
679
680    fn try_assignments_row_with_alpha(
681        &self,
682        row: usize,
683        resolved_ibp_alpha: Option<f64>,
684    ) -> Result<Array1<f64>, String> {
685        // #1033 — read the ACTIVE routing logits: the ρ-invariant frozen/predicted
686        // logits when routing is frozen, else the free `self.logits`. This single
687        // source makes the gate value ρ-invariant under frozen routing (the
688        // amortized-routing lever) and bit-identical to the historical path when
689        // not frozen.
690        let routing = self.routing_logits_row(row);
691        validate_finite_logits(routing, row)?;
692        // Only Softmax collapses to a fixed assignment at K==1: its
693        // assignment_coord_dim is K-1 = 0, so there is no free logit. IBPMap and
694        // JumpReLU keep a free per-atom gate logit even at K==1
695        // (assignment_coord_dim = K = 1), so they must fall through to their real
696        // row functions or the logit would move the prior but not the gate.
697        if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
698            return Ok(Array1::from_vec(vec![1.0]));
699        }
700        let mut row_gates = match self.mode {
701            AssignmentMode::Softmax { temperature, .. } => softmax_row(routing, temperature),
702            AssignmentMode::IBPMap {
703                temperature, alpha, ..
704            } => ibp_map_row(routing, temperature, resolved_ibp_alpha.unwrap_or(alpha)),
705            AssignmentMode::ThresholdGate {
706                temperature,
707                threshold,
708            } => jumprelu_row(routing, temperature, threshold),
709        };
710        // #1026 — ungated (background-tier) atoms have a fixed unit gate. For the
711        // column-separable IBP / JumpReLU modes the other atoms' gates are
712        // computed independently above, so overwriting the ungated entries to 1.0
713        // leaves the gated atoms exactly as they were; the ungated atom then
714        // contributes `γ_k(t_k)` unweighted to every row. (Softmax + ungated is
715        // rejected at `with_ungated`, so no simplex renormalization is needed
716        // here.)
717        if self.has_ungated() {
718            for (k, gate) in row_gates.iter_mut().enumerate() {
719                if self.ungated[k] {
720                    *gate = 1.0;
721                }
722            }
723        }
724        Ok(row_gates)
725    }
726
727    /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_for_rho`].
728    ///
729    /// Writes the EXACT SAME per-atom assignment row into `out` (length
730    /// `k_atoms()`) instead of allocating a fresh `Array1`. Bit-identical to the
731    /// allocating path; intended for the hot per-row loops that immediately
732    /// consume the row, reusing a single scratch buffer across rows.
733    pub(crate) fn try_assignments_row_for_rho_into(
734        &self,
735        row: usize,
736        rho: &SaeManifoldRho,
737        out: &mut [f64],
738    ) -> Result<(), String> {
739        self.try_assignments_row_with_alpha_into(row, self.resolved_ibp_alpha(rho), out)
740    }
741
742    /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_with_alpha`].
743    ///
744    /// `out` must have length `k_atoms()`; it is fully overwritten with the same
745    /// values the allocating variant would return. Every branch (early-return
746    /// K==1 Softmax, the per-mode row math, the #1026 ungated overwrite) mirrors
747    /// the allocating path exactly so the two are bit-identical.
748    pub(crate) fn try_assignments_row_with_alpha_into(
749        &self,
750        row: usize,
751        resolved_ibp_alpha: Option<f64>,
752        out: &mut [f64],
753    ) -> Result<(), String> {
754        // `out` is sized `k_atoms()` by every caller; the per-mode helpers below
755        // fully overwrite indices `0..k_atoms()`.
756        let routing = self.routing_logits_row(row);
757        validate_finite_logits(routing, row)?;
758        // Mirror the allocating early-return: only Softmax collapses to a fixed
759        // unit assignment at K==1.
760        if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
761            out[0] = 1.0;
762            return Ok(());
763        }
764        match self.mode {
765            AssignmentMode::Softmax { temperature, .. } => {
766                softmax_row_into(routing, temperature, out)
767            }
768            AssignmentMode::IBPMap {
769                temperature, alpha, ..
770            } => ibp_map_row_into(
771                routing,
772                temperature,
773                resolved_ibp_alpha.unwrap_or(alpha),
774                out,
775            ),
776            AssignmentMode::ThresholdGate {
777                temperature,
778                threshold,
779            } => jumprelu_row_into(routing, temperature, threshold, out),
780        };
781        // #1026 — ungated (background-tier) atoms have a fixed unit gate, exactly
782        // as in the allocating path.
783        if self.has_ungated() {
784            for (k, gate) in out.iter_mut().enumerate() {
785                if self.ungated[k] {
786                    *gate = 1.0;
787                }
788            }
789        }
790        Ok(())
791    }
792
793    pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
794        let AssignmentMode::IBPMap {
795            temperature,
796            alpha,
797            learnable_alpha: true,
798        } = self.mode
799        else {
800            return false;
801        };
802        let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
803        self.mode = AssignmentMode::IBPMap {
804            temperature,
805            alpha: resolved_alpha,
806            learnable_alpha: false,
807        };
808        true
809    }
810
811    pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
812        let n = self.n_obs();
813        let k = self.k_atoms();
814        let mut out = Array2::<f64>::zeros((n, k));
815        for row in 0..n {
816            let a = self.try_assignments_row_for_rho(row, rho)?;
817            for atom in 0..k {
818                out[[row, atom]] = a[atom];
819            }
820        }
821        Ok(out)
822    }
823
824    /// Flatten extension coordinates in row-major SAE layout:
825    /// `(assignment chart_i, t_i0[0..d_0], ..., t_iK[0..d_K])` for every row.
826    /// Softmax contributes the first `K - 1` reference logits and omits the
827    /// fixed reference logit; gate-style assignment modes contribute all `K`
828    /// logits.
829    pub fn flatten_ext_coords(&self) -> Array1<f64> {
830        let n = self.n_obs();
831        let q = self.row_block_dim();
832        let k = self.k_atoms();
833        let assignment_dim = self.assignment_coord_dim();
834        let offsets = self.coord_offsets();
835        let mut out = Array1::<f64>::zeros(n * q);
836        for row in 0..n {
837            let base = row * q;
838            for atom in 0..assignment_dim {
839                out[base + atom] = self.logits[[row, atom]];
840            }
841            for atom in 0..k {
842                let d = self.coords[atom].latent_dim();
843                let t_row = self.coords[atom].row(row);
844                for axis in 0..d {
845                    out[base + offsets[atom] + axis] = t_row[axis];
846                }
847            }
848        }
849        out
850    }
851
852    #[must_use = "build error must be handled"]
853    pub fn from_blocks_with_mode(
854        logits: Array2<f64>,
855        coord_blocks: Vec<Array2<f64>>,
856        mode: AssignmentMode,
857    ) -> Result<Self, String> {
858        let coords = coord_blocks
859            .iter()
860            .map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
861            .collect();
862        Self::with_mode(logits, coords, mode)
863    }
864
865    #[must_use = "build error must be handled"]
866    pub fn from_blocks_with_mode_and_manifolds(
867        logits: Array2<f64>,
868        coord_blocks: Vec<Array2<f64>>,
869        manifolds: Vec<LatentManifold>,
870        mode: AssignmentMode,
871    ) -> Result<Self, String> {
872        if coord_blocks.len() != manifolds.len() {
873            return Err(format!(
874                "SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
875                coord_blocks.len(),
876                manifolds.len()
877            ));
878        }
879        let coords = coord_blocks
880            .iter()
881            .zip(manifolds)
882            .map(|(c, manifold)| {
883                LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
884            })
885            .collect();
886        Self::with_mode(logits, coords, mode)
887    }
888}
889
890pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
891    match mode {
892        AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
893        AssignmentMode::IBPMap {
894            temperature, alpha, ..
895        } => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
896        AssignmentMode::ThresholdGate { .. } => Array1::from_elem(k_atoms, 0.5),
897    }
898}
899
900pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
901    let k = logits.len();
902    let inv_tau = 1.0 / temperature;
903    let mut max_logit = f64::NEG_INFINITY;
904    for &v in logits.iter() {
905        max_logit = max_logit.max(v);
906    }
907    let mut out = Array1::<f64>::zeros(k);
908    let mut sum = 0.0;
909    for i in 0..k {
910        let v = ((logits[i] - max_logit) * inv_tau).exp();
911        out[i] = v;
912        sum += v;
913    }
914    assert!(sum.is_finite() && sum > 0.0);
915    for v in out.iter_mut() {
916        *v /= sum;
917    }
918    out
919}
920
921pub(crate) fn validate_finite_logits(
922    logits: ArrayView1<'_, f64>,
923    row: usize,
924) -> Result<(), String> {
925    for (col, &v) in logits.iter().enumerate() {
926        if !v.is_finite() {
927            return Err(format!(
928                "SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
929            ));
930        }
931    }
932    Ok(())
933}
934
935pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
936    let k = logits.ncols();
937    if k == 0 {
938        return;
939    }
940    if k == 1 {
941        logits.fill(0.0);
942        return;
943    }
944    for row in 0..logits.nrows() {
945        let reference = logits[[row, k - 1]];
946        for col in 0..k - 1 {
947            logits[[row, col]] -= reference;
948        }
949        logits[[row, k - 1]] = 0.0;
950    }
951}
952
953/// Truncated Indian-Buffet-Process stick-breaking prior *means*
954/// `π_k = E[∏_{j=0}^{k} v_j] = (α/(α+1))^{k+1}` for k = 0, .., K-1, with sticks
955/// `v_j ~ Beta(α, 1)` so `E[v_j] = α/(α+1)`. EVERY atom (including the first,
956/// `π_0 = α/(α+1)`) carries the consistent Beta(α, 1) shrinkage: there is no
957/// special-cased always-on base atom, so `α` behaves as a genuine IBP
958/// concentration — larger `α` ⇒ heavier mass / slower decay, `α → 0` ⇒ all mass
959/// collapses onto nothing, matching the stick-breaking limit. This is the
960/// deterministic MAP / mean-field form of the IBP prior (the closed form the
961/// analytic Newton / Hessian / Woodbury machinery differentiates); no sticks are
962/// *sampled* here, the per-atom weight is the exact expectation of the
963/// stick-breaking product. (#614: previously `π_0 = 1` left the first atom
964/// unshrunk, which is the prior mean of NO stick at all and broke α's role as a
965/// concentration; the consistent product mean restores genuine IBP semantics.)
966pub(crate) fn ordered_geometric_shrinkage_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
967    // Accumulate the geometric schedule `π_k = ratio^(k+1)` in LOG space so the
968    // prior stays a finite *soft* weight even for large `K`. The naive product
969    // `acc *= ratio` underflows to exact `0.0` once `ratio^(k+1) < f64::MIN_POSITIVE`
970    // (e.g. `(0.1/1.1)^320`), which would turn the soft shrinkage prior into a
971    // HARD mask: such atoms would receive zero assignment AND zero logit
972    // gradient (the gradient is multiplied by `π_k`), so they could never
973    // reactivate. Working in log-space and flooring the exponentiated weight at
974    // the smallest positive normal keeps every atom's gradient path alive while
975    // preserving the geometric ordering.
976    let mut out = Array1::<f64>::zeros(k_atoms);
977    let log_ratio = (alpha / (alpha + 1.0)).ln();
978    for k in 0..k_atoms {
979        // π_k = (α/(α+1))^{k+1}: the product of (k+1) i.i.d. Beta(α,1) stick
980        // means, so atom 0 is also shrunk by one stick (E[v_0] = α/(α+1)).
981        let log_pi = ((k + 1) as f64) * log_ratio;
982        out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
983    }
984    out
985}
986
987/// #1784 — K-aware default IBP concentration.
988///
989/// The ordered stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` decays
990/// GEOMETRICALLY in the atom INDEX, so a fixed small concentration (the
991/// historical default `α = 1`, i.e. the `(0.5)^{k+1}` schedule) collapses to a
992/// near-hard mask past atom ~3: a K-atom dictionary can then only ever place
993/// mass on its first handful of atoms. That is exactly why the manifold SAE
994/// UNDERFITS a linear dictionary of equal K on real activations, and why its
995/// late atoms carry zero mass and leave the per-row joint Hessian rank-deficient
996/// (the K = 128 `RemlConvergenceError`).
997///
998/// For a K-atom dictionary to actually USE all K atoms the IBP concentration must
999/// scale with K. Choosing `α` so the LAST atom retains prior mass
1000/// `π_{K-1} = (α/(α+1))^K ≈ e^{-1}` spans the whole dictionary while keeping the
1001/// prior monotone (still an honest ordered stick-breaking prior — no atom is
1002/// structurally masked). Solving `(α/(α+1))^K = e^{-1}` gives
1003/// `α = 1/(exp(1/K) − 1) ≈ K − 1/2`. Floored at `1.0` so `K = 1` keeps the
1004/// historical `α = 1`.
1005pub fn default_ibp_concentration_for_k_atoms(k_atoms: usize) -> f64 {
1006    let k = k_atoms.max(1) as f64;
1007    // π_{K-1} = (α/(α+1))^K = e^{-1}  ⇒  α = 1/(e^{1/K} − 1).
1008    let alpha = 1.0 / ((1.0 / k).exp() - 1.0);
1009    alpha.max(1.0)
1010}
1011
1012/// IBP-MAP row activations: per-atom sigmoid likelihood times the truncated
1013/// stick-breaking prior mean `π_k = (α/(α+1))^{k+1}`. With tied logits the prior
1014/// dominates and yields strictly decreasing activations in atom index, with the
1015/// first atom already shrunk by one Beta(α,1) stick mean (no unshrunk base atom).
1016pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
1017    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1018    let mut out = Array1::<f64>::zeros(logits.len());
1019    for i in 0..logits.len() {
1020        out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
1021    }
1022    out
1023}
1024
1025/// IBP-MAP activations together with the diagonal Jacobian `∂z_k/∂l_k`,
1026/// shared with the torch autograd `Function` so the Python IBP-Gumbel path
1027/// applies the same stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` and
1028/// temperature scaling as the Rust closed form. With `z_k = σ(l_k/τ)·π_k` the
1029/// per-atom derivative is
1030/// `σ(l_k/τ)(1 − σ(l_k/τ))·π_k / τ`; the map is diagonal in `k`, so the
1031/// Jacobian is returned as the per-atom diagonal vector.
1032#[must_use]
1033pub fn ibp_map_row_value_grad(
1034    logits: ArrayView1<'_, f64>,
1035    temperature: f64,
1036    alpha: f64,
1037) -> (Array1<f64>, Array1<f64>) {
1038    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1039    let inv_tau = 1.0 / temperature;
1040    let mut value = Array1::<f64>::zeros(logits.len());
1041    let mut grad = Array1::<f64>::zeros(logits.len());
1042    for i in 0..logits.len() {
1043        let sig = gam_linalg::utils::stable_logistic(logits[i] * inv_tau);
1044        value[i] = sig * prior[i];
1045        grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
1046    }
1047    (value, grad)
1048}
1049
1050pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
1051    let mut out = Array1::<f64>::zeros(logits.len());
1052    for i in 0..logits.len() {
1053        // Hard gate: strictly zero below threshold (the intended "jump"). Above
1054        // threshold the surrogate is centered at the threshold so the gate is
1055        // most informative exactly at the boundary it switches on:
1056        // σ((l−θ)/τ) ∈ [0.5, 1). Magnitude lives in the decoder, so the gate
1057        // stays bounded in [0, 1] by design.
1058        if logits[i] > threshold {
1059            out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1060        }
1061    }
1062    out
1063}
1064
1065/// Bounded threshold-gate activations together with the straight-through
1066/// derivative `∂a_k/∂l_k` of the smooth surrogate, shared with the torch
1067/// autograd `Function` so the torch `jumprelu` lane applies the SAME bounded
1068/// gate as the closed-form fit (`jumprelu_row`): `a_k = σ((l_k − θ_k)/τ)` for
1069/// `l_k > θ_k` and exactly `0` otherwise — magnitude lives in the decoder, the
1070/// gate stays in `[0, 1)`. The returned derivative is the smooth surrogate's
1071/// `σ'((l_k − θ_k)/τ)/τ` evaluated on BOTH sides of the jump (a straight-through
1072/// estimator: the hard forward has zero derivative below threshold, which would
1073/// permanently kill gradient flow to gated-off atoms). `∂a_k/∂θ_k` is the
1074/// negation of the returned logit derivative; callers negate. `thresholds` is
1075/// per-atom (the torch lane learns one threshold per atom); the scalar
1076/// closed-form threshold is the constant-vector special case and the value
1077/// arithmetic matches `jumprelu_row` exactly there.
1078#[must_use]
1079pub fn jumprelu_row_value_grad(
1080    logits: ArrayView1<'_, f64>,
1081    temperature: f64,
1082    thresholds: ArrayView1<'_, f64>,
1083) -> (Array1<f64>, Array1<f64>) {
1084    assert_eq!(
1085        logits.len(),
1086        thresholds.len(),
1087        "jumprelu_row_value_grad: logits/thresholds length mismatch"
1088    );
1089    let inv_tau = 1.0 / temperature;
1090    let mut value = Array1::<f64>::zeros(logits.len());
1091    let mut grad = Array1::<f64>::zeros(logits.len());
1092    for i in 0..logits.len() {
1093        let sig = gam_linalg::utils::stable_logistic((logits[i] - thresholds[i]) * inv_tau);
1094        if logits[i] > thresholds[i] {
1095            value[i] = sig;
1096        }
1097        grad[i] = sig * (1.0 - sig) * inv_tau;
1098    }
1099    (value, grad)
1100}
1101
1102// #1557 — fill-into-caller-buffer variants of the three per-mode row functions.
1103// These compute the EXACT SAME values as `softmax_row` / `ibp_map_row` /
1104// `jumprelu_row` (same arithmetic, same order of operations) but write into a
1105// caller-provided `&mut [f64]` slice instead of heap-allocating a fresh
1106// `Array1<f64>` per call. The hot per-row loops (loss eval, arrow/Schur row
1107// loops) call these with a reused scratch buffer, eliminating millions of tiny
1108// K-sized allocations while staying bit-identical to the allocating path.
1109// `out` must have length `logits.len()`; the slice is fully overwritten.
1110
1111pub(crate) fn softmax_row_into(logits: ArrayView1<'_, f64>, temperature: f64, out: &mut [f64]) {
1112    let k = logits.len();
1113    let inv_tau = 1.0 / temperature;
1114    let mut max_logit = f64::NEG_INFINITY;
1115    for &v in logits.iter() {
1116        max_logit = max_logit.max(v);
1117    }
1118    let mut sum = 0.0;
1119    for i in 0..k {
1120        let v = ((logits[i] - max_logit) * inv_tau).exp();
1121        out[i] = v;
1122        sum += v;
1123    }
1124    assert!(sum.is_finite() && sum > 0.0);
1125    for v in out.iter_mut() {
1126        *v /= sum;
1127    }
1128}
1129
1130pub(crate) fn ibp_map_row_into(
1131    logits: ArrayView1<'_, f64>,
1132    temperature: f64,
1133    alpha: f64,
1134    out: &mut [f64],
1135) {
1136    let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1137    for i in 0..logits.len() {
1138        out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
1139    }
1140}
1141
1142pub(crate) fn jumprelu_row_into(
1143    logits: ArrayView1<'_, f64>,
1144    temperature: f64,
1145    threshold: f64,
1146    out: &mut [f64],
1147) {
1148    for i in 0..logits.len() {
1149        // Match `jumprelu_row`: strictly zero below threshold, sigmoid surrogate
1150        // above. The buffer is fully overwritten (no read of prior contents).
1151        if logits[i] > threshold {
1152            out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1153        } else {
1154            out[i] = 0.0;
1155        }
1156    }
1157}
1158
1159pub(crate) struct ActiveAtomLogitJvp<'a> {
1160    pub(crate) mode: AssignmentMode,
1161    pub(crate) k: usize,
1162    pub(crate) logit_k: f64,
1163    pub(crate) a_k: f64,
1164    pub(crate) decoded_k: ArrayView1<'a, f64>,
1165    pub(crate) fitted: ArrayView1<'a, f64>,
1166    pub(crate) ibp_prior: Option<&'a [f64]>,
1167    pub(crate) compact_index: usize,
1168    /// #1026 — when `true`, atom `k` is the ungated background tier: its gate is
1169    /// the constant `1`, so its logit-JVP `da_k/dl_k` is identically zero (the
1170    /// compact row is left untouched / zero).
1171    pub(crate) ungated: bool,
1172}
1173
1174/// Fill the single compact logit-JVP row for active atom `k`, using the
1175/// per-mode assignment sensitivity `da_k/dl_k` contracted into the decoded /
1176/// fitted-corrected output direction. This is the active-set analogue of
1177/// [`fill_assignment_logit_jvp_rows`]: it reproduces that function's diagonal
1178/// logit row exactly for the atom `k`, but writes into a compact position of a
1179/// heterogeneous-`q` row block instead of the dense full-`K` Jacobian. `fitted`
1180/// is the row's *active-set* reconstruction so the softmax cross term
1181/// `(decoded_k − fitted)` is consistent with the curvature the compact block
1182/// carries.
1183pub(crate) fn fill_active_atom_logit_jvp(
1184    input: ActiveAtomLogitJvp<'_>,
1185    jac_compact: &mut Array2<f64>,
1186) {
1187    let ActiveAtomLogitJvp {
1188        mode,
1189        k,
1190        logit_k,
1191        a_k,
1192        decoded_k,
1193        fitted,
1194        ibp_prior,
1195        compact_index,
1196        ungated,
1197    } = input;
1198    let p = fitted.len();
1199    // #1026 — an ungated atom's gate is constant, so its logit-JVP is zero; leave
1200    // its compact row untouched (the buffer row is pre-zeroed by the caller).
1201    if ungated {
1202        return;
1203    }
1204    match mode {
1205        AssignmentMode::Softmax { temperature, .. } => {
1206            // da_k/dl_k contracted: a_k (decoded_k − fitted) / τ.
1207            let inv_tau = 1.0 / temperature;
1208            for out_col in 0..p {
1209                jac_compact[[compact_index, out_col]] =
1210                    a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
1211            }
1212        }
1213        AssignmentMode::IBPMap { temperature, .. } => {
1214            // z_k = σ(l_k/τ)·π_k ⇒ dz_k/dl_k = a_k(π_k − a_k)/(π_k τ) · π_k form
1215            // (matches `fill_assignment_logit_jvp_rows`).
1216            let inv_tau = 1.0 / temperature;
1217            let prior =
1218                ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
1219            let pi_k = prior[k];
1220            let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1221            let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1222            for out_col in 0..p {
1223                jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
1224            }
1225        }
1226        AssignmentMode::ThresholdGate {
1227            temperature,
1228            threshold,
1229        } => {
1230            // The data-fit Jacobian follows the hard forward gate. Below the
1231            // threshold the reconstruction contribution is exactly zero, so the
1232            // data-fit logit derivative must also be zero. Band-only atoms stay
1233            // in the compact row for prior terms, not phantom reconstruction
1234            // slope.
1235            if logit_k <= threshold {
1236                return;
1237            }
1238            let inv_tau = 1.0 / temperature;
1239            let activation = gam_linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
1240            let da = activation * (1.0 - activation) * inv_tau;
1241            for out_col in 0..p {
1242                jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
1243            }
1244        }
1245    }
1246}
1247
1248pub(crate) fn fill_assignment_logit_jvp_rows(
1249    mode: AssignmentMode,
1250    logits: ArrayView1<'_, f64>,
1251    assignments: ArrayView1<'_, f64>,
1252    decoded: ArrayView2<'_, f64>,
1253    fitted: ArrayView1<'_, f64>,
1254    ibp_prior: Option<&[f64]>,
1255    // #1026 — per-atom ungated flags (length `K`). An ungated atom's gate is
1256    // constant, so its logit-JVP row is identically zero (skipped below). Empty
1257    // ⇒ no atom is ungated (the historical path, bit-identical).
1258    ungated: &[bool],
1259    local_jac: &mut Array2<f64>,
1260) {
1261    let is_ungated = |k: usize| ungated.get(k).copied().unwrap_or(false);
1262    match mode {
1263        AssignmentMode::Softmax { temperature, .. } => {
1264            if assignments.len() == 1 {
1265                return;
1266            }
1267            // da_k/dl_j = a_k (1[k=j] - a_j) / tau, contracted against
1268            // the assignment-weighted fitted row. The dense row layout uses
1269            // the reference-logit chart, so only columns `0..K-1` are free;
1270            // the final reference logit is fixed at zero and has no row.
1271            let inv_tau = 1.0 / temperature;
1272            for logit_col in 0..assignments.len() - 1 {
1273                if is_ungated(logit_col) {
1274                    continue;
1275                }
1276                for out_col in 0..fitted.len() {
1277                    local_jac[[logit_col, out_col]] = assignments[logit_col]
1278                        * (decoded[[logit_col, out_col]] - fitted[out_col])
1279                        * inv_tau;
1280                }
1281            }
1282        }
1283        AssignmentMode::IBPMap { temperature, .. } => {
1284            // Truncated-IBP concrete relaxation: z_k = σ(l_k/τ) · π_k where
1285            // π_k is the stick-breaking prior. Thus
1286            // dz_k/dl_k = σ(l/τ)(1-σ(l/τ))/τ · π_k = a_k(π_k - a_k)/(π_k τ).
1287            let inv_tau = 1.0 / temperature;
1288            let prior = ibp_prior
1289                .expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
1290            for logit_col in 0..assignments.len() {
1291                if is_ungated(logit_col) {
1292                    continue;
1293                }
1294                let pi_k = prior[logit_col];
1295                let a_k = assignments[logit_col];
1296                let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1297                let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1298                for out_col in 0..fitted.len() {
1299                    local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
1300                }
1301            }
1302        }
1303        AssignmentMode::ThresholdGate {
1304            temperature,
1305            threshold,
1306        } => {
1307            // Data-fit sensitivity follows the hard forward gate: rows at or
1308            // below the threshold have zero reconstruction value and therefore
1309            // zero data-fit logit derivative. The wider machine-precision prior
1310            // support is a compact-layout/prior rule, not a data-fit STE.
1311            let inv_tau = 1.0 / temperature;
1312            for logit_col in 0..assignments.len() {
1313                if is_ungated(logit_col) || logits[logit_col] <= threshold {
1314                    continue;
1315                }
1316                let activation = gam_linalg::utils::stable_logistic(
1317                    (logits[logit_col] - threshold) * inv_tau,
1318                );
1319                let da = activation * (1.0 - activation) * inv_tau;
1320                for out_col in 0..fitted.len() {
1321                    local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
1322                }
1323            }
1324        }
1325    }
1326}
1327
1328pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
1329    let mut out = Array1::<f64>::zeros(logits.len());
1330    for row in 0..logits.nrows() {
1331        let start = row * logits.ncols();
1332        for col in 0..logits.ncols() {
1333            out[start + col] = logits[[row, col]];
1334        }
1335    }
1336    out
1337}
1338
1339pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
1340    for row in 0..assignment.n_obs() {
1341        validate_finite_logits(assignment.logits.row(row), row)
1342            .expect("assignment logits must be finite");
1343    }
1344    let target = flat_logits(assignment.logits.view());
1345    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1346        return 0.0;
1347    }
1348    match assignment.mode {
1349        AssignmentMode::Softmax {
1350            temperature,
1351            sparsity,
1352        } => {
1353            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1354            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1355            penalty.value(target.view(), rho_view.view())
1356        }
1357        AssignmentMode::IBPMap {
1358            temperature,
1359            alpha,
1360            learnable_alpha,
1361        } => {
1362            let mut penalty = IBPAssignmentPenalty::new(
1363                assignment.k_atoms(),
1364                alpha,
1365                temperature,
1366                learnable_alpha,
1367            );
1368            let rho_view = if learnable_alpha {
1369                Array1::from_vec(vec![rho.log_lambda_sparse])
1370            } else {
1371                // Keep the fixed-alpha value path on the same weighting branch as
1372                // assignment_prior_grad_hdiag; that gradient path owns the
1373                // lambda_sparse convention for IBP assignment sparsity.
1374                penalty.weight = rho.lambda_sparse();
1375                Array1::zeros(0)
1376            };
1377            penalty.value(target.view(), rho_view.view())
1378        }
1379        AssignmentMode::ThresholdGate {
1380            temperature,
1381            threshold,
1382        } => {
1383            // Sparsity penalty uses the same threshold-centered surrogate and
1384            // machine-precision support as its gradient/Hessian. Data-fit
1385            // reconstruction remains hard-gated by `jumprelu_row`.
1386            let sparsity_strength = rho.lambda_sparse();
1387            let mut acc = 0.0;
1388            for &logit in target.iter() {
1389                if jumprelu_in_optimization_band(logit, threshold, temperature) {
1390                    acc += gam_linalg::utils::stable_logistic((logit - threshold) / temperature);
1391                }
1392            }
1393            sparsity_strength * acc
1394        }
1395    }
1396}
1397
1398pub(crate) fn assignment_prior_log_strength_derivative(
1399    assignment: &SaeAssignment,
1400    rho: &SaeManifoldRho,
1401) -> f64 {
1402    for row in 0..assignment.n_obs() {
1403        validate_finite_logits(assignment.logits.row(row), row)
1404            .expect("assignment logits must be finite");
1405    }
1406    let target = flat_logits(assignment.logits.view());
1407    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1408        return 0.0;
1409    }
1410    match assignment.mode {
1411        AssignmentMode::Softmax { .. } | AssignmentMode::ThresholdGate { .. } => {
1412            assignment_prior_value(assignment, rho)
1413        }
1414        AssignmentMode::IBPMap {
1415            temperature,
1416            alpha,
1417            learnable_alpha,
1418        } => {
1419            let mut penalty = IBPAssignmentPenalty::new(
1420                assignment.k_atoms(),
1421                alpha,
1422                temperature,
1423                learnable_alpha,
1424            );
1425            if learnable_alpha {
1426                let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1427                penalty.grad_rho(target.view(), rho_view.view())[0]
1428            } else {
1429                penalty.weight = rho.lambda_sparse();
1430                penalty.value(target.view(), Array1::<f64>::zeros(0).view())
1431            }
1432        }
1433    }
1434}
1435
1436pub(crate) fn assignment_prior_log_strength_hdiag(
1437    assignment: &SaeAssignment,
1438    rho: &SaeManifoldRho,
1439) -> Result<Array1<f64>, String> {
1440    for row in 0..assignment.n_obs() {
1441        validate_finite_logits(assignment.logits.row(row), row)?;
1442    }
1443    let target = flat_logits(assignment.logits.view());
1444    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1445        return Ok(Array1::<f64>::zeros(target.len()));
1446    }
1447    match assignment.mode {
1448        AssignmentMode::Softmax {
1449            temperature,
1450            sparsity,
1451        } => {
1452            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1453            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1454            penalty
1455                .hessian_diag(target.view(), rho_view.view())
1456                .ok_or_else(|| {
1457                    "softmax assignment log-strength hessian diag unavailable".to_string()
1458                })
1459        }
1460        AssignmentMode::ThresholdGate {
1461            temperature,
1462            threshold,
1463        } => {
1464            let sparsity_strength = rho.lambda_sparse();
1465            let inv_tau = 1.0 / temperature;
1466            let inv_tau2 = inv_tau * inv_tau;
1467            let mut d = Array1::<f64>::zeros(target.len());
1468            for idx in 0..target.len() {
1469                let logit = target[idx];
1470                if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1471                    continue;
1472                }
1473                let activation =
1474                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1475                let slope = activation * (1.0 - activation);
1476                d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1477            }
1478            Ok(d)
1479        }
1480        AssignmentMode::IBPMap {
1481            temperature,
1482            alpha,
1483            learnable_alpha,
1484        } => {
1485            let mut penalty = IBPAssignmentPenalty::new(
1486                assignment.k_atoms(),
1487                alpha,
1488                temperature,
1489                learnable_alpha,
1490            );
1491            if learnable_alpha {
1492                let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1493                Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
1494            } else {
1495                penalty.weight = rho.lambda_sparse();
1496                penalty
1497                    .hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
1498                    .ok_or_else(|| {
1499                        "IBP assignment log-strength hessian diag unavailable".to_string()
1500                    })
1501            }
1502        }
1503    }
1504}
1505
1506pub(crate) fn assignment_prior_log_strength_target_mixed(
1507    assignment: &SaeAssignment,
1508    rho: &SaeManifoldRho,
1509) -> Result<Array1<f64>, String> {
1510    for row in 0..assignment.n_obs() {
1511        validate_finite_logits(assignment.logits.row(row), row)?;
1512    }
1513    let target = flat_logits(assignment.logits.view());
1514    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1515        return Ok(Array1::<f64>::zeros(target.len()));
1516    }
1517    match assignment.mode {
1518        AssignmentMode::IBPMap {
1519            temperature,
1520            alpha,
1521            learnable_alpha: true,
1522        } => {
1523            let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
1524            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1525            Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
1526        }
1527        _ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
1528    }
1529}
1530
1531pub(crate) fn assignment_prior_grad_hdiag(
1532    assignment: &SaeAssignment,
1533    rho: &SaeManifoldRho,
1534) -> Result<(Array1<f64>, Array1<f64>), String> {
1535    for row in 0..assignment.n_obs() {
1536        validate_finite_logits(assignment.logits.row(row), row)?;
1537    }
1538    let target = flat_logits(assignment.logits.view());
1539    let mut grad = Array1::<f64>::zeros(target.len());
1540    let mut diag = Array1::<f64>::zeros(target.len());
1541    if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1542        return Ok((grad, diag));
1543    }
1544    let (sparsity_grad, sparsity_diag) = match assignment.mode {
1545        AssignmentMode::Softmax {
1546            temperature,
1547            sparsity,
1548        } => {
1549            let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1550            let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1551            let g = penalty.grad_target(target.view(), rho_view.view());
1552            let d = penalty
1553                .hessian_diag(target.view(), rho_view.view())
1554                .ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
1555            (g, d)
1556        }
1557        AssignmentMode::IBPMap {
1558            temperature,
1559            alpha,
1560            learnable_alpha,
1561        } => {
1562            // Scale the IBP assignment-sparsity prior by `lambda_sparse`, exactly
1563            // like the Softmax and JumpReLU branches do (Softmax folds it into the
1564            // penalty's rho coordinate, JumpReLU multiplies `sparsity_strength`).
1565            // Previously the IBP penalty used its hardcoded `weight = 1.0` and the
1566            // `rho.log_lambda_sparse` coordinate never reached it (the rho_view was
1567            // empty for the common `learnable_alpha = false` config), so the prior
1568            // ran at full strength with no way to dial it down — and its
1569            // Beta-Bernoulli BCE energy `−mass·ln π_k − (n−mass)·ln(1−π_k)` toward
1570            // the self-referential empirical active fraction `π_k` has its global
1571            // minimum at the all-off gate, so at full weight it over-shrank the
1572            // assignment off both atoms even with a truth-seeded decoder (#853).
1573            // Routing `lambda_sparse` into the penalty weight makes the prior a
1574            // genuine, user-controllable lever balanced against the data fit.
1575            let mut penalty = IBPAssignmentPenalty::new(
1576                assignment.k_atoms(),
1577                alpha,
1578                temperature,
1579                learnable_alpha,
1580            );
1581            // When `alpha` is learnable, `log_lambda_sparse` already modulates
1582            // it through `resolved_alpha(rho)`, so the weight stays 1.0 to avoid
1583            // double-counting that coordinate. Only when `alpha` is fixed (so the
1584            // sparse coordinate would otherwise be ignored entirely) does
1585            // `lambda_sparse` become the prior's weight lever.
1586            let rho_view = if learnable_alpha {
1587                Array1::from_vec(vec![rho.log_lambda_sparse])
1588            } else {
1589                penalty.weight = rho.lambda_sparse();
1590                Array1::zeros(0)
1591            };
1592            let g = penalty.grad_target(target.view(), rho_view.view());
1593            let d = penalty
1594                .hessian_diag(target.view(), rho_view.view())
1595                .ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
1596            (g, d)
1597        }
1598        AssignmentMode::ThresholdGate {
1599            temperature,
1600            threshold,
1601        } => {
1602            // Gradient and exact diagonal Hessian of the sparsity value's
1603            // threshold-centered surrogate σ((l−θ)/τ), using the same
1604            // machine-precision support as the value path. Data-fit JVP support
1605            // is narrower and follows the hard forward gate.
1606            let sparsity_strength = rho.lambda_sparse();
1607            let inv_tau = 1.0 / temperature;
1608            let inv_tau2 = inv_tau * inv_tau;
1609            let mut g = Array1::<f64>::zeros(target.len());
1610            let mut d = Array1::<f64>::zeros(target.len());
1611            for idx in 0..target.len() {
1612                let logit = target[idx];
1613                if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1614                    continue;
1615                }
1616                let activation =
1617                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1618                let slope = activation * (1.0 - activation);
1619                g[idx] = sparsity_strength * slope * inv_tau;
1620                d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1621            }
1622            (g, d)
1623        }
1624    };
1625    grad += &sparsity_grad;
1626    diag += &sparsity_diag;
1627    // #1026/#1033 — a FIXED logit (an ungated atom's, or every atom's under
1628    // frozen routing) is not a free parameter, so it carries NO sparsity-prior
1629    // gradient or curvature. Zero its flat columns (`flat_logits` is row-major
1630    // `row*K + atom`) so the assembled `gt` and `htt` logit slots stay zero —
1631    // matching the zero logit-JVP. The column-separable IBP / JumpReLU priors are
1632    // per-atom, so zeroing one atom's columns leaves the others' prior intact;
1633    // under frozen routing ALL atoms' logit columns are zeroed (the whole routing
1634    // is a fixed predicted function, not optimized).
1635    if assignment.has_ungated() || assignment.routing_is_frozen() {
1636        let k = assignment.k_atoms();
1637        for idx in 0..grad.len() {
1638            if assignment.logit_is_fixed(idx % k) {
1639                grad[idx] = 0.0;
1640                diag[idx] = 0.0;
1641            }
1642        }
1643    }
1644    Ok((grad, diag))
1645}
1646
1647/// Build the exact IBP `hessian_diag` logit third-derivative channels (#1006)
1648/// for the SAE log-det adjoint Γ, using the SAME penalty configuration —
1649/// `alpha`/`tau`/`learnable_alpha` and the `lambda_sparse` weight convention —
1650/// that [`assignment_prior_grad_hdiag`] assembles into `htt`. Returns `None`
1651/// for non-IBP assignment modes (no cross-row empirical-π coupling to correct).
1652pub(crate) fn ibp_assignment_third_channels(
1653    assignment: &SaeAssignment,
1654    rho: &SaeManifoldRho,
1655) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
1656    let AssignmentMode::IBPMap {
1657        temperature,
1658        alpha,
1659        learnable_alpha,
1660    } = assignment.mode
1661    else {
1662        return Ok(None);
1663    };
1664    for row in 0..assignment.n_obs() {
1665        validate_finite_logits(assignment.logits.row(row), row)?;
1666    }
1667    let target = flat_logits(assignment.logits.view());
1668    let mut penalty =
1669        IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
1670    // Mirror assignment_prior_grad_hdiag exactly: when alpha is learnable the
1671    // sparse coordinate already modulates it through resolved_alpha(rho), so the
1672    // weight stays 1.0; otherwise lambda_sparse becomes the prior's weight lever.
1673    let rho_view = if learnable_alpha {
1674        Array1::from_vec(vec![rho.log_lambda_sparse])
1675    } else {
1676        penalty.weight = rho.lambda_sparse();
1677        Array1::zeros(0)
1678    };
1679    let mut channels = penalty.hessian_diag_logit_third_channels(target.view(), rho_view.view());
1680    // #1026/#1033 — zero the log-det third-derivative channels of FIXED-logit
1681    // atoms (ungated, or all atoms under frozen routing) so the #1006 θ-adjoint
1682    // differentiates the SAME (fixed-logit-zeroed) `htt` that
1683    // `assignment_prior_grad_hdiag` assembled. `k_max` columns, row-major `N·K`
1684    // for the per-(row,atom) arrays and length-`K` for the per-column ones.
1685    if assignment.has_ungated() || assignment.routing_is_frozen() {
1686        let k = channels.k_max;
1687        for idx in 0..channels.z_jac.len() {
1688            if assignment.logit_is_fixed(idx % k) {
1689                channels.z_jac[idx] = 0.0;
1690                channels.local_logit_third[idx] = 0.0;
1691                channels.m_channel[idx] = 0.0;
1692                channels.logit_curvature[idx] = 0.0;
1693            }
1694        }
1695        for atom in 0..k {
1696            if assignment.logit_is_fixed(atom) {
1697                channels.cross_row_d[atom] = 0.0;
1698                channels.cross_row_dd[atom] = 0.0;
1699            }
1700        }
1701    }
1702    Ok(Some(channels))
1703}
1704
1705/// #1026 hybrid curved + linear-tail adjudication for one SAE atom slot.
1706///
1707/// A hybrid dictionary lets each atom slot be either a CURVED atom (its fitted
1708/// `latent_dim ≥ 1` manifold chart, whose decoded image may turn) or its LINEAR
1709/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
1710/// `γ(t) = t·b`, zero turning). The two are nested: the linear atom is exactly
1711/// the curved family restricted to its straight sub-model, so a hybrid slot
1712/// cannot lose to pure-linear at matched actives — it strictly generalizes it.
1713///
1714/// This is the single call the SAE fitter makes per atom to choose the split by
1715/// EVIDENCE rather than fiat. It packages the atom's two already-fitted
1716/// candidates — each scored on the COMMON rank-aware Laplace scale (`−V = NLE`,
1717/// lower wins, identical to the union/mixture rungs) on the same rows — and
1718/// routes them through [`select_hybrid_atom`]. The curved candidate's fitted
1719/// turning `Θ` (from
1720/// [`crate::chart_canonicalization::d1_atom_fitted_turning`]) enters
1721/// as the decision feature: a `Θ → 0` atom yields to the cheaper linear tail by
1722/// construction (the dominance floor — a curved atom buys nothing on a straight
1723/// feature), a high-`Θ` atom takes the curved parameterization when its
1724/// curvature lowers the NLE by more than its extra-parameter price (the `Θ/√ε`
1725/// crossover).
1726///
1727/// `manifold` is the atom's fitted chart manifold; a non-curveable (already
1728/// Euclidean-flat) chart can only present the linear candidate, which this
1729/// helper enforces by ignoring any curved candidate offered for a flat chart —
1730/// a flat chart has no curvature to price, so the linear special case is its
1731/// only honest parameterization. Curveable charts present both candidates.
1732///
1733/// # Wiring into the fitter (the one call into `sae_manifold.rs`)
1734///
1735/// The post-fit pass in `sae_manifold.rs` already computes each d=1 atom's
1736/// fitted turning `Θ` (the read-only EV-vs-Θ diagnostic). To make the split
1737/// load-bearing, that pass supplies, per atom, the curved-candidate NLE +
1738/// parameter count + `Θ` and the linear-candidate NLE + parameter count (both
1739/// fitted on the atom's rows), and calls this helper; the returned
1740/// [`HybridAtomChoice`] tells the fitter which parameterization to keep for that
1741/// slot. The fitting of the two candidates lives in `sae_manifold.rs` (the
1742/// manifold-chart fitter); the SELECTION/scoring lives here.
1743pub fn select_hybrid_atom_parameterization(
1744    manifold: &LatentManifold,
1745    curved: Option<HybridAtomCandidate>,
1746    linear: HybridAtomCandidate,
1747) -> HybridAtomChoice {
1748    // A flat (Euclidean) chart has no curvature to price: its only honest
1749    // parameterization is the linear special case, so any curved candidate
1750    // offered for it is dropped before the evidence comparison. Curveable charts
1751    // (Circle / Sphere / Torus / curved products) present both candidates.
1752    let curved = if manifold.is_euclidean() {
1753        None
1754    } else {
1755        curved
1756    };
1757    let candidates: Vec<HybridAtomCandidate> = match curved {
1758        Some(c) => vec![linear, c],
1759        None => vec![linear],
1760    };
1761    // `candidates` is never empty (it always contains the linear candidate), so
1762    // the selector always returns a choice.
1763    select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
1764}
1765
1766#[cfg(test)]
1767mod ibp_prior_614_tests {
1768    // #614: `ibp_stick_breaking_prior` used to compute `π_k = (α/(α+1))^k` with
1769    // `π_0 = 1`, i.e. an UNSHRUNK first atom — the prior mean of no stick at all,
1770    // which broke α's role as an IBP concentration parameter. The consistent
1771    // truncated-IBP stick-breaking prior mean is `π_k = (α/(α+1))^{k+1}`, the
1772    // expectation of the product of (k+1) i.i.d. Beta(α,1) stick means, so EVERY
1773    // atom (including the first) carries one stick of shrinkage. This test pins
1774    // that contract so the regression cannot silently return.
1775    use super::*;
1776
1777    fn ratio(alpha: f64) -> f64 {
1778        alpha / (alpha + 1.0)
1779    }
1780
1781    #[test]
1782    fn first_atom_is_shrunk_not_unity() {
1783        // The #614 defect: π_0 must equal the single-stick mean α/(α+1), NOT 1.0.
1784        for &alpha in &[0.1_f64, 0.5, 1.0, 2.0, 5.0] {
1785            let prior = ordered_geometric_shrinkage_prior(8, alpha);
1786            let r = ratio(alpha);
1787            assert!(
1788                (prior[0] - r).abs() < 1e-12,
1789                "π_0 must be the single-stick mean α/(α+1)={r} (was the unshrunk 1.0 in #614); got {}",
1790                prior[0]
1791            );
1792            assert!(
1793                prior[0] < 1.0,
1794                "first atom must be shrunk (π_0<1) for alpha={alpha}; got {}",
1795                prior[0]
1796            );
1797        }
1798    }
1799
1800    #[test]
1801    fn prior_is_consistent_geometric_product_mean() {
1802        // π_k = (α/(α+1))^{k+1} exactly, and every successive ratio equals α/(α+1).
1803        for &alpha in &[0.3_f64, 1.0, 4.0] {
1804            let k = 12;
1805            let prior = ordered_geometric_shrinkage_prior(k, alpha);
1806            let r = ratio(alpha);
1807            for j in 0..k {
1808                let expected = r.powi((j + 1) as i32);
1809                assert!(
1810                    (prior[j] - expected).abs() < 1e-12 * expected.max(1.0),
1811                    "alpha={alpha} π_{j}: expected {expected}, got {}",
1812                    prior[j]
1813                );
1814            }
1815            // Strictly decreasing (ordered shrinkage), no plateau at the head.
1816            for j in 1..k {
1817                assert!(
1818                    prior[j] < prior[j - 1],
1819                    "alpha={alpha}: prior must strictly decrease at index {j}"
1820                );
1821            }
1822        }
1823    }
1824
1825    #[test]
1826    fn alpha_behaves_as_concentration() {
1827        // Larger α => heavier mass / slower decay: π_0 increases toward 1 and the
1828        // tail (e.g. π_4) carries more mass. This is the IBP-concentration role
1829        // the #614 fix restored.
1830        let lo = ordered_geometric_shrinkage_prior(8, 0.5);
1831        let hi = ordered_geometric_shrinkage_prior(8, 5.0);
1832        assert!(
1833            hi[0] > lo[0],
1834            "larger alpha must raise π_0 (concentration): {} vs {}",
1835            hi[0],
1836            lo[0]
1837        );
1838        assert!(
1839            hi[4] > lo[4],
1840            "larger alpha must put more mass in the tail: {} vs {}",
1841            hi[4],
1842            lo[4]
1843        );
1844    }
1845}
1846
1847#[cfg(test)]
1848mod hybrid_split_tests {
1849    use super::*;
1850    use gam_solve::evidence::HybridAtomParam;
1851
1852    #[test]
1853    fn flat_chart_drops_curved_candidate_and_keeps_linear() {
1854        // A Euclidean chart has no curvature: even if a curved candidate with a
1855        // lower NLE is offered, the helper drops it (a flat chart cannot honestly
1856        // present a curved parameterization).
1857        let linear = HybridAtomCandidate::linear(100.0, 2);
1858        let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
1859        let choice =
1860            select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
1861        assert!(choice.param.is_linear());
1862    }
1863
1864    #[test]
1865    fn curveable_chart_selects_curved_when_turning_pays() {
1866        // A Circle chart presents both candidates; a turning feature whose curved
1867        // fit beats the linear secant on evidence selects curved.
1868        let linear = HybridAtomCandidate::linear(100.0, 2);
1869        let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
1870        let choice = select_hybrid_atom_parameterization(
1871            &LatentManifold::Circle {
1872                period: 2.0 * std::f64::consts::PI,
1873            },
1874            Some(curved),
1875            linear,
1876        );
1877        assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
1878    }
1879
1880    #[test]
1881    fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
1882        let linear = HybridAtomCandidate::linear(33.0, 2);
1883        let choice = select_hybrid_atom_parameterization(
1884            &LatentManifold::Circle {
1885                period: 2.0 * std::f64::consts::PI,
1886            },
1887            None,
1888            linear,
1889        );
1890        assert!(choice.param.is_linear());
1891        assert_eq!(choice.num_parameters, 2);
1892    }
1893}
1894
1895#[cfg(test)]
1896mod frozen_routing_1033_tests {
1897    //! #1033 — the FROZEN (amortized) routing mechanism: once installed, the
1898    //! per-row gate is a ρ-invariant function of the FROZEN predicted logits and
1899    //! is DECOUPLED from any subsequent update to the free `self.logits` (the
1900    //! inner-fit logit drift the outer ρ-search would otherwise re-incur every
1901    //! eval). These are deterministic mechanism invariants — no inner fit — so
1902    //! they pin the load-bearing freeze properties without the cluster.
1903    use super::*;
1904
1905    fn ibp_assignment(n: usize, k: usize) -> SaeAssignment {
1906        let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1907            0.3 + 0.05 * (i as f64) - 0.1 * (kk as f64)
1908        });
1909        let coords: Vec<Array2<f64>> = (0..k)
1910            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1911            .collect();
1912        // learnable_alpha = false: alpha is ρ-independent, isolating the routing.
1913        SaeAssignment::from_blocks_with_mode(
1914            logits,
1915            coords,
1916            AssignmentMode::ibp_map(0.5, 1.0, false),
1917        )
1918        .unwrap()
1919    }
1920
1921    #[test]
1922    fn frozen_routing_decouples_gates_from_logit_updates_1033() {
1923        let (n, k) = (6usize, 3usize);
1924        let mut a = ibp_assignment(n, k)
1925            .freeze_routing_from_current_logits()
1926            .unwrap();
1927        assert!(a.routing_is_frozen());
1928        // Gates BEFORE mutating the free logits.
1929        let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
1930        let before: Vec<Array1<f64>> = (0..n)
1931            .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1932            .collect();
1933        // Simulate an inner-fit logit update (what the ρ-search would otherwise do
1934        // every eval): perturb every free logit substantially.
1935        a.logits.mapv_inplace(|v| v + 5.0);
1936        let after: Vec<Array1<f64>> = (0..n)
1937            .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1938            .collect();
1939        // FROZEN routing reads the snapshot, so the gates are UNCHANGED by the
1940        // free-logit perturbation — the routing is decoupled from inner-fit drift.
1941        for r in 0..n {
1942            for kk in 0..k {
1943                assert_eq!(
1944                    before[r][kk], after[r][kk],
1945                    "row {r} atom {kk}: frozen-routing gate must be UNCHANGED by a free-logit \
1946                     update (decoupled from inner-fit drift); {} vs {}",
1947                    before[r][kk], after[r][kk]
1948                );
1949            }
1950        }
1951    }
1952
1953    #[test]
1954    fn frozen_routing_gates_are_rho_invariant_1033() {
1955        let (n, k) = (5usize, 2usize);
1956        let a = ibp_assignment(n, k)
1957            .freeze_routing_from_current_logits()
1958            .unwrap();
1959        // Two different ρ (different sparse + smooth strengths). With frozen routing
1960        // and learnable_alpha=false, the gate value must be identical at both ρ.
1961        let rho_a = SaeManifoldRho::new(
1962            (1e-3_f64).ln(),
1963            (1e-2_f64).ln(),
1964            vec![Array1::<f64>::zeros(1); k],
1965        );
1966        let rho_b = SaeManifoldRho::new(
1967            (1e3_f64).ln(),
1968            (1e1_f64).ln(),
1969            vec![Array1::<f64>::zeros(1); k],
1970        );
1971        for r in 0..n {
1972            let ga = a.try_assignments_row_for_rho(r, &rho_a).unwrap();
1973            let gb = a.try_assignments_row_for_rho(r, &rho_b).unwrap();
1974            for kk in 0..k {
1975                assert_eq!(
1976                    ga[kk], gb[kk],
1977                    "row {r} atom {kk}: frozen-routing gate must be ρ-INVARIANT (the n-independence \
1978                     lever); {} at ρ_a vs {} at ρ_b",
1979                    ga[kk], gb[kk]
1980                );
1981            }
1982        }
1983    }
1984
1985    #[test]
1986    fn frozen_routing_fixes_all_logits_and_thaw_restores_free_path_1033() {
1987        let (n, k) = (4usize, 3usize);
1988        let mut a = ibp_assignment(n, k)
1989            .freeze_routing_from_current_logits()
1990            .unwrap();
1991        // Under frozen routing EVERY logit is fixed (not a free Newton coord).
1992        let mask = a.fixed_logit_mask();
1993        assert_eq!(mask.len(), k);
1994        assert!(
1995            mask.iter().all(|&f| f),
1996            "frozen routing must fix ALL logits"
1997        );
1998        for kk in 0..k {
1999            assert!(
2000                a.logit_is_fixed(kk),
2001                "atom {kk} logit must be fixed under frozen routing"
2002            );
2003        }
2004        // Thawing restores the free-logit path (no fixed logits, no ungated).
2005        a.thaw_routing();
2006        assert!(!a.routing_is_frozen());
2007        assert!(
2008            a.fixed_logit_mask().iter().all(|&f| !f),
2009            "thaw must restore the free-logit path"
2010        );
2011    }
2012
2013    #[test]
2014    fn frozen_routing_rejects_softmax_1033() {
2015        let (n, k) = (4usize, 3usize);
2016        let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.1 * (i as f64) - 0.05 * (kk as f64));
2017        let coords: Vec<Array2<f64>> = (0..k)
2018            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
2019            .collect();
2020        let a = SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::softmax(1.0))
2021            .unwrap();
2022        // Softmax + frozen routing is rejected (the coupled-simplex entropy
2023        // majorizer would be inconsistent with a frozen, non-optimized routing).
2024        assert!(
2025            a.freeze_routing_from_current_logits().is_err(),
2026            "frozen routing under Softmax must be rejected (simplex entropy-majorizer coupling)"
2027        );
2028    }
2029}
2030
2031#[cfg(test)]
2032mod fill_into_buffer_1557_tests {
2033    //! #1557 — the fill-into-caller-buffer variant
2034    //! [`SaeAssignment::try_assignments_row_for_rho_into`] must produce
2035    //! BIT-IDENTICAL output to the allocating
2036    //! [`SaeAssignment::try_assignments_row_for_rho`] across every assignment
2037    //! mode (Softmax, IBPMap, JumpReLU), the #1026 ungated case, and the K==1
2038    //! edge. Exact `==` on f64 — not an approximate tolerance — because the
2039    //! `_into` path is a pure allocation-elision refactor and any numeric drift
2040    //! is a regression.
2041    use super::*;
2042
2043    fn build(n: usize, k: usize, mode: AssignmentMode) -> SaeAssignment {
2044        // Deterministic, asymmetric logits/coords so every atom takes a distinct
2045        // value (no accidental ties masking an index bug).
2046        let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
2047            0.37 + 0.11 * (i as f64) - 0.23 * (kk as f64)
2048        });
2049        let coords: Vec<Array2<f64>> = (0..k)
2050            .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| 0.1 + 0.05 * (i as f64)))
2051            .collect();
2052        SaeAssignment::from_blocks_with_mode(logits, coords, mode).unwrap()
2053    }
2054
2055    fn rho(k: usize) -> SaeManifoldRho {
2056        SaeManifoldRho::new(
2057            (1e-2_f64).ln(),
2058            (1e-1_f64).ln(),
2059            vec![Array1::<f64>::zeros(1); k],
2060        )
2061    }
2062
2063    fn assert_into_matches_alloc(a: &SaeAssignment) {
2064        let n = a.n_obs();
2065        let k = a.k_atoms();
2066        let rho = rho(k);
2067        let mut scratch = vec![f64::NAN; k];
2068        for row in 0..n {
2069            let allocated = a.try_assignments_row_for_rho(row, &rho).unwrap();
2070            // Pre-fill with NaN so a partial write (e.g. a JumpReLU below-threshold
2071            // entry left untouched) is caught as a mismatch, not silently passed.
2072            for s in scratch.iter_mut() {
2073                *s = f64::NAN;
2074            }
2075            a.try_assignments_row_for_rho_into(row, &rho, &mut scratch)
2076                .unwrap();
2077            assert_eq!(allocated.len(), k);
2078            for kk in 0..k {
2079                assert_eq!(
2080                    allocated[kk], scratch[kk],
2081                    "row {row} atom {kk}: _into must be BIT-IDENTICAL to the allocating \
2082                     try_assignments_row_for_rho; got {} vs {}",
2083                    allocated[kk], scratch[kk]
2084                );
2085            }
2086        }
2087    }
2088
2089    #[test]
2090    fn softmax_into_is_bit_identical() {
2091        assert_into_matches_alloc(&build(7, 4, AssignmentMode::softmax(0.8)));
2092    }
2093
2094    #[test]
2095    fn ibp_map_into_is_bit_identical() {
2096        // Both learnable and fixed alpha exercise the resolved-alpha branch.
2097        assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, false)));
2098        assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, true)));
2099    }
2100
2101    #[test]
2102    fn jumprelu_into_is_bit_identical() {
2103        // Threshold chosen so SOME atoms fall below it (the untouched-entry path)
2104        // and some clear it (the sigmoid path) — both branches are exercised.
2105        assert_into_matches_alloc(&build(7, 5, AssignmentMode::jumprelu(0.9, 0.2)));
2106    }
2107
2108    #[test]
2109    fn ungated_into_is_bit_identical() {
2110        // #1026 ungated overwrite under a gate-style mode (IBP/JumpReLU allow it).
2111        let a = build(6, 4, AssignmentMode::ibp_map(0.6, 1.1, false))
2112            .with_ungated(vec![false, true, false, true])
2113            .unwrap();
2114        assert_into_matches_alloc(&a);
2115        let j = build(6, 4, AssignmentMode::jumprelu(0.9, 0.15))
2116            .with_ungated(vec![true, false, true, false])
2117            .unwrap();
2118        assert_into_matches_alloc(&j);
2119    }
2120
2121    #[test]
2122    fn k_equals_one_into_is_bit_identical() {
2123        // Softmax K==1 hits the fixed-unit early return; IBP/JumpReLU K==1 keep a
2124        // free per-atom gate and fall through to the real row functions.
2125        assert_into_matches_alloc(&build(5, 1, AssignmentMode::softmax(1.0)));
2126        assert_into_matches_alloc(&build(5, 1, AssignmentMode::ibp_map(0.7, 1.0, false)));
2127        assert_into_matches_alloc(&build(5, 1, AssignmentMode::jumprelu(0.8, 0.1)));
2128    }
2129}