gam_sae/assignment.rs
1//! Assignment gates and sparsity-prior helpers for the SAE manifold term.
2//! Mechanically split from `sae_manifold.rs`.
3
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use gam_solve::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
7use gam_terms::analytic_penalties::{
8 AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
9 SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
10};
11use gam_terms::latent::{LatentCoordValues, LatentIdMode, LatentManifold};
12use crate::manifold::SaeManifoldRho;
13
14/// #976 Layer-1 guard: cap on one accepted iteration's assignment-logit
15/// update, in units of the gate temperature τ (the gate's natural length
16/// scale — every assignment mode reads logits through `σ(·/τ)` /
17/// `softmax(·/τ)`). A 4τ move spans the gate's whole soft range, so healthy
18/// convergence is never throttled, but no single inner iteration can carry a
19/// gate from contention to numerically-zero support: a collapse takes
20/// multiple accepted iterations, which guarantees the per-iteration
21/// active-mass guard observes the decay before it completes. The clamp is
22/// applied where the step is realised; when it binds, the realised objective
23/// is evaluated on the clamped state, so the Armijo comparison stays
24/// value-consistent (the unclamped quadratic model is merely conservative,
25/// and step halvings shrink the trial below the cap).
26pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
27
28/// #976 Layer-1 guard: re-seed budget per atom per joint fit. One second
29/// chance from a fresh basin; a second breach means the collapse is (locally)
30/// the objective's verdict at the current hyperparameters, which is recorded
31/// as a terminal collapse event and left for the structure-search death move
32/// to adjudicate — re-seeding in a loop would fight the optimizer.
33pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
34
35/// #976 Layer-1 guard (decoder arm): an atom whose decoder block Frobenius norm
36/// has fallen to this fraction of the dictionary's MEDIAN decoder norm carries
37/// no material reconstruction signal — it has degenerated to (near-)zero output
38/// and decodes the same nothing as every other collapsed atom. This is the
39/// real-data K>1 failure that the gate-mass floor cannot see: the assignment
40/// gates can stay spread across rows (mass guard satisfied) while the decoders
41/// all collapse to ~0, giving EV≈0 and a rank-deficient per-row coordinate
42/// Hessian on every row (the 0→K·n evidence-deflation jump). The statistic is a
43/// RATIO to the dictionary median so it is scale-free and never fires for a
44/// uniformly-small but well-conditioned decoder; only an atom that has fallen
45/// far behind its peers is caught. By construction this is a no-op for K=1
46/// (a single atom has no peer to fall behind, and the median equals its own
47/// norm), so the K=1 path is byte-for-byte unchanged.
48pub(crate) const SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO: f64 = 1.0e-3;
49
50/// #976 / #1117 K>1 robustness: bounded DICTIONARY-level multi-start budget for
51/// the simultaneous co-collapse arm (the EV-floor branch of
52/// [`crate::manifold::SaeManifoldTerm::enforce_decoder_norm_guard`]).
53/// Distinct from the per-atom [`SAE_ATOM_COLLAPSE_RESEED_BUDGET`] (= 1): that
54/// budget governs reseeding ONE atom's gate logits against an optimizer that
55/// keeps killing it, where a loop would fight the optimizer. A co-collapse
56/// reseed is categorically different — it is a full-dictionary multi-start that
57/// re-diversifies ALL atoms onto distinct principal directions of a FRESHLY
58/// recomputed residual, so successive attempts explore genuinely different
59/// basins. A single such reseed empirically cannot always break a K≥3 three-way
60/// basin (identical (K, seed) flips EV≈0.40 ↔ 0.00), so this arm gets a small
61/// bounded budget of independent multi-starts. It is consumed ONLY when the
62/// whole dictionary's reconstruction EV is at or below the data-derived collapse
63/// bar (`collapse_ev_bar` = `SAE_COLLAPSE_PCA_EV_FRACTION` × the rank-K PCA
64/// ceiling, i.e. less than half the variance any rank-K linear dictionary could
65/// reach) — the old absolute magic EV floor was REPLACED by that data-derived
66/// bar, not deleted. A no-op for any healthy fit (real OLMo K=1 ~0.22, K=2
67/// ~0.40, both well above the bar); on the rough patches of a hard K≥3 fit a
68/// transient sub-bar EV can still consume the budget before the fit recovers.
69pub(crate) const SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET: usize = 3;
70
71/// Machine-precision support cutoff for the smooth JumpReLU assignment prior,
72/// in units of the gate temperature below the hard threshold. The forward gate
73/// remains hard-zero at and below `threshold`, but the prior value/gradient and
74/// compact Newton layout keep every logit with `(logit - threshold)/tau > -36`.
75/// At the excluded edge `sigma(-36) ~= 2e-16`, so dropped value/gradient/Hessian
76/// terms are below f64 noise instead of creating an algorithmic discontinuity.
77pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
78
79/// Shared support predicate for JumpReLU optimization inclusion. This is
80/// strictly weaker than the hard forward gate `logit > threshold`, which still
81/// governs data-fit reconstruction and its logit JVP.
82#[inline]
83pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
84 (logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
85}
86
87/// Assignment prior/relaxation used by [`SaeAssignment`].
88#[derive(Debug, Clone, Copy)]
89pub enum AssignmentMode {
90 /// Row-wise simplex assignment with entropy sparsity.
91 Softmax { temperature: f64, sparsity: f64 },
92 /// Deterministic concrete relaxation of a truncated IBP active set: each
93 /// atom's gate is the INDEPENDENT prior-shrunk activation
94 /// `σ(logit/temperature) · π_k`, with `π_k = (α/(α+1))^{k+1}` the
95 /// stick-breaking prior mean (see [`ibp_map_row`]). These are per-atom gates,
96 /// NOT mixture/simplex responsibilities: they are computed independently per
97 /// column and do NOT sum to 1 across atoms (there is no row normalization).
98 /// Each `a_k ∈ [0, π_k] ⊂ [0, 1)` is the relaxed "atom k is active in this
99 /// row" indicator of a truncated IBP, not a share of a unit reconstruction
100 /// budget.
101 IBPMap {
102 temperature: f64,
103 alpha: f64,
104 learnable_alpha: bool,
105 },
106 /// Hard-thresholded bounded gate: each atom is off (gate = 0) when its logit
107 /// is at or below `threshold`, and on with a threshold-centered shifted
108 /// sigmoid `σ((logit − threshold) / temperature) ∈ [0.5, 1)` above it.
109 ///
110 /// #1777 RENAMED from `JumpReLU` (an inaccurate name): this is NOT the
111 /// literature JumpReLU activation `z·1[z>θ]`, which carries the thresholded
112 /// MAGNITUDE `z`. This mode is a thresholded-logistic GATE (a hard-sigmoid
113 /// gate): it carries no magnitude at all — its output is a bounded `[0, 1)`
114 /// indicator. `ThresholdGate` names it for what it is. It is a member of the
115 /// gate family (softmax simplex / IBP sigmoid / this hard gate); reconstruction
116 /// magnitude lives entirely in the decoder curve `g_k(t) = φ(t)ᵀ B_k`. The
117 /// discontinuity at `threshold` (0 → 0.5) is the intended "jump".
118 ///
119 /// BACK-COMPAT: the constructor [`Self::threshold_gate`] is the primary
120 /// spelling; [`Self::jumprelu`] is retained as a deprecated alias, and the FFI
121 /// string parser accepts both `"threshold_gate"` and the legacy `"jumprelu"`.
122 ThresholdGate { temperature: f64, threshold: f64 },
123}
124
125/// #1033 — the fixed-form predictor that produces the ρ-invariant FROZEN routing
126/// (amortized routing). Both forms are NO-learned-net deterministic functions of
127/// the current dictionary; they differ in how faithfully they track the
128/// dictionary as it evolves across outer iterates. Kept as alternatives so the
129/// accuracy gate can pick whichever passes the fit-quality bar (the cheap
130/// `Snapshot` if it suffices, the `ChartGeometry` distill otherwise).
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum RoutingPredictor {
133 /// Snapshot the current (converged) logits as the frozen routing — the
134 /// cheapest fixed-form distill, exact at the dictionary it is taken from.
135 /// Goes stale as the dictionary moves (needs a refresh to track), so it is the
136 /// MVP/baseline form.
137 Snapshot,
138 /// Re-derive the per-(row, atom) routing logit from the atom's encode-chart
139 /// geometry against the CURRENT dictionary: encode each row to its predicted
140 /// coord `t̂`, reconstruct the amplitude-1 image `γ_k(t̂) = Bᵀφ(t̂)`, and map
141 /// the reconstruction ALIGNMENT to a logit. This tracks the dictionary
142 /// (a moved decoder changes `γ_k(t̂)` and hence the routing) without re-running
143 /// the free-logit inner solve, so it is the default-readiness form when the
144 /// snapshot proves too stale.
145 ChartGeometry,
146}
147
148impl AssignmentMode {
149 #[must_use]
150 pub fn softmax(temperature: f64) -> Self {
151 Self::Softmax {
152 temperature,
153 sparsity: 1.0,
154 }
155 }
156
157 #[must_use]
158 pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
159 Self::IBPMap {
160 temperature,
161 alpha,
162 learnable_alpha,
163 }
164 }
165
166 /// #1777 — construct the hard-sigmoid [`Self::ThresholdGate`] (the accurate
167 /// name for what was `jumprelu`). Primary spelling; [`Self::jumprelu`] is a
168 /// deprecated alias kept for back-compat.
169 #[must_use]
170 pub fn threshold_gate(temperature: f64, threshold: f64) -> Self {
171 Self::ThresholdGate {
172 temperature,
173 threshold,
174 }
175 }
176
177 /// Back-compat alias for [`Self::threshold_gate`] (#1777): the mode is a
178 /// hard-sigmoid gate, not the literature JumpReLU magnitude activation. Retained
179 /// (NOT `#[deprecated]`, since the workspace denies warnings and many callers
180 /// and the legacy `"jumprelu"` FFI token still use it) so existing code keeps
181 /// compiling; new code should prefer `threshold_gate`.
182 #[must_use]
183 pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
184 Self::threshold_gate(temperature, threshold)
185 }
186
187 pub fn temperature(&self) -> f64 {
188 match *self {
189 AssignmentMode::Softmax { temperature, .. }
190 | AssignmentMode::IBPMap { temperature, .. }
191 | AssignmentMode::ThresholdGate { temperature, .. } => temperature,
192 }
193 }
194
195 pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
196 if !(new_temperature.is_finite() && new_temperature > 0.0) {
197 return Err(format!(
198 "AssignmentMode: temperature must be finite and positive; got {new_temperature}"
199 ));
200 }
201 match self {
202 AssignmentMode::Softmax { temperature, .. }
203 | AssignmentMode::IBPMap { temperature, .. }
204 | AssignmentMode::ThresholdGate { temperature, .. } => {
205 *temperature = new_temperature;
206 }
207 }
208 Ok(())
209 }
210
211 pub(crate) fn validate(&self) -> Result<(), String> {
212 let temperature = self.temperature();
213 if !(temperature.is_finite() && temperature > 0.0) {
214 return Err(format!(
215 "AssignmentMode: temperature must be finite and positive; got {temperature}"
216 ));
217 }
218 match *self {
219 AssignmentMode::Softmax { sparsity, .. } => {
220 if !(sparsity.is_finite() && sparsity > 0.0) {
221 return Err(format!(
222 "AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
223 ));
224 }
225 }
226 AssignmentMode::IBPMap { alpha, .. } => {
227 if !(alpha.is_finite() && alpha > 0.0) {
228 return Err(format!(
229 "AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
230 ));
231 }
232 }
233 AssignmentMode::ThresholdGate { threshold, .. } => {
234 if !threshold.is_finite() {
235 return Err(format!(
236 "AssignmentMode::ThresholdGate: threshold must be finite; got {threshold}"
237 ));
238 }
239 }
240 }
241 Ok(())
242 }
243
244 /// Resolve the effective truncated-IBP concentration `α` for this mode.
245 ///
246 /// `per_fit_override` is the #1777 PER-FIT override (from
247 /// [`SaeAssignment::ibp_alpha_override`]) and is the SOURCE OF TRUTH when set.
248 /// It falls back to the deprecated process-global [`ibp_alpha_override`] atomic
249 /// only when the per-fit field is unset, then to the mode's own `α` /
250 /// learnable schedule — so nothing breaks, but concurrent fits are isolatable
251 /// via the per-fit field.
252 pub(crate) fn resolved_ibp_alpha(
253 &self,
254 rho: &SaeManifoldRho,
255 per_fit_override: Option<f64>,
256 ) -> Option<f64> {
257 match *self {
258 AssignmentMode::IBPMap {
259 alpha,
260 learnable_alpha,
261 ..
262 } => Some(if let Some(over) = per_fit_override.or_else(ibp_alpha_override) {
263 // #1777 — the per-fit override (else the deprecated process-global
264 // one) flattens the ordered geometric prior π_k = (α/(α+1))^{k+1}
265 // so all K atoms can contribute to the reconstruction (the
266 // production α=1 gives a (0.5)^{k+1} schedule that structurally
267 // caps atoms 4..K → effective-K≈3). Forces the fixed value,
268 // bypassing the learnable schedule.
269 over
270 } else if learnable_alpha {
271 resolve_learnable_weight(alpha, rho.log_lambda_sparse)
272 } else {
273 alpha
274 }),
275 _ => None,
276 }
277 }
278}
279
280// #1026 — process-global IBP-α override (NaN sentinel = "unset → use the
281// AssignmentMode's compiled α"). Lets ONE wheel sweep the prior-flattening axis
282// from Python (`sae_set_ibp_alpha`) without recompiling the gam crate.
283//
284// CONCURRENCY WARNING: this is a PROCESS-GLOBAL atomic, not per-fit config. It is
285// read by `ibp_alpha_override` from every IBP assignment evaluation in the
286// process, so setting it affects ALL in-flight fits, not just the caller's. It is
287// therefore UNSAFE to use across concurrent / parallel in-process fits — one
288// fit's sweep value leaks into another's gates. It is safe only for serial,
289// whole-process sweeps (the single-wheel FFI sweep driver it exists for). This
290// should be migrated to per-fit configuration (threaded through `SaeManifoldRho`
291// / the AssignmentMode) before any concurrent multi-fit use; that refactor is
292// cross-cutting (FFI + term plumbing) and deliberately out of scope here.
293static IBP_ALPHA_OVERRIDE_BITS: std::sync::atomic::AtomicU64 =
294 std::sync::atomic::AtomicU64::new(0x7ff8_0000_0000_0000);
295
296pub(crate) fn ibp_alpha_override() -> Option<f64> {
297 let v = f64::from_bits(IBP_ALPHA_OVERRIDE_BITS.load(std::sync::atomic::Ordering::Relaxed));
298 if v.is_finite() && v > 0.0 {
299 Some(v)
300 } else {
301 None
302 }
303}
304
305/// Set (or, with a non-finite/non-positive value, clear) the process-global
306/// IBP-α override. Called from the gamfit Python FFI sweep driver.
307///
308/// PROCESS-GLOBAL / NOT CONCURRENCY-SAFE: this mutates one process-wide atomic
309/// read by every IBP assignment in the process. Calling it while any other fit is
310/// running leaks the override into that fit's gates. Use only for serial,
311/// whole-process sweeps; do not use across concurrent in-process fits. See the
312/// `IBP_ALPHA_OVERRIDE_BITS` note on migrating this to per-fit config.
313pub fn set_ibp_alpha_override(alpha: f64) {
314 IBP_ALPHA_OVERRIDE_BITS.store(alpha.to_bits(), std::sync::atomic::Ordering::Relaxed);
315}
316
317/// Per-row latent assignment state.
318///
319/// The stored assignment parameter is `logits`; non-negative assignments are
320/// derived by row-wise softmax, independent IBP-MAP sigmoid active indicators,
321/// or JumpReLU gates. Softmax logits are canonicalized to the reference chart
322/// `logits[K - 1] = 0`, so the row-local Newton coordinates contain only the
323/// first `K - 1` logits (`0` coordinates for `K = 1`). Gate-style modes keep
324/// all `K` logits as identifiable scalar parameters. `coords[k]` holds
325/// `t_{.,k}` for atom `k`.
326#[derive(Debug, Clone)]
327pub struct SaeAssignment {
328 pub logits: Array2<f64>,
329 pub coords: Vec<LatentCoordValues>,
330 pub mode: AssignmentMode,
331 /// #1026 — per-atom UNGATED flag (length `K`, default all-`false`). An
332 /// ungated atom is the dense linear/background tier: its per-row gate is
333 /// fixed at `a_k ≡ 1` (it contributes `γ_k(t_k)` to EVERY row, unweighted),
334 /// it is excluded from the other atoms' gate (for the column-separable
335 /// IBP / JumpReLU modes the remaining atoms are computed independently, so
336 /// they are unaffected), and its logit is NOT a free parameter — its
337 /// logit-JVP, sparsity-prior gradient/curvature, and softmax majorizer
338 /// contributions are all zero, leaving its logit slot an inert
339 /// (ridge-regularized) null direction in the per-row Newton block. This lets
340 /// the linear tier carry FULL-RANK reconstructible variance
341 /// (`fitted = γ_ungated(x) + Σ_{gated} a_k·γ_k(x)`) so a linear SAE can reach
342 /// the rank-(K·d) PCA ceiling, while the gated curved atoms still add sparse
343 /// structure on the residual (#1026 routing-bound finding).
344 pub ungated: Vec<bool>,
345 /// #1033 — AMORTIZED / FROZEN routing. When `Some`, this `(n, K)` matrix is a
346 /// ρ-INVARIANT predicted routing (the amortized `x → logits` map distilled
347 /// from the frozen dictionary): the gates are computed from THESE logits
348 /// instead of the free `self.logits`, and the logits are NOT optimized by the
349 /// inner Newton (their gradient/curvature/prior contributions are zeroed,
350 /// exactly as for [`Self::ungated`]). This is the generalization of an ungated
351 /// atom from "pin the gate at 1" to "pin the gate at the predicted value": it
352 /// makes the per-row routing a fixed function of `x` + the frozen dictionary,
353 /// so the outer ρ-search reuses ONE routing instead of re-solving per-row
354 /// gates every outer eval — the n-independent-outer-loop lever (#1033). `None`
355 /// is the historical free-logit path (bit-identical).
356 pub frozen_logits: Option<Array2<f64>>,
357 /// #1777 PER-FIT IBP-α override — the source of truth for the truncated-IBP
358 /// concentration `α` when set, replacing the process-global
359 /// [`set_ibp_alpha_override`] atomic. `Some(α)` forces the fixed value
360 /// (bypassing the learnable schedule), scoped to THIS assignment/fit so
361 /// concurrent in-process fits are isolated. `None` ⇒ fall back to the
362 /// deprecated process-global override, then to the `AssignmentMode`'s own `α` /
363 /// learnable schedule (bit-identical to the historical path when neither
364 /// override is set). Read via [`Self::resolved_ibp_alpha`]; set from the FFI
365 /// through the term's `set_fit_config`.
366 pub ibp_alpha_override: Option<f64>,
367}
368
369impl SaeAssignment {
370 #[must_use = "build error must be handled"]
371 pub fn new(
372 logits: Array2<f64>,
373 coords: Vec<LatentCoordValues>,
374 temperature: f64,
375 ) -> Result<Self, String> {
376 Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
377 }
378
379 #[must_use = "build error must be handled"]
380 pub fn with_mode(
381 mut logits: Array2<f64>,
382 coords: Vec<LatentCoordValues>,
383 mode: AssignmentMode,
384 ) -> Result<Self, String> {
385 mode.validate()?;
386 let n = logits.nrows();
387 let k = logits.ncols();
388 if coords.len() != k {
389 return Err(format!(
390 "SaeAssignment::new: coords length {} must equal K={k}",
391 coords.len()
392 ));
393 }
394 for (atom, coord) in coords.iter().enumerate() {
395 if coord.n_obs() != n {
396 return Err(format!(
397 "SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
398 coord.n_obs()
399 ));
400 }
401 }
402 for row in 0..n {
403 validate_finite_logits(logits.row(row), row)?;
404 }
405 if matches!(mode, AssignmentMode::Softmax { .. }) {
406 canonicalize_softmax_logits(&mut logits);
407 }
408 Ok(Self {
409 logits,
410 coords,
411 mode,
412 ungated: vec![false; k],
413 frozen_logits: None,
414 ibp_alpha_override: None,
415 })
416 }
417
418 /// #1033 — install a ρ-INVARIANT FROZEN routing (the amortized predicted
419 /// logits; see [`SaeAssignment::frozen_logits`]). `predicted` must be
420 /// `(n, K)`. With routing frozen, the gates are computed from `predicted` and
421 /// the logits are excluded from the inner Newton (their gradient/curvature are
422 /// inert, like an ungated atom's). Passing `None` restores the free-logit
423 /// path.
424 #[must_use = "build error must be handled"]
425 pub fn with_frozen_routing(mut self, predicted: Option<Array2<f64>>) -> Result<Self, String> {
426 if let Some(ref p) = predicted {
427 if p.dim() != (self.n_obs(), self.k_atoms()) {
428 return Err(format!(
429 "SaeAssignment::with_frozen_routing: predicted shape {:?} must be ({}, {})",
430 p.dim(),
431 self.n_obs(),
432 self.k_atoms()
433 ));
434 }
435 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
436 return Err(
437 "SaeAssignment::with_frozen_routing: frozen routing under Softmax is rejected \
438 — the coupled simplex's entropy majorizer is assembled over the logits, which \
439 a frozen (non-optimized) routing would leave inconsistent; this separable-mode \
440 contract supports IBP-MAP and JumpReLU, whose per-atom gates have no \
441 simplex-coupled curvature to skip"
442 .to_string(),
443 );
444 }
445 for row in 0..p.nrows() {
446 validate_finite_logits(p.row(row), row)?;
447 }
448 }
449 self.frozen_logits = predicted;
450 Ok(self)
451 }
452
453 /// Whether the per-row routing is FROZEN (amortized) rather than free-logit.
454 pub fn routing_is_frozen(&self) -> bool {
455 self.frozen_logits.is_some()
456 }
457
458 /// The active routing logits for `row`: the frozen/predicted logits when
459 /// routing is frozen (#1033), else the free `self.logits`. This is the SINGLE
460 /// source the gate value reads, so freezing routing changes every gate
461 /// consistently.
462 pub(crate) fn routing_logits_row(&self, row: usize) -> ArrayView1<'_, f64> {
463 match self.frozen_logits {
464 Some(ref f) => f.row(row),
465 None => self.logits.row(row),
466 }
467 }
468
469 /// Whether atom `k`'s logit is held fixed (not a free Newton parameter): true
470 /// for an ungated atom (#1026, gate pinned at 1) OR when routing is frozen
471 /// (#1033, gate pinned at the predicted value). Both share the same inert
472 /// treatment — zero logit-JVP, zero sparsity-prior gradient/curvature, zero
473 /// softmax majorizer — so the logit slot never moves.
474 pub(crate) fn logit_is_fixed(&self, k: usize) -> bool {
475 self.routing_is_frozen() || self.ungated.get(k).copied().unwrap_or(false)
476 }
477
478 /// Per-atom mask (length `K`) of [`Self::logit_is_fixed`] — the logit slots
479 /// that are NOT free Newton parameters (ungated #1026 and/or frozen-routing
480 /// #1033). Precompute once per assembly and pass to the logit-JVP fillers so
481 /// the data-fit Jacobian zeroes those rows. Under frozen routing every entry
482 /// is `true`; with only ungated atoms it equals `ungated`; otherwise all
483 /// `false` (the historical free-logit path).
484 pub(crate) fn fixed_logit_mask(&self) -> Vec<bool> {
485 if self.routing_is_frozen() {
486 vec![true; self.k_atoms()]
487 } else {
488 self.ungated.clone()
489 }
490 }
491
492 /// #1033 — install the simplest faithful AMORTIZED routing predictor: a
493 /// fixed-form DISTILL of the current dictionary's routing, namely the current
494 /// (converged) logits SNAPSHOTTED as the ρ-invariant frozen routing. This is
495 /// the `x → logits` map "evaluated once at the frozen dictionary" — the
496 /// routing the dictionary already expresses — held fixed so the outer ρ-search
497 /// reuses it instead of re-optimizing the gates at every ρ. (A richer
498 /// predictor that recomputes logits from `x` via the encode-atlas chart
499 /// geometry is a later refinement; snapshotting the converged routing is the
500 /// exact fixed-point it would target at the frozen dictionary.) Rejected for
501 /// Softmax for the same simplex-coupling reason as [`Self::with_frozen_routing`].
502 #[must_use = "build error must be handled"]
503 pub fn freeze_routing_from_current_logits(self) -> Result<Self, String> {
504 let snapshot = self.logits.clone();
505 self.with_frozen_routing(Some(snapshot))
506 }
507
508 /// #1033 — in-place variant of [`Self::freeze_routing_from_current_logits`]
509 /// for callers holding `&mut SaeAssignment` (e.g. inside a `SaeManifoldTerm`),
510 /// where moving the assignment out is awkward. Same contract: snapshot the
511 /// current logits as the ρ-invariant frozen routing; reject Softmax.
512 pub fn freeze_routing_in_place(&mut self) -> Result<(), String> {
513 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
514 return Err(
515 "SaeAssignment::freeze_routing_in_place: frozen routing under Softmax is rejected \
516 (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
517 .to_string(),
518 );
519 }
520 let snapshot = self.logits.clone();
521 for row in 0..snapshot.nrows() {
522 validate_finite_logits(snapshot.row(row), row)?;
523 }
524 self.frozen_logits = Some(snapshot);
525 Ok(())
526 }
527
528 /// #1033 — install an explicit predicted routing in place (the
529 /// [`RoutingPredictor::ChartGeometry`] output), `&mut self` variant of
530 /// [`Self::with_frozen_routing`]. `predicted` must be `(n, K)`; rejects Softmax
531 /// (separable-mode contract) and non-finite predictions.
532 pub fn set_frozen_routing_in_place(&mut self, predicted: Array2<f64>) -> Result<(), String> {
533 if predicted.dim() != (self.n_obs(), self.k_atoms()) {
534 return Err(format!(
535 "SaeAssignment::set_frozen_routing_in_place: predicted shape {:?} must be ({}, {})",
536 predicted.dim(),
537 self.n_obs(),
538 self.k_atoms()
539 ));
540 }
541 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
542 return Err(
543 "SaeAssignment::set_frozen_routing_in_place: frozen routing under Softmax is \
544 rejected (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
545 .to_string(),
546 );
547 }
548 for row in 0..predicted.nrows() {
549 validate_finite_logits(predicted.row(row), row)?;
550 }
551 self.frozen_logits = Some(predicted);
552 Ok(())
553 }
554
555 /// #1033 — lift the frozen routing, restoring the free-logit search path.
556 pub fn thaw_routing(&mut self) {
557 self.frozen_logits = None;
558 }
559
560 /// #1026 — designate which atoms are UNGATED (the dense linear/background
561 /// tier; see [`SaeAssignment::ungated`]). `flags` must have length `K`.
562 ///
563 /// Ungating is defined for the COLUMN-SEPARABLE gate modes (IBP-MAP and
564 /// JumpReLU): each atom's gate is an independent per-atom function of its own
565 /// logit, so pinning one atom to `a_k ≡ 1` leaves every other atom's gate
566 /// exactly as computed. Softmax is a coupled simplex (`Σ_k a_k = 1` over all
567 /// `K`), so a unit gate for one atom is only well defined relative to a
568 /// gated-subset renormalization that must also be reflected in the logit-JVP
569 /// and the entropy majorizer; this constructor's contract is restricted to
570 /// the separable modes, and an ungated atom under Softmax is REJECTED here so
571 /// the inner solve never runs on a value/gradient-mismatched gate. Callers
572 /// wanting a dense background tier under Softmax route it as an IBP-MAP or
573 /// JumpReLU atom.
574 #[must_use = "build error must be handled"]
575 pub fn with_ungated(mut self, flags: Vec<bool>) -> Result<Self, String> {
576 if flags.len() != self.k_atoms() {
577 return Err(format!(
578 "SaeAssignment::with_ungated: flags length {} must equal K={}",
579 flags.len(),
580 self.k_atoms()
581 ));
582 }
583 if matches!(self.mode, AssignmentMode::Softmax { .. }) && flags.iter().any(|&u| u) {
584 return Err(
585 "SaeAssignment::with_ungated: an ungated atom under Softmax routing is \
586 rejected — the coupled simplex requires a gated-subset renormalization \
587 reflected in the logit-JVP and entropy majorizer, which this separable-mode \
588 contract does not perform; route a dense background tier as IBP-MAP or JumpReLU"
589 .to_string(),
590 );
591 }
592 self.ungated = flags;
593 Ok(self)
594 }
595
596 /// Whether any atom is ungated (the #1026 background tier is engaged).
597 pub fn has_ungated(&self) -> bool {
598 self.ungated.iter().any(|&u| u)
599 }
600
601 pub fn n_obs(&self) -> usize {
602 self.logits.nrows()
603 }
604
605 pub fn k_atoms(&self) -> usize {
606 self.logits.ncols()
607 }
608
609 pub fn total_coord_dim(&self) -> usize {
610 self.coords.iter().map(|c| c.latent_dim()).sum()
611 }
612
613 pub fn assignment_coord_dim(&self) -> usize {
614 match self.mode {
615 AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
616 AssignmentMode::IBPMap { .. } | AssignmentMode::ThresholdGate { .. } => self.k_atoms(),
617 }
618 }
619
620 pub fn row_block_dim(&self) -> usize {
621 self.assignment_coord_dim() + self.total_coord_dim()
622 }
623
624 pub fn coord_offsets(&self) -> Vec<usize> {
625 let mut out = Vec::with_capacity(self.k_atoms());
626 let mut cursor = self.assignment_coord_dim();
627 for coord in &self.coords {
628 out.push(cursor);
629 cursor += coord.latent_dim();
630 }
631 out
632 }
633
634 pub fn assignments(&self) -> Array2<f64> {
635 let n = self.n_obs();
636 let k = self.k_atoms();
637 let mut out = Array2::<f64>::zeros((n, k));
638 for row in 0..n {
639 let a = self.assignments_row(row);
640 for atom in 0..k {
641 out[[row, atom]] = a[atom];
642 }
643 }
644 out
645 }
646
647 pub fn assignments_row(&self, row: usize) -> Array1<f64> {
648 self.try_assignments_row(row)
649 .expect("assignment logits must be finite")
650 }
651
652 pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
653 self.try_assignments_row_with_alpha(row, None)
654 }
655
656 /// #1777 — the effective truncated-IBP `α` for this assignment at `rho`,
657 /// honoring the PER-FIT [`Self::ibp_alpha_override`] first (source of truth),
658 /// then the deprecated process-global override, then the mode's own schedule.
659 /// The single seam every gate/jet/prior site reads so the per-fit override is
660 /// applied consistently. `None` for non-IBP modes.
661 pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
662 self.mode.resolved_ibp_alpha(rho, self.ibp_alpha_override)
663 }
664
665 /// #1777 — install (or clear, with `None`) the PER-FIT IBP-α override on this
666 /// assignment. Source of truth used by [`Self::resolved_ibp_alpha`]; the FFI
667 /// reaches it through the term's `set_fit_config`.
668 pub fn set_ibp_alpha_override(&mut self, alpha: Option<f64>) {
669 self.ibp_alpha_override = alpha;
670 }
671
672 pub(crate) fn try_assignments_row_for_rho(
673 &self,
674 row: usize,
675 rho: &SaeManifoldRho,
676 ) -> Result<Array1<f64>, String> {
677 self.try_assignments_row_with_alpha(row, self.resolved_ibp_alpha(rho))
678 }
679
680 fn try_assignments_row_with_alpha(
681 &self,
682 row: usize,
683 resolved_ibp_alpha: Option<f64>,
684 ) -> Result<Array1<f64>, String> {
685 // #1033 — read the ACTIVE routing logits: the ρ-invariant frozen/predicted
686 // logits when routing is frozen, else the free `self.logits`. This single
687 // source makes the gate value ρ-invariant under frozen routing (the
688 // amortized-routing lever) and bit-identical to the historical path when
689 // not frozen.
690 let routing = self.routing_logits_row(row);
691 validate_finite_logits(routing, row)?;
692 // Only Softmax collapses to a fixed assignment at K==1: its
693 // assignment_coord_dim is K-1 = 0, so there is no free logit. IBPMap and
694 // JumpReLU keep a free per-atom gate logit even at K==1
695 // (assignment_coord_dim = K = 1), so they must fall through to their real
696 // row functions or the logit would move the prior but not the gate.
697 if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
698 return Ok(Array1::from_vec(vec![1.0]));
699 }
700 let mut row_gates = match self.mode {
701 AssignmentMode::Softmax { temperature, .. } => softmax_row(routing, temperature),
702 AssignmentMode::IBPMap {
703 temperature, alpha, ..
704 } => ibp_map_row(routing, temperature, resolved_ibp_alpha.unwrap_or(alpha)),
705 AssignmentMode::ThresholdGate {
706 temperature,
707 threshold,
708 } => jumprelu_row(routing, temperature, threshold),
709 };
710 // #1026 — ungated (background-tier) atoms have a fixed unit gate. For the
711 // column-separable IBP / JumpReLU modes the other atoms' gates are
712 // computed independently above, so overwriting the ungated entries to 1.0
713 // leaves the gated atoms exactly as they were; the ungated atom then
714 // contributes `γ_k(t_k)` unweighted to every row. (Softmax + ungated is
715 // rejected at `with_ungated`, so no simplex renormalization is needed
716 // here.)
717 if self.has_ungated() {
718 for (k, gate) in row_gates.iter_mut().enumerate() {
719 if self.ungated[k] {
720 *gate = 1.0;
721 }
722 }
723 }
724 Ok(row_gates)
725 }
726
727 /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_for_rho`].
728 ///
729 /// Writes the EXACT SAME per-atom assignment row into `out` (length
730 /// `k_atoms()`) instead of allocating a fresh `Array1`. Bit-identical to the
731 /// allocating path; intended for the hot per-row loops that immediately
732 /// consume the row, reusing a single scratch buffer across rows.
733 pub(crate) fn try_assignments_row_for_rho_into(
734 &self,
735 row: usize,
736 rho: &SaeManifoldRho,
737 out: &mut [f64],
738 ) -> Result<(), String> {
739 self.try_assignments_row_with_alpha_into(row, self.resolved_ibp_alpha(rho), out)
740 }
741
742 /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_with_alpha`].
743 ///
744 /// `out` must have length `k_atoms()`; it is fully overwritten with the same
745 /// values the allocating variant would return. Every branch (early-return
746 /// K==1 Softmax, the per-mode row math, the #1026 ungated overwrite) mirrors
747 /// the allocating path exactly so the two are bit-identical.
748 pub(crate) fn try_assignments_row_with_alpha_into(
749 &self,
750 row: usize,
751 resolved_ibp_alpha: Option<f64>,
752 out: &mut [f64],
753 ) -> Result<(), String> {
754 // `out` is sized `k_atoms()` by every caller; the per-mode helpers below
755 // fully overwrite indices `0..k_atoms()`.
756 let routing = self.routing_logits_row(row);
757 validate_finite_logits(routing, row)?;
758 // Mirror the allocating early-return: only Softmax collapses to a fixed
759 // unit assignment at K==1.
760 if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
761 out[0] = 1.0;
762 return Ok(());
763 }
764 match self.mode {
765 AssignmentMode::Softmax { temperature, .. } => {
766 softmax_row_into(routing, temperature, out)
767 }
768 AssignmentMode::IBPMap {
769 temperature, alpha, ..
770 } => ibp_map_row_into(
771 routing,
772 temperature,
773 resolved_ibp_alpha.unwrap_or(alpha),
774 out,
775 ),
776 AssignmentMode::ThresholdGate {
777 temperature,
778 threshold,
779 } => jumprelu_row_into(routing, temperature, threshold, out),
780 };
781 // #1026 — ungated (background-tier) atoms have a fixed unit gate, exactly
782 // as in the allocating path.
783 if self.has_ungated() {
784 for (k, gate) in out.iter_mut().enumerate() {
785 if self.ungated[k] {
786 *gate = 1.0;
787 }
788 }
789 }
790 Ok(())
791 }
792
793 pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
794 let AssignmentMode::IBPMap {
795 temperature,
796 alpha,
797 learnable_alpha: true,
798 } = self.mode
799 else {
800 return false;
801 };
802 let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
803 self.mode = AssignmentMode::IBPMap {
804 temperature,
805 alpha: resolved_alpha,
806 learnable_alpha: false,
807 };
808 true
809 }
810
811 pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
812 let n = self.n_obs();
813 let k = self.k_atoms();
814 let mut out = Array2::<f64>::zeros((n, k));
815 for row in 0..n {
816 let a = self.try_assignments_row_for_rho(row, rho)?;
817 for atom in 0..k {
818 out[[row, atom]] = a[atom];
819 }
820 }
821 Ok(out)
822 }
823
824 /// Flatten extension coordinates in row-major SAE layout:
825 /// `(assignment chart_i, t_i0[0..d_0], ..., t_iK[0..d_K])` for every row.
826 /// Softmax contributes the first `K - 1` reference logits and omits the
827 /// fixed reference logit; gate-style assignment modes contribute all `K`
828 /// logits.
829 pub fn flatten_ext_coords(&self) -> Array1<f64> {
830 let n = self.n_obs();
831 let q = self.row_block_dim();
832 let k = self.k_atoms();
833 let assignment_dim = self.assignment_coord_dim();
834 let offsets = self.coord_offsets();
835 let mut out = Array1::<f64>::zeros(n * q);
836 for row in 0..n {
837 let base = row * q;
838 for atom in 0..assignment_dim {
839 out[base + atom] = self.logits[[row, atom]];
840 }
841 for atom in 0..k {
842 let d = self.coords[atom].latent_dim();
843 let t_row = self.coords[atom].row(row);
844 for axis in 0..d {
845 out[base + offsets[atom] + axis] = t_row[axis];
846 }
847 }
848 }
849 out
850 }
851
852 #[must_use = "build error must be handled"]
853 pub fn from_blocks_with_mode(
854 logits: Array2<f64>,
855 coord_blocks: Vec<Array2<f64>>,
856 mode: AssignmentMode,
857 ) -> Result<Self, String> {
858 let coords = coord_blocks
859 .iter()
860 .map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
861 .collect();
862 Self::with_mode(logits, coords, mode)
863 }
864
865 #[must_use = "build error must be handled"]
866 pub fn from_blocks_with_mode_and_manifolds(
867 logits: Array2<f64>,
868 coord_blocks: Vec<Array2<f64>>,
869 manifolds: Vec<LatentManifold>,
870 mode: AssignmentMode,
871 ) -> Result<Self, String> {
872 if coord_blocks.len() != manifolds.len() {
873 return Err(format!(
874 "SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
875 coord_blocks.len(),
876 manifolds.len()
877 ));
878 }
879 let coords = coord_blocks
880 .iter()
881 .zip(manifolds)
882 .map(|(c, manifold)| {
883 LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
884 })
885 .collect();
886 Self::with_mode(logits, coords, mode)
887 }
888}
889
890pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
891 match mode {
892 AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
893 AssignmentMode::IBPMap {
894 temperature, alpha, ..
895 } => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
896 AssignmentMode::ThresholdGate { .. } => Array1::from_elem(k_atoms, 0.5),
897 }
898}
899
900pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
901 let k = logits.len();
902 let inv_tau = 1.0 / temperature;
903 let mut max_logit = f64::NEG_INFINITY;
904 for &v in logits.iter() {
905 max_logit = max_logit.max(v);
906 }
907 let mut out = Array1::<f64>::zeros(k);
908 let mut sum = 0.0;
909 for i in 0..k {
910 let v = ((logits[i] - max_logit) * inv_tau).exp();
911 out[i] = v;
912 sum += v;
913 }
914 assert!(sum.is_finite() && sum > 0.0);
915 for v in out.iter_mut() {
916 *v /= sum;
917 }
918 out
919}
920
921pub(crate) fn validate_finite_logits(
922 logits: ArrayView1<'_, f64>,
923 row: usize,
924) -> Result<(), String> {
925 for (col, &v) in logits.iter().enumerate() {
926 if !v.is_finite() {
927 return Err(format!(
928 "SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
929 ));
930 }
931 }
932 Ok(())
933}
934
935pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
936 let k = logits.ncols();
937 if k == 0 {
938 return;
939 }
940 if k == 1 {
941 logits.fill(0.0);
942 return;
943 }
944 for row in 0..logits.nrows() {
945 let reference = logits[[row, k - 1]];
946 for col in 0..k - 1 {
947 logits[[row, col]] -= reference;
948 }
949 logits[[row, k - 1]] = 0.0;
950 }
951}
952
953/// Truncated Indian-Buffet-Process stick-breaking prior *means*
954/// `π_k = E[∏_{j=0}^{k} v_j] = (α/(α+1))^{k+1}` for k = 0, .., K-1, with sticks
955/// `v_j ~ Beta(α, 1)` so `E[v_j] = α/(α+1)`. EVERY atom (including the first,
956/// `π_0 = α/(α+1)`) carries the consistent Beta(α, 1) shrinkage: there is no
957/// special-cased always-on base atom, so `α` behaves as a genuine IBP
958/// concentration — larger `α` ⇒ heavier mass / slower decay, `α → 0` ⇒ all mass
959/// collapses onto nothing, matching the stick-breaking limit. This is the
960/// deterministic MAP / mean-field form of the IBP prior (the closed form the
961/// analytic Newton / Hessian / Woodbury machinery differentiates); no sticks are
962/// *sampled* here, the per-atom weight is the exact expectation of the
963/// stick-breaking product. (#614: previously `π_0 = 1` left the first atom
964/// unshrunk, which is the prior mean of NO stick at all and broke α's role as a
965/// concentration; the consistent product mean restores genuine IBP semantics.)
966pub(crate) fn ordered_geometric_shrinkage_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
967 // Accumulate the geometric schedule `π_k = ratio^(k+1)` in LOG space so the
968 // prior stays a finite *soft* weight even for large `K`. The naive product
969 // `acc *= ratio` underflows to exact `0.0` once `ratio^(k+1) < f64::MIN_POSITIVE`
970 // (e.g. `(0.1/1.1)^320`), which would turn the soft shrinkage prior into a
971 // HARD mask: such atoms would receive zero assignment AND zero logit
972 // gradient (the gradient is multiplied by `π_k`), so they could never
973 // reactivate. Working in log-space and flooring the exponentiated weight at
974 // the smallest positive normal keeps every atom's gradient path alive while
975 // preserving the geometric ordering.
976 let mut out = Array1::<f64>::zeros(k_atoms);
977 let log_ratio = (alpha / (alpha + 1.0)).ln();
978 for k in 0..k_atoms {
979 // π_k = (α/(α+1))^{k+1}: the product of (k+1) i.i.d. Beta(α,1) stick
980 // means, so atom 0 is also shrunk by one stick (E[v_0] = α/(α+1)).
981 let log_pi = ((k + 1) as f64) * log_ratio;
982 out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
983 }
984 out
985}
986
987/// #1784 — K-aware default IBP concentration.
988///
989/// The ordered stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` decays
990/// GEOMETRICALLY in the atom INDEX, so a fixed small concentration (the
991/// historical default `α = 1`, i.e. the `(0.5)^{k+1}` schedule) collapses to a
992/// near-hard mask past atom ~3: a K-atom dictionary can then only ever place
993/// mass on its first handful of atoms. That is exactly why the manifold SAE
994/// UNDERFITS a linear dictionary of equal K on real activations, and why its
995/// late atoms carry zero mass and leave the per-row joint Hessian rank-deficient
996/// (the K = 128 `RemlConvergenceError`).
997///
998/// For a K-atom dictionary to actually USE all K atoms the IBP concentration must
999/// scale with K. Choosing `α` so the LAST atom retains prior mass
1000/// `π_{K-1} = (α/(α+1))^K ≈ e^{-1}` spans the whole dictionary while keeping the
1001/// prior monotone (still an honest ordered stick-breaking prior — no atom is
1002/// structurally masked). Solving `(α/(α+1))^K = e^{-1}` gives
1003/// `α = 1/(exp(1/K) − 1) ≈ K − 1/2`. Floored at `1.0` so `K = 1` keeps the
1004/// historical `α = 1`.
1005pub fn default_ibp_concentration_for_k_atoms(k_atoms: usize) -> f64 {
1006 let k = k_atoms.max(1) as f64;
1007 // π_{K-1} = (α/(α+1))^K = e^{-1} ⇒ α = 1/(e^{1/K} − 1).
1008 let alpha = 1.0 / ((1.0 / k).exp() - 1.0);
1009 alpha.max(1.0)
1010}
1011
1012/// IBP-MAP row activations: per-atom sigmoid likelihood times the truncated
1013/// stick-breaking prior mean `π_k = (α/(α+1))^{k+1}`. With tied logits the prior
1014/// dominates and yields strictly decreasing activations in atom index, with the
1015/// first atom already shrunk by one Beta(α,1) stick mean (no unshrunk base atom).
1016pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
1017 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1018 let mut out = Array1::<f64>::zeros(logits.len());
1019 for i in 0..logits.len() {
1020 out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
1021 }
1022 out
1023}
1024
1025/// IBP-MAP activations together with the diagonal Jacobian `∂z_k/∂l_k`,
1026/// shared with the torch autograd `Function` so the Python IBP-Gumbel path
1027/// applies the same stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` and
1028/// temperature scaling as the Rust closed form. With `z_k = σ(l_k/τ)·π_k` the
1029/// per-atom derivative is
1030/// `σ(l_k/τ)(1 − σ(l_k/τ))·π_k / τ`; the map is diagonal in `k`, so the
1031/// Jacobian is returned as the per-atom diagonal vector.
1032#[must_use]
1033pub fn ibp_map_row_value_grad(
1034 logits: ArrayView1<'_, f64>,
1035 temperature: f64,
1036 alpha: f64,
1037) -> (Array1<f64>, Array1<f64>) {
1038 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1039 let inv_tau = 1.0 / temperature;
1040 let mut value = Array1::<f64>::zeros(logits.len());
1041 let mut grad = Array1::<f64>::zeros(logits.len());
1042 for i in 0..logits.len() {
1043 let sig = gam_linalg::utils::stable_logistic(logits[i] * inv_tau);
1044 value[i] = sig * prior[i];
1045 grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
1046 }
1047 (value, grad)
1048}
1049
1050pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
1051 let mut out = Array1::<f64>::zeros(logits.len());
1052 for i in 0..logits.len() {
1053 // Hard gate: strictly zero below threshold (the intended "jump"). Above
1054 // threshold the surrogate is centered at the threshold so the gate is
1055 // most informative exactly at the boundary it switches on:
1056 // σ((l−θ)/τ) ∈ [0.5, 1). Magnitude lives in the decoder, so the gate
1057 // stays bounded in [0, 1] by design.
1058 if logits[i] > threshold {
1059 out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1060 }
1061 }
1062 out
1063}
1064
1065/// Bounded threshold-gate activations together with the straight-through
1066/// derivative `∂a_k/∂l_k` of the smooth surrogate, shared with the torch
1067/// autograd `Function` so the torch `jumprelu` lane applies the SAME bounded
1068/// gate as the closed-form fit (`jumprelu_row`): `a_k = σ((l_k − θ_k)/τ)` for
1069/// `l_k > θ_k` and exactly `0` otherwise — magnitude lives in the decoder, the
1070/// gate stays in `[0, 1)`. The returned derivative is the smooth surrogate's
1071/// `σ'((l_k − θ_k)/τ)/τ` evaluated on BOTH sides of the jump (a straight-through
1072/// estimator: the hard forward has zero derivative below threshold, which would
1073/// permanently kill gradient flow to gated-off atoms). `∂a_k/∂θ_k` is the
1074/// negation of the returned logit derivative; callers negate. `thresholds` is
1075/// per-atom (the torch lane learns one threshold per atom); the scalar
1076/// closed-form threshold is the constant-vector special case and the value
1077/// arithmetic matches `jumprelu_row` exactly there.
1078#[must_use]
1079pub fn jumprelu_row_value_grad(
1080 logits: ArrayView1<'_, f64>,
1081 temperature: f64,
1082 thresholds: ArrayView1<'_, f64>,
1083) -> (Array1<f64>, Array1<f64>) {
1084 assert_eq!(
1085 logits.len(),
1086 thresholds.len(),
1087 "jumprelu_row_value_grad: logits/thresholds length mismatch"
1088 );
1089 let inv_tau = 1.0 / temperature;
1090 let mut value = Array1::<f64>::zeros(logits.len());
1091 let mut grad = Array1::<f64>::zeros(logits.len());
1092 for i in 0..logits.len() {
1093 let sig = gam_linalg::utils::stable_logistic((logits[i] - thresholds[i]) * inv_tau);
1094 if logits[i] > thresholds[i] {
1095 value[i] = sig;
1096 }
1097 grad[i] = sig * (1.0 - sig) * inv_tau;
1098 }
1099 (value, grad)
1100}
1101
1102// #1557 — fill-into-caller-buffer variants of the three per-mode row functions.
1103// These compute the EXACT SAME values as `softmax_row` / `ibp_map_row` /
1104// `jumprelu_row` (same arithmetic, same order of operations) but write into a
1105// caller-provided `&mut [f64]` slice instead of heap-allocating a fresh
1106// `Array1<f64>` per call. The hot per-row loops (loss eval, arrow/Schur row
1107// loops) call these with a reused scratch buffer, eliminating millions of tiny
1108// K-sized allocations while staying bit-identical to the allocating path.
1109// `out` must have length `logits.len()`; the slice is fully overwritten.
1110
1111pub(crate) fn softmax_row_into(logits: ArrayView1<'_, f64>, temperature: f64, out: &mut [f64]) {
1112 let k = logits.len();
1113 let inv_tau = 1.0 / temperature;
1114 let mut max_logit = f64::NEG_INFINITY;
1115 for &v in logits.iter() {
1116 max_logit = max_logit.max(v);
1117 }
1118 let mut sum = 0.0;
1119 for i in 0..k {
1120 let v = ((logits[i] - max_logit) * inv_tau).exp();
1121 out[i] = v;
1122 sum += v;
1123 }
1124 assert!(sum.is_finite() && sum > 0.0);
1125 for v in out.iter_mut() {
1126 *v /= sum;
1127 }
1128}
1129
1130pub(crate) fn ibp_map_row_into(
1131 logits: ArrayView1<'_, f64>,
1132 temperature: f64,
1133 alpha: f64,
1134 out: &mut [f64],
1135) {
1136 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
1137 for i in 0..logits.len() {
1138 out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
1139 }
1140}
1141
1142pub(crate) fn jumprelu_row_into(
1143 logits: ArrayView1<'_, f64>,
1144 temperature: f64,
1145 threshold: f64,
1146 out: &mut [f64],
1147) {
1148 for i in 0..logits.len() {
1149 // Match `jumprelu_row`: strictly zero below threshold, sigmoid surrogate
1150 // above. The buffer is fully overwritten (no read of prior contents).
1151 if logits[i] > threshold {
1152 out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1153 } else {
1154 out[i] = 0.0;
1155 }
1156 }
1157}
1158
1159pub(crate) struct ActiveAtomLogitJvp<'a> {
1160 pub(crate) mode: AssignmentMode,
1161 pub(crate) k: usize,
1162 pub(crate) logit_k: f64,
1163 pub(crate) a_k: f64,
1164 pub(crate) decoded_k: ArrayView1<'a, f64>,
1165 pub(crate) fitted: ArrayView1<'a, f64>,
1166 pub(crate) ibp_prior: Option<&'a [f64]>,
1167 pub(crate) compact_index: usize,
1168 /// #1026 — when `true`, atom `k` is the ungated background tier: its gate is
1169 /// the constant `1`, so its logit-JVP `da_k/dl_k` is identically zero (the
1170 /// compact row is left untouched / zero).
1171 pub(crate) ungated: bool,
1172}
1173
1174/// Fill the single compact logit-JVP row for active atom `k`, using the
1175/// per-mode assignment sensitivity `da_k/dl_k` contracted into the decoded /
1176/// fitted-corrected output direction. This is the active-set analogue of
1177/// [`fill_assignment_logit_jvp_rows`]: it reproduces that function's diagonal
1178/// logit row exactly for the atom `k`, but writes into a compact position of a
1179/// heterogeneous-`q` row block instead of the dense full-`K` Jacobian. `fitted`
1180/// is the row's *active-set* reconstruction so the softmax cross term
1181/// `(decoded_k − fitted)` is consistent with the curvature the compact block
1182/// carries.
1183pub(crate) fn fill_active_atom_logit_jvp(
1184 input: ActiveAtomLogitJvp<'_>,
1185 jac_compact: &mut Array2<f64>,
1186) {
1187 let ActiveAtomLogitJvp {
1188 mode,
1189 k,
1190 logit_k,
1191 a_k,
1192 decoded_k,
1193 fitted,
1194 ibp_prior,
1195 compact_index,
1196 ungated,
1197 } = input;
1198 let p = fitted.len();
1199 // #1026 — an ungated atom's gate is constant, so its logit-JVP is zero; leave
1200 // its compact row untouched (the buffer row is pre-zeroed by the caller).
1201 if ungated {
1202 return;
1203 }
1204 match mode {
1205 AssignmentMode::Softmax { temperature, .. } => {
1206 // da_k/dl_k contracted: a_k (decoded_k − fitted) / τ.
1207 let inv_tau = 1.0 / temperature;
1208 for out_col in 0..p {
1209 jac_compact[[compact_index, out_col]] =
1210 a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
1211 }
1212 }
1213 AssignmentMode::IBPMap { temperature, .. } => {
1214 // z_k = σ(l_k/τ)·π_k ⇒ dz_k/dl_k = a_k(π_k − a_k)/(π_k τ) · π_k form
1215 // (matches `fill_assignment_logit_jvp_rows`).
1216 let inv_tau = 1.0 / temperature;
1217 let prior =
1218 ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
1219 let pi_k = prior[k];
1220 let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1221 let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1222 for out_col in 0..p {
1223 jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
1224 }
1225 }
1226 AssignmentMode::ThresholdGate {
1227 temperature,
1228 threshold,
1229 } => {
1230 // The data-fit Jacobian follows the hard forward gate. Below the
1231 // threshold the reconstruction contribution is exactly zero, so the
1232 // data-fit logit derivative must also be zero. Band-only atoms stay
1233 // in the compact row for prior terms, not phantom reconstruction
1234 // slope.
1235 if logit_k <= threshold {
1236 return;
1237 }
1238 let inv_tau = 1.0 / temperature;
1239 let activation = gam_linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
1240 let da = activation * (1.0 - activation) * inv_tau;
1241 for out_col in 0..p {
1242 jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
1243 }
1244 }
1245 }
1246}
1247
1248pub(crate) fn fill_assignment_logit_jvp_rows(
1249 mode: AssignmentMode,
1250 logits: ArrayView1<'_, f64>,
1251 assignments: ArrayView1<'_, f64>,
1252 decoded: ArrayView2<'_, f64>,
1253 fitted: ArrayView1<'_, f64>,
1254 ibp_prior: Option<&[f64]>,
1255 // #1026 — per-atom ungated flags (length `K`). An ungated atom's gate is
1256 // constant, so its logit-JVP row is identically zero (skipped below). Empty
1257 // ⇒ no atom is ungated (the historical path, bit-identical).
1258 ungated: &[bool],
1259 local_jac: &mut Array2<f64>,
1260) {
1261 let is_ungated = |k: usize| ungated.get(k).copied().unwrap_or(false);
1262 match mode {
1263 AssignmentMode::Softmax { temperature, .. } => {
1264 if assignments.len() == 1 {
1265 return;
1266 }
1267 // da_k/dl_j = a_k (1[k=j] - a_j) / tau, contracted against
1268 // the assignment-weighted fitted row. The dense row layout uses
1269 // the reference-logit chart, so only columns `0..K-1` are free;
1270 // the final reference logit is fixed at zero and has no row.
1271 let inv_tau = 1.0 / temperature;
1272 for logit_col in 0..assignments.len() - 1 {
1273 if is_ungated(logit_col) {
1274 continue;
1275 }
1276 for out_col in 0..fitted.len() {
1277 local_jac[[logit_col, out_col]] = assignments[logit_col]
1278 * (decoded[[logit_col, out_col]] - fitted[out_col])
1279 * inv_tau;
1280 }
1281 }
1282 }
1283 AssignmentMode::IBPMap { temperature, .. } => {
1284 // Truncated-IBP concrete relaxation: z_k = σ(l_k/τ) · π_k where
1285 // π_k is the stick-breaking prior. Thus
1286 // dz_k/dl_k = σ(l/τ)(1-σ(l/τ))/τ · π_k = a_k(π_k - a_k)/(π_k τ).
1287 let inv_tau = 1.0 / temperature;
1288 let prior = ibp_prior
1289 .expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
1290 for logit_col in 0..assignments.len() {
1291 if is_ungated(logit_col) {
1292 continue;
1293 }
1294 let pi_k = prior[logit_col];
1295 let a_k = assignments[logit_col];
1296 let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1297 let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1298 for out_col in 0..fitted.len() {
1299 local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
1300 }
1301 }
1302 }
1303 AssignmentMode::ThresholdGate {
1304 temperature,
1305 threshold,
1306 } => {
1307 // Data-fit sensitivity follows the hard forward gate: rows at or
1308 // below the threshold have zero reconstruction value and therefore
1309 // zero data-fit logit derivative. The wider machine-precision prior
1310 // support is a compact-layout/prior rule, not a data-fit STE.
1311 let inv_tau = 1.0 / temperature;
1312 for logit_col in 0..assignments.len() {
1313 if is_ungated(logit_col) || logits[logit_col] <= threshold {
1314 continue;
1315 }
1316 let activation = gam_linalg::utils::stable_logistic(
1317 (logits[logit_col] - threshold) * inv_tau,
1318 );
1319 let da = activation * (1.0 - activation) * inv_tau;
1320 for out_col in 0..fitted.len() {
1321 local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
1322 }
1323 }
1324 }
1325 }
1326}
1327
1328pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
1329 let mut out = Array1::<f64>::zeros(logits.len());
1330 for row in 0..logits.nrows() {
1331 let start = row * logits.ncols();
1332 for col in 0..logits.ncols() {
1333 out[start + col] = logits[[row, col]];
1334 }
1335 }
1336 out
1337}
1338
1339pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
1340 for row in 0..assignment.n_obs() {
1341 validate_finite_logits(assignment.logits.row(row), row)
1342 .expect("assignment logits must be finite");
1343 }
1344 let target = flat_logits(assignment.logits.view());
1345 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1346 return 0.0;
1347 }
1348 match assignment.mode {
1349 AssignmentMode::Softmax {
1350 temperature,
1351 sparsity,
1352 } => {
1353 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1354 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1355 penalty.value(target.view(), rho_view.view())
1356 }
1357 AssignmentMode::IBPMap {
1358 temperature,
1359 alpha,
1360 learnable_alpha,
1361 } => {
1362 let mut penalty = IBPAssignmentPenalty::new(
1363 assignment.k_atoms(),
1364 alpha,
1365 temperature,
1366 learnable_alpha,
1367 );
1368 let rho_view = if learnable_alpha {
1369 Array1::from_vec(vec![rho.log_lambda_sparse])
1370 } else {
1371 // Keep the fixed-alpha value path on the same weighting branch as
1372 // assignment_prior_grad_hdiag; that gradient path owns the
1373 // lambda_sparse convention for IBP assignment sparsity.
1374 penalty.weight = rho.lambda_sparse();
1375 Array1::zeros(0)
1376 };
1377 penalty.value(target.view(), rho_view.view())
1378 }
1379 AssignmentMode::ThresholdGate {
1380 temperature,
1381 threshold,
1382 } => {
1383 // Sparsity penalty uses the same threshold-centered surrogate and
1384 // machine-precision support as its gradient/Hessian. Data-fit
1385 // reconstruction remains hard-gated by `jumprelu_row`.
1386 let sparsity_strength = rho.lambda_sparse();
1387 let mut acc = 0.0;
1388 for &logit in target.iter() {
1389 if jumprelu_in_optimization_band(logit, threshold, temperature) {
1390 acc += gam_linalg::utils::stable_logistic((logit - threshold) / temperature);
1391 }
1392 }
1393 sparsity_strength * acc
1394 }
1395 }
1396}
1397
1398pub(crate) fn assignment_prior_log_strength_derivative(
1399 assignment: &SaeAssignment,
1400 rho: &SaeManifoldRho,
1401) -> f64 {
1402 for row in 0..assignment.n_obs() {
1403 validate_finite_logits(assignment.logits.row(row), row)
1404 .expect("assignment logits must be finite");
1405 }
1406 let target = flat_logits(assignment.logits.view());
1407 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1408 return 0.0;
1409 }
1410 match assignment.mode {
1411 AssignmentMode::Softmax { .. } | AssignmentMode::ThresholdGate { .. } => {
1412 assignment_prior_value(assignment, rho)
1413 }
1414 AssignmentMode::IBPMap {
1415 temperature,
1416 alpha,
1417 learnable_alpha,
1418 } => {
1419 let mut penalty = IBPAssignmentPenalty::new(
1420 assignment.k_atoms(),
1421 alpha,
1422 temperature,
1423 learnable_alpha,
1424 );
1425 if learnable_alpha {
1426 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1427 penalty.grad_rho(target.view(), rho_view.view())[0]
1428 } else {
1429 penalty.weight = rho.lambda_sparse();
1430 penalty.value(target.view(), Array1::<f64>::zeros(0).view())
1431 }
1432 }
1433 }
1434}
1435
1436pub(crate) fn assignment_prior_log_strength_hdiag(
1437 assignment: &SaeAssignment,
1438 rho: &SaeManifoldRho,
1439) -> Result<Array1<f64>, String> {
1440 for row in 0..assignment.n_obs() {
1441 validate_finite_logits(assignment.logits.row(row), row)?;
1442 }
1443 let target = flat_logits(assignment.logits.view());
1444 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1445 return Ok(Array1::<f64>::zeros(target.len()));
1446 }
1447 match assignment.mode {
1448 AssignmentMode::Softmax {
1449 temperature,
1450 sparsity,
1451 } => {
1452 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1453 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1454 penalty
1455 .hessian_diag(target.view(), rho_view.view())
1456 .ok_or_else(|| {
1457 "softmax assignment log-strength hessian diag unavailable".to_string()
1458 })
1459 }
1460 AssignmentMode::ThresholdGate {
1461 temperature,
1462 threshold,
1463 } => {
1464 let sparsity_strength = rho.lambda_sparse();
1465 let inv_tau = 1.0 / temperature;
1466 let inv_tau2 = inv_tau * inv_tau;
1467 let mut d = Array1::<f64>::zeros(target.len());
1468 for idx in 0..target.len() {
1469 let logit = target[idx];
1470 if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1471 continue;
1472 }
1473 let activation =
1474 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1475 let slope = activation * (1.0 - activation);
1476 d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1477 }
1478 Ok(d)
1479 }
1480 AssignmentMode::IBPMap {
1481 temperature,
1482 alpha,
1483 learnable_alpha,
1484 } => {
1485 let mut penalty = IBPAssignmentPenalty::new(
1486 assignment.k_atoms(),
1487 alpha,
1488 temperature,
1489 learnable_alpha,
1490 );
1491 if learnable_alpha {
1492 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1493 Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
1494 } else {
1495 penalty.weight = rho.lambda_sparse();
1496 penalty
1497 .hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
1498 .ok_or_else(|| {
1499 "IBP assignment log-strength hessian diag unavailable".to_string()
1500 })
1501 }
1502 }
1503 }
1504}
1505
1506pub(crate) fn assignment_prior_log_strength_target_mixed(
1507 assignment: &SaeAssignment,
1508 rho: &SaeManifoldRho,
1509) -> Result<Array1<f64>, String> {
1510 for row in 0..assignment.n_obs() {
1511 validate_finite_logits(assignment.logits.row(row), row)?;
1512 }
1513 let target = flat_logits(assignment.logits.view());
1514 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1515 return Ok(Array1::<f64>::zeros(target.len()));
1516 }
1517 match assignment.mode {
1518 AssignmentMode::IBPMap {
1519 temperature,
1520 alpha,
1521 learnable_alpha: true,
1522 } => {
1523 let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
1524 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1525 Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
1526 }
1527 _ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
1528 }
1529}
1530
1531pub(crate) fn assignment_prior_grad_hdiag(
1532 assignment: &SaeAssignment,
1533 rho: &SaeManifoldRho,
1534) -> Result<(Array1<f64>, Array1<f64>), String> {
1535 for row in 0..assignment.n_obs() {
1536 validate_finite_logits(assignment.logits.row(row), row)?;
1537 }
1538 let target = flat_logits(assignment.logits.view());
1539 let mut grad = Array1::<f64>::zeros(target.len());
1540 let mut diag = Array1::<f64>::zeros(target.len());
1541 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1542 return Ok((grad, diag));
1543 }
1544 let (sparsity_grad, sparsity_diag) = match assignment.mode {
1545 AssignmentMode::Softmax {
1546 temperature,
1547 sparsity,
1548 } => {
1549 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1550 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1551 let g = penalty.grad_target(target.view(), rho_view.view());
1552 let d = penalty
1553 .hessian_diag(target.view(), rho_view.view())
1554 .ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
1555 (g, d)
1556 }
1557 AssignmentMode::IBPMap {
1558 temperature,
1559 alpha,
1560 learnable_alpha,
1561 } => {
1562 // Scale the IBP assignment-sparsity prior by `lambda_sparse`, exactly
1563 // like the Softmax and JumpReLU branches do (Softmax folds it into the
1564 // penalty's rho coordinate, JumpReLU multiplies `sparsity_strength`).
1565 // Previously the IBP penalty used its hardcoded `weight = 1.0` and the
1566 // `rho.log_lambda_sparse` coordinate never reached it (the rho_view was
1567 // empty for the common `learnable_alpha = false` config), so the prior
1568 // ran at full strength with no way to dial it down — and its
1569 // Beta-Bernoulli BCE energy `−mass·ln π_k − (n−mass)·ln(1−π_k)` toward
1570 // the self-referential empirical active fraction `π_k` has its global
1571 // minimum at the all-off gate, so at full weight it over-shrank the
1572 // assignment off both atoms even with a truth-seeded decoder (#853).
1573 // Routing `lambda_sparse` into the penalty weight makes the prior a
1574 // genuine, user-controllable lever balanced against the data fit.
1575 let mut penalty = IBPAssignmentPenalty::new(
1576 assignment.k_atoms(),
1577 alpha,
1578 temperature,
1579 learnable_alpha,
1580 );
1581 // When `alpha` is learnable, `log_lambda_sparse` already modulates
1582 // it through `resolved_alpha(rho)`, so the weight stays 1.0 to avoid
1583 // double-counting that coordinate. Only when `alpha` is fixed (so the
1584 // sparse coordinate would otherwise be ignored entirely) does
1585 // `lambda_sparse` become the prior's weight lever.
1586 let rho_view = if learnable_alpha {
1587 Array1::from_vec(vec![rho.log_lambda_sparse])
1588 } else {
1589 penalty.weight = rho.lambda_sparse();
1590 Array1::zeros(0)
1591 };
1592 let g = penalty.grad_target(target.view(), rho_view.view());
1593 let d = penalty
1594 .hessian_diag(target.view(), rho_view.view())
1595 .ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
1596 (g, d)
1597 }
1598 AssignmentMode::ThresholdGate {
1599 temperature,
1600 threshold,
1601 } => {
1602 // Gradient and exact diagonal Hessian of the sparsity value's
1603 // threshold-centered surrogate σ((l−θ)/τ), using the same
1604 // machine-precision support as the value path. Data-fit JVP support
1605 // is narrower and follows the hard forward gate.
1606 let sparsity_strength = rho.lambda_sparse();
1607 let inv_tau = 1.0 / temperature;
1608 let inv_tau2 = inv_tau * inv_tau;
1609 let mut g = Array1::<f64>::zeros(target.len());
1610 let mut d = Array1::<f64>::zeros(target.len());
1611 for idx in 0..target.len() {
1612 let logit = target[idx];
1613 if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1614 continue;
1615 }
1616 let activation =
1617 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1618 let slope = activation * (1.0 - activation);
1619 g[idx] = sparsity_strength * slope * inv_tau;
1620 d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1621 }
1622 (g, d)
1623 }
1624 };
1625 grad += &sparsity_grad;
1626 diag += &sparsity_diag;
1627 // #1026/#1033 — a FIXED logit (an ungated atom's, or every atom's under
1628 // frozen routing) is not a free parameter, so it carries NO sparsity-prior
1629 // gradient or curvature. Zero its flat columns (`flat_logits` is row-major
1630 // `row*K + atom`) so the assembled `gt` and `htt` logit slots stay zero —
1631 // matching the zero logit-JVP. The column-separable IBP / JumpReLU priors are
1632 // per-atom, so zeroing one atom's columns leaves the others' prior intact;
1633 // under frozen routing ALL atoms' logit columns are zeroed (the whole routing
1634 // is a fixed predicted function, not optimized).
1635 if assignment.has_ungated() || assignment.routing_is_frozen() {
1636 let k = assignment.k_atoms();
1637 for idx in 0..grad.len() {
1638 if assignment.logit_is_fixed(idx % k) {
1639 grad[idx] = 0.0;
1640 diag[idx] = 0.0;
1641 }
1642 }
1643 }
1644 Ok((grad, diag))
1645}
1646
1647/// Build the exact IBP `hessian_diag` logit third-derivative channels (#1006)
1648/// for the SAE log-det adjoint Γ, using the SAME penalty configuration —
1649/// `alpha`/`tau`/`learnable_alpha` and the `lambda_sparse` weight convention —
1650/// that [`assignment_prior_grad_hdiag`] assembles into `htt`. Returns `None`
1651/// for non-IBP assignment modes (no cross-row empirical-π coupling to correct).
1652pub(crate) fn ibp_assignment_third_channels(
1653 assignment: &SaeAssignment,
1654 rho: &SaeManifoldRho,
1655) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
1656 let AssignmentMode::IBPMap {
1657 temperature,
1658 alpha,
1659 learnable_alpha,
1660 } = assignment.mode
1661 else {
1662 return Ok(None);
1663 };
1664 for row in 0..assignment.n_obs() {
1665 validate_finite_logits(assignment.logits.row(row), row)?;
1666 }
1667 let target = flat_logits(assignment.logits.view());
1668 let mut penalty =
1669 IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
1670 // Mirror assignment_prior_grad_hdiag exactly: when alpha is learnable the
1671 // sparse coordinate already modulates it through resolved_alpha(rho), so the
1672 // weight stays 1.0; otherwise lambda_sparse becomes the prior's weight lever.
1673 let rho_view = if learnable_alpha {
1674 Array1::from_vec(vec![rho.log_lambda_sparse])
1675 } else {
1676 penalty.weight = rho.lambda_sparse();
1677 Array1::zeros(0)
1678 };
1679 let mut channels = penalty.hessian_diag_logit_third_channels(target.view(), rho_view.view());
1680 // #1026/#1033 — zero the log-det third-derivative channels of FIXED-logit
1681 // atoms (ungated, or all atoms under frozen routing) so the #1006 θ-adjoint
1682 // differentiates the SAME (fixed-logit-zeroed) `htt` that
1683 // `assignment_prior_grad_hdiag` assembled. `k_max` columns, row-major `N·K`
1684 // for the per-(row,atom) arrays and length-`K` for the per-column ones.
1685 if assignment.has_ungated() || assignment.routing_is_frozen() {
1686 let k = channels.k_max;
1687 for idx in 0..channels.z_jac.len() {
1688 if assignment.logit_is_fixed(idx % k) {
1689 channels.z_jac[idx] = 0.0;
1690 channels.local_logit_third[idx] = 0.0;
1691 channels.m_channel[idx] = 0.0;
1692 channels.logit_curvature[idx] = 0.0;
1693 }
1694 }
1695 for atom in 0..k {
1696 if assignment.logit_is_fixed(atom) {
1697 channels.cross_row_d[atom] = 0.0;
1698 channels.cross_row_dd[atom] = 0.0;
1699 }
1700 }
1701 }
1702 Ok(Some(channels))
1703}
1704
1705/// #1026 hybrid curved + linear-tail adjudication for one SAE atom slot.
1706///
1707/// A hybrid dictionary lets each atom slot be either a CURVED atom (its fitted
1708/// `latent_dim ≥ 1` manifold chart, whose decoded image may turn) or its LINEAR
1709/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
1710/// `γ(t) = t·b`, zero turning). The two are nested: the linear atom is exactly
1711/// the curved family restricted to its straight sub-model, so a hybrid slot
1712/// cannot lose to pure-linear at matched actives — it strictly generalizes it.
1713///
1714/// This is the single call the SAE fitter makes per atom to choose the split by
1715/// EVIDENCE rather than fiat. It packages the atom's two already-fitted
1716/// candidates — each scored on the COMMON rank-aware Laplace scale (`−V = NLE`,
1717/// lower wins, identical to the union/mixture rungs) on the same rows — and
1718/// routes them through [`select_hybrid_atom`]. The curved candidate's fitted
1719/// turning `Θ` (from
1720/// [`crate::chart_canonicalization::d1_atom_fitted_turning`]) enters
1721/// as the decision feature: a `Θ → 0` atom yields to the cheaper linear tail by
1722/// construction (the dominance floor — a curved atom buys nothing on a straight
1723/// feature), a high-`Θ` atom takes the curved parameterization when its
1724/// curvature lowers the NLE by more than its extra-parameter price (the `Θ/√ε`
1725/// crossover).
1726///
1727/// `manifold` is the atom's fitted chart manifold; a non-curveable (already
1728/// Euclidean-flat) chart can only present the linear candidate, which this
1729/// helper enforces by ignoring any curved candidate offered for a flat chart —
1730/// a flat chart has no curvature to price, so the linear special case is its
1731/// only honest parameterization. Curveable charts present both candidates.
1732///
1733/// # Wiring into the fitter (the one call into `sae_manifold.rs`)
1734///
1735/// The post-fit pass in `sae_manifold.rs` already computes each d=1 atom's
1736/// fitted turning `Θ` (the read-only EV-vs-Θ diagnostic). To make the split
1737/// load-bearing, that pass supplies, per atom, the curved-candidate NLE +
1738/// parameter count + `Θ` and the linear-candidate NLE + parameter count (both
1739/// fitted on the atom's rows), and calls this helper; the returned
1740/// [`HybridAtomChoice`] tells the fitter which parameterization to keep for that
1741/// slot. The fitting of the two candidates lives in `sae_manifold.rs` (the
1742/// manifold-chart fitter); the SELECTION/scoring lives here.
1743pub fn select_hybrid_atom_parameterization(
1744 manifold: &LatentManifold,
1745 curved: Option<HybridAtomCandidate>,
1746 linear: HybridAtomCandidate,
1747) -> HybridAtomChoice {
1748 // A flat (Euclidean) chart has no curvature to price: its only honest
1749 // parameterization is the linear special case, so any curved candidate
1750 // offered for it is dropped before the evidence comparison. Curveable charts
1751 // (Circle / Sphere / Torus / curved products) present both candidates.
1752 let curved = if manifold.is_euclidean() {
1753 None
1754 } else {
1755 curved
1756 };
1757 let candidates: Vec<HybridAtomCandidate> = match curved {
1758 Some(c) => vec![linear, c],
1759 None => vec![linear],
1760 };
1761 // `candidates` is never empty (it always contains the linear candidate), so
1762 // the selector always returns a choice.
1763 select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
1764}
1765
1766#[cfg(test)]
1767mod ibp_prior_614_tests {
1768 // #614: `ibp_stick_breaking_prior` used to compute `π_k = (α/(α+1))^k` with
1769 // `π_0 = 1`, i.e. an UNSHRUNK first atom — the prior mean of no stick at all,
1770 // which broke α's role as an IBP concentration parameter. The consistent
1771 // truncated-IBP stick-breaking prior mean is `π_k = (α/(α+1))^{k+1}`, the
1772 // expectation of the product of (k+1) i.i.d. Beta(α,1) stick means, so EVERY
1773 // atom (including the first) carries one stick of shrinkage. This test pins
1774 // that contract so the regression cannot silently return.
1775 use super::*;
1776
1777 fn ratio(alpha: f64) -> f64 {
1778 alpha / (alpha + 1.0)
1779 }
1780
1781 #[test]
1782 fn first_atom_is_shrunk_not_unity() {
1783 // The #614 defect: π_0 must equal the single-stick mean α/(α+1), NOT 1.0.
1784 for &alpha in &[0.1_f64, 0.5, 1.0, 2.0, 5.0] {
1785 let prior = ordered_geometric_shrinkage_prior(8, alpha);
1786 let r = ratio(alpha);
1787 assert!(
1788 (prior[0] - r).abs() < 1e-12,
1789 "π_0 must be the single-stick mean α/(α+1)={r} (was the unshrunk 1.0 in #614); got {}",
1790 prior[0]
1791 );
1792 assert!(
1793 prior[0] < 1.0,
1794 "first atom must be shrunk (π_0<1) for alpha={alpha}; got {}",
1795 prior[0]
1796 );
1797 }
1798 }
1799
1800 #[test]
1801 fn prior_is_consistent_geometric_product_mean() {
1802 // π_k = (α/(α+1))^{k+1} exactly, and every successive ratio equals α/(α+1).
1803 for &alpha in &[0.3_f64, 1.0, 4.0] {
1804 let k = 12;
1805 let prior = ordered_geometric_shrinkage_prior(k, alpha);
1806 let r = ratio(alpha);
1807 for j in 0..k {
1808 let expected = r.powi((j + 1) as i32);
1809 assert!(
1810 (prior[j] - expected).abs() < 1e-12 * expected.max(1.0),
1811 "alpha={alpha} π_{j}: expected {expected}, got {}",
1812 prior[j]
1813 );
1814 }
1815 // Strictly decreasing (ordered shrinkage), no plateau at the head.
1816 for j in 1..k {
1817 assert!(
1818 prior[j] < prior[j - 1],
1819 "alpha={alpha}: prior must strictly decrease at index {j}"
1820 );
1821 }
1822 }
1823 }
1824
1825 #[test]
1826 fn alpha_behaves_as_concentration() {
1827 // Larger α => heavier mass / slower decay: π_0 increases toward 1 and the
1828 // tail (e.g. π_4) carries more mass. This is the IBP-concentration role
1829 // the #614 fix restored.
1830 let lo = ordered_geometric_shrinkage_prior(8, 0.5);
1831 let hi = ordered_geometric_shrinkage_prior(8, 5.0);
1832 assert!(
1833 hi[0] > lo[0],
1834 "larger alpha must raise π_0 (concentration): {} vs {}",
1835 hi[0],
1836 lo[0]
1837 );
1838 assert!(
1839 hi[4] > lo[4],
1840 "larger alpha must put more mass in the tail: {} vs {}",
1841 hi[4],
1842 lo[4]
1843 );
1844 }
1845}
1846
1847#[cfg(test)]
1848mod hybrid_split_tests {
1849 use super::*;
1850 use gam_solve::evidence::HybridAtomParam;
1851
1852 #[test]
1853 fn flat_chart_drops_curved_candidate_and_keeps_linear() {
1854 // A Euclidean chart has no curvature: even if a curved candidate with a
1855 // lower NLE is offered, the helper drops it (a flat chart cannot honestly
1856 // present a curved parameterization).
1857 let linear = HybridAtomCandidate::linear(100.0, 2);
1858 let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
1859 let choice =
1860 select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
1861 assert!(choice.param.is_linear());
1862 }
1863
1864 #[test]
1865 fn curveable_chart_selects_curved_when_turning_pays() {
1866 // A Circle chart presents both candidates; a turning feature whose curved
1867 // fit beats the linear secant on evidence selects curved.
1868 let linear = HybridAtomCandidate::linear(100.0, 2);
1869 let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
1870 let choice = select_hybrid_atom_parameterization(
1871 &LatentManifold::Circle {
1872 period: 2.0 * std::f64::consts::PI,
1873 },
1874 Some(curved),
1875 linear,
1876 );
1877 assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
1878 }
1879
1880 #[test]
1881 fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
1882 let linear = HybridAtomCandidate::linear(33.0, 2);
1883 let choice = select_hybrid_atom_parameterization(
1884 &LatentManifold::Circle {
1885 period: 2.0 * std::f64::consts::PI,
1886 },
1887 None,
1888 linear,
1889 );
1890 assert!(choice.param.is_linear());
1891 assert_eq!(choice.num_parameters, 2);
1892 }
1893}
1894
1895#[cfg(test)]
1896mod frozen_routing_1033_tests {
1897 //! #1033 — the FROZEN (amortized) routing mechanism: once installed, the
1898 //! per-row gate is a ρ-invariant function of the FROZEN predicted logits and
1899 //! is DECOUPLED from any subsequent update to the free `self.logits` (the
1900 //! inner-fit logit drift the outer ρ-search would otherwise re-incur every
1901 //! eval). These are deterministic mechanism invariants — no inner fit — so
1902 //! they pin the load-bearing freeze properties without the cluster.
1903 use super::*;
1904
1905 fn ibp_assignment(n: usize, k: usize) -> SaeAssignment {
1906 let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1907 0.3 + 0.05 * (i as f64) - 0.1 * (kk as f64)
1908 });
1909 let coords: Vec<Array2<f64>> = (0..k)
1910 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1911 .collect();
1912 // learnable_alpha = false: alpha is ρ-independent, isolating the routing.
1913 SaeAssignment::from_blocks_with_mode(
1914 logits,
1915 coords,
1916 AssignmentMode::ibp_map(0.5, 1.0, false),
1917 )
1918 .unwrap()
1919 }
1920
1921 #[test]
1922 fn frozen_routing_decouples_gates_from_logit_updates_1033() {
1923 let (n, k) = (6usize, 3usize);
1924 let mut a = ibp_assignment(n, k)
1925 .freeze_routing_from_current_logits()
1926 .unwrap();
1927 assert!(a.routing_is_frozen());
1928 // Gates BEFORE mutating the free logits.
1929 let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
1930 let before: Vec<Array1<f64>> = (0..n)
1931 .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1932 .collect();
1933 // Simulate an inner-fit logit update (what the ρ-search would otherwise do
1934 // every eval): perturb every free logit substantially.
1935 a.logits.mapv_inplace(|v| v + 5.0);
1936 let after: Vec<Array1<f64>> = (0..n)
1937 .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1938 .collect();
1939 // FROZEN routing reads the snapshot, so the gates are UNCHANGED by the
1940 // free-logit perturbation — the routing is decoupled from inner-fit drift.
1941 for r in 0..n {
1942 for kk in 0..k {
1943 assert_eq!(
1944 before[r][kk], after[r][kk],
1945 "row {r} atom {kk}: frozen-routing gate must be UNCHANGED by a free-logit \
1946 update (decoupled from inner-fit drift); {} vs {}",
1947 before[r][kk], after[r][kk]
1948 );
1949 }
1950 }
1951 }
1952
1953 #[test]
1954 fn frozen_routing_gates_are_rho_invariant_1033() {
1955 let (n, k) = (5usize, 2usize);
1956 let a = ibp_assignment(n, k)
1957 .freeze_routing_from_current_logits()
1958 .unwrap();
1959 // Two different ρ (different sparse + smooth strengths). With frozen routing
1960 // and learnable_alpha=false, the gate value must be identical at both ρ.
1961 let rho_a = SaeManifoldRho::new(
1962 (1e-3_f64).ln(),
1963 (1e-2_f64).ln(),
1964 vec![Array1::<f64>::zeros(1); k],
1965 );
1966 let rho_b = SaeManifoldRho::new(
1967 (1e3_f64).ln(),
1968 (1e1_f64).ln(),
1969 vec![Array1::<f64>::zeros(1); k],
1970 );
1971 for r in 0..n {
1972 let ga = a.try_assignments_row_for_rho(r, &rho_a).unwrap();
1973 let gb = a.try_assignments_row_for_rho(r, &rho_b).unwrap();
1974 for kk in 0..k {
1975 assert_eq!(
1976 ga[kk], gb[kk],
1977 "row {r} atom {kk}: frozen-routing gate must be ρ-INVARIANT (the n-independence \
1978 lever); {} at ρ_a vs {} at ρ_b",
1979 ga[kk], gb[kk]
1980 );
1981 }
1982 }
1983 }
1984
1985 #[test]
1986 fn frozen_routing_fixes_all_logits_and_thaw_restores_free_path_1033() {
1987 let (n, k) = (4usize, 3usize);
1988 let mut a = ibp_assignment(n, k)
1989 .freeze_routing_from_current_logits()
1990 .unwrap();
1991 // Under frozen routing EVERY logit is fixed (not a free Newton coord).
1992 let mask = a.fixed_logit_mask();
1993 assert_eq!(mask.len(), k);
1994 assert!(
1995 mask.iter().all(|&f| f),
1996 "frozen routing must fix ALL logits"
1997 );
1998 for kk in 0..k {
1999 assert!(
2000 a.logit_is_fixed(kk),
2001 "atom {kk} logit must be fixed under frozen routing"
2002 );
2003 }
2004 // Thawing restores the free-logit path (no fixed logits, no ungated).
2005 a.thaw_routing();
2006 assert!(!a.routing_is_frozen());
2007 assert!(
2008 a.fixed_logit_mask().iter().all(|&f| !f),
2009 "thaw must restore the free-logit path"
2010 );
2011 }
2012
2013 #[test]
2014 fn frozen_routing_rejects_softmax_1033() {
2015 let (n, k) = (4usize, 3usize);
2016 let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.1 * (i as f64) - 0.05 * (kk as f64));
2017 let coords: Vec<Array2<f64>> = (0..k)
2018 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
2019 .collect();
2020 let a = SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::softmax(1.0))
2021 .unwrap();
2022 // Softmax + frozen routing is rejected (the coupled-simplex entropy
2023 // majorizer would be inconsistent with a frozen, non-optimized routing).
2024 assert!(
2025 a.freeze_routing_from_current_logits().is_err(),
2026 "frozen routing under Softmax must be rejected (simplex entropy-majorizer coupling)"
2027 );
2028 }
2029}
2030
2031#[cfg(test)]
2032mod fill_into_buffer_1557_tests {
2033 //! #1557 — the fill-into-caller-buffer variant
2034 //! [`SaeAssignment::try_assignments_row_for_rho_into`] must produce
2035 //! BIT-IDENTICAL output to the allocating
2036 //! [`SaeAssignment::try_assignments_row_for_rho`] across every assignment
2037 //! mode (Softmax, IBPMap, JumpReLU), the #1026 ungated case, and the K==1
2038 //! edge. Exact `==` on f64 — not an approximate tolerance — because the
2039 //! `_into` path is a pure allocation-elision refactor and any numeric drift
2040 //! is a regression.
2041 use super::*;
2042
2043 fn build(n: usize, k: usize, mode: AssignmentMode) -> SaeAssignment {
2044 // Deterministic, asymmetric logits/coords so every atom takes a distinct
2045 // value (no accidental ties masking an index bug).
2046 let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
2047 0.37 + 0.11 * (i as f64) - 0.23 * (kk as f64)
2048 });
2049 let coords: Vec<Array2<f64>> = (0..k)
2050 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| 0.1 + 0.05 * (i as f64)))
2051 .collect();
2052 SaeAssignment::from_blocks_with_mode(logits, coords, mode).unwrap()
2053 }
2054
2055 fn rho(k: usize) -> SaeManifoldRho {
2056 SaeManifoldRho::new(
2057 (1e-2_f64).ln(),
2058 (1e-1_f64).ln(),
2059 vec![Array1::<f64>::zeros(1); k],
2060 )
2061 }
2062
2063 fn assert_into_matches_alloc(a: &SaeAssignment) {
2064 let n = a.n_obs();
2065 let k = a.k_atoms();
2066 let rho = rho(k);
2067 let mut scratch = vec![f64::NAN; k];
2068 for row in 0..n {
2069 let allocated = a.try_assignments_row_for_rho(row, &rho).unwrap();
2070 // Pre-fill with NaN so a partial write (e.g. a JumpReLU below-threshold
2071 // entry left untouched) is caught as a mismatch, not silently passed.
2072 for s in scratch.iter_mut() {
2073 *s = f64::NAN;
2074 }
2075 a.try_assignments_row_for_rho_into(row, &rho, &mut scratch)
2076 .unwrap();
2077 assert_eq!(allocated.len(), k);
2078 for kk in 0..k {
2079 assert_eq!(
2080 allocated[kk], scratch[kk],
2081 "row {row} atom {kk}: _into must be BIT-IDENTICAL to the allocating \
2082 try_assignments_row_for_rho; got {} vs {}",
2083 allocated[kk], scratch[kk]
2084 );
2085 }
2086 }
2087 }
2088
2089 #[test]
2090 fn softmax_into_is_bit_identical() {
2091 assert_into_matches_alloc(&build(7, 4, AssignmentMode::softmax(0.8)));
2092 }
2093
2094 #[test]
2095 fn ibp_map_into_is_bit_identical() {
2096 // Both learnable and fixed alpha exercise the resolved-alpha branch.
2097 assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, false)));
2098 assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, true)));
2099 }
2100
2101 #[test]
2102 fn jumprelu_into_is_bit_identical() {
2103 // Threshold chosen so SOME atoms fall below it (the untouched-entry path)
2104 // and some clear it (the sigmoid path) — both branches are exercised.
2105 assert_into_matches_alloc(&build(7, 5, AssignmentMode::jumprelu(0.9, 0.2)));
2106 }
2107
2108 #[test]
2109 fn ungated_into_is_bit_identical() {
2110 // #1026 ungated overwrite under a gate-style mode (IBP/JumpReLU allow it).
2111 let a = build(6, 4, AssignmentMode::ibp_map(0.6, 1.1, false))
2112 .with_ungated(vec![false, true, false, true])
2113 .unwrap();
2114 assert_into_matches_alloc(&a);
2115 let j = build(6, 4, AssignmentMode::jumprelu(0.9, 0.15))
2116 .with_ungated(vec![true, false, true, false])
2117 .unwrap();
2118 assert_into_matches_alloc(&j);
2119 }
2120
2121 #[test]
2122 fn k_equals_one_into_is_bit_identical() {
2123 // Softmax K==1 hits the fixed-unit early return; IBP/JumpReLU K==1 keep a
2124 // free per-atom gate and fall through to the real row functions.
2125 assert_into_matches_alloc(&build(5, 1, AssignmentMode::softmax(1.0)));
2126 assert_into_matches_alloc(&build(5, 1, AssignmentMode::ibp_map(0.7, 1.0, false)));
2127 assert_into_matches_alloc(&build(5, 1, AssignmentMode::jumprelu(0.8, 0.1)));
2128 }
2129}