gam_sae/assignment.rs
1//! Assignment gates and sparsity-prior helpers for the SAE manifold term.
2//! Mechanically split from `sae_manifold.rs`.
3
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use gam_solve::evidence::{HybridAtomCandidate, HybridAtomChoice, select_hybrid_atom};
7use gam_terms::analytic_penalties::{
8 AnalyticPenalty, IBPAssignmentPenalty, IbpHessianDiagThirdChannels,
9 SoftmaxAssignmentSparsityPenalty, resolve_learnable_weight,
10};
11use gam_terms::latent::{LatentCoordValues, LatentIdMode, LatentManifold};
12use crate::manifold::SaeManifoldRho;
13
14/// #976 Layer-1 guard: cap on one accepted iteration's assignment-logit
15/// update, in units of the gate temperature τ (the gate's natural length
16/// scale — every assignment mode reads logits through `σ(·/τ)` /
17/// `softmax(·/τ)`). A 4τ move spans the gate's whole soft range, so healthy
18/// convergence is never throttled, but no single inner iteration can carry a
19/// gate from contention to numerically-zero support: a collapse takes
20/// multiple accepted iterations, which guarantees the per-iteration
21/// active-mass guard observes the decay before it completes. The clamp is
22/// applied where the step is realised; when it binds, the realised objective
23/// is evaluated on the clamped state, so the Armijo comparison stays
24/// value-consistent (the unclamped quadratic model is merely conservative,
25/// and step halvings shrink the trial below the cap).
26pub(crate) const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
27
28/// #976 Layer-1 guard: re-seed budget per atom per joint fit. One second
29/// chance from a fresh basin; a second breach means the collapse is (locally)
30/// the objective's verdict at the current hyperparameters, which is recorded
31/// as a terminal collapse event and left for the structure-search death move
32/// to adjudicate — re-seeding in a loop would fight the optimizer.
33pub(crate) const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
34
35/// #976 Layer-1 guard (decoder arm): an atom whose decoder block Frobenius norm
36/// has fallen to this fraction of the dictionary's MEDIAN decoder norm carries
37/// no material reconstruction signal — it has degenerated to (near-)zero output
38/// and decodes the same nothing as every other collapsed atom. This is the
39/// real-data K>1 failure that the gate-mass floor cannot see: the assignment
40/// gates can stay spread across rows (mass guard satisfied) while the decoders
41/// all collapse to ~0, giving EV≈0 and a rank-deficient per-row coordinate
42/// Hessian on every row (the 0→K·n evidence-deflation jump). The statistic is a
43/// RATIO to the dictionary median so it is scale-free and never fires for a
44/// uniformly-small but well-conditioned decoder; only an atom that has fallen
45/// far behind its peers is caught. By construction this is a no-op for K=1
46/// (a single atom has no peer to fall behind, and the median equals its own
47/// norm), so the K=1 path is byte-for-byte unchanged.
48pub(crate) const SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO: f64 = 1.0e-3;
49
50/// #976 / #1117 K>1 robustness: bounded DICTIONARY-level multi-start budget for
51/// the simultaneous co-collapse arm (the EV-floor branch of
52/// [`crate::manifold::SaeManifoldTerm::enforce_decoder_norm_guard`]).
53/// Distinct from the per-atom [`SAE_ATOM_COLLAPSE_RESEED_BUDGET`] (= 1): that
54/// budget governs reseeding ONE atom's gate logits against an optimizer that
55/// keeps killing it, where a loop would fight the optimizer. A co-collapse
56/// reseed is categorically different — it is a full-dictionary multi-start that
57/// re-diversifies ALL atoms onto distinct principal directions of a FRESHLY
58/// recomputed residual, so successive attempts explore genuinely different
59/// basins. A single such reseed empirically cannot always break a K≥3 three-way
60/// basin (identical (K, seed) flips EV≈0.40 ↔ 0.00), so this arm gets a small
61/// bounded budget of independent multi-starts. It is consumed ONLY when the
62/// whole dictionary's reconstruction EV is at or below the data-derived collapse
63/// bar (`collapse_ev_bar` = `SAE_COLLAPSE_PCA_EV_FRACTION` × the rank-K PCA
64/// ceiling, i.e. less than half the variance any rank-K linear dictionary could
65/// reach) — the old absolute magic EV floor was REPLACED by that data-derived
66/// bar, not deleted. A no-op for any healthy fit (real OLMo K=1 ~0.22, K=2
67/// ~0.40, both well above the bar); on the rough patches of a hard K≥3 fit a
68/// transient sub-bar EV can still consume the budget before the fit recovers.
69pub(crate) const SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET: usize = 3;
70
71/// Machine-precision support cutoff for the smooth JumpReLU assignment prior,
72/// in units of the gate temperature below the hard threshold. The forward gate
73/// remains hard-zero at and below `threshold`, but the prior value/gradient and
74/// compact Newton layout keep every logit with `(logit - threshold)/tau > -36`.
75/// At the excluded edge `sigma(-36) ~= 2e-16`, so dropped value/gradient/Hessian
76/// terms are below f64 noise instead of creating an algorithmic discontinuity.
77pub(crate) const JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF: f64 = -36.0;
78
79/// Shared support predicate for JumpReLU optimization inclusion. This is
80/// strictly weaker than the hard forward gate `logit > threshold`, which still
81/// governs data-fit reconstruction and its logit JVP.
82#[inline]
83pub(crate) fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
84 (logit - threshold) / temperature > JUMPRELU_OPTIMIZATION_LOGIT_CUTOFF
85}
86
87/// Assignment prior/relaxation used by [`SaeAssignment`].
88#[derive(Debug, Clone, Copy)]
89pub enum AssignmentMode {
90 /// Row-wise simplex assignment with entropy sparsity.
91 Softmax { temperature: f64, sparsity: f64 },
92 /// Deterministic concrete relaxation of a truncated IBP active set.
93 IBPMap {
94 temperature: f64,
95 alpha: f64,
96 learnable_alpha: bool,
97 },
98 /// Hard-thresholded bounded gate: each atom is off (gate = 0) when its logit
99 /// is at or below `threshold`, and on with a threshold-centered shifted
100 /// sigmoid `σ((logit − threshold) / temperature) ∈ [0.5, 1)` above it. This
101 /// is NOT literal JumpReLU `z·1[z>θ]` — the gate carries no magnitude; it is
102 /// a member of the gate family (softmax simplex / IBP sigmoid / this hard
103 /// gate) and stays bounded in [0, 1]. Reconstruction magnitude lives entirely
104 /// in the decoder curve `g_k(t) = φ(t)ᵀ B_k`. The discontinuity at `threshold`
105 /// (0 → 0.5) is the intended "jump".
106 JumpReLU { temperature: f64, threshold: f64 },
107}
108
109/// #1033 — the fixed-form predictor that produces the ρ-invariant FROZEN routing
110/// (amortized routing). Both forms are NO-learned-net deterministic functions of
111/// the current dictionary; they differ in how faithfully they track the
112/// dictionary as it evolves across outer iterates. Kept as alternatives so the
113/// accuracy gate can pick whichever passes the fit-quality bar (the cheap
114/// `Snapshot` if it suffices, the `ChartGeometry` distill otherwise).
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum RoutingPredictor {
117 /// Snapshot the current (converged) logits as the frozen routing — the
118 /// cheapest fixed-form distill, exact at the dictionary it is taken from.
119 /// Goes stale as the dictionary moves (needs a refresh to track), so it is the
120 /// MVP/baseline form.
121 Snapshot,
122 /// Re-derive the per-(row, atom) routing logit from the atom's encode-chart
123 /// geometry against the CURRENT dictionary: encode each row to its predicted
124 /// coord `t̂`, reconstruct the amplitude-1 image `γ_k(t̂) = Bᵀφ(t̂)`, and map
125 /// the reconstruction ALIGNMENT to a logit. This tracks the dictionary
126 /// (a moved decoder changes `γ_k(t̂)` and hence the routing) without re-running
127 /// the free-logit inner solve, so it is the default-readiness form when the
128 /// snapshot proves too stale.
129 ChartGeometry,
130}
131
132impl AssignmentMode {
133 #[must_use]
134 pub fn softmax(temperature: f64) -> Self {
135 Self::Softmax {
136 temperature,
137 sparsity: 1.0,
138 }
139 }
140
141 #[must_use]
142 pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
143 Self::IBPMap {
144 temperature,
145 alpha,
146 learnable_alpha,
147 }
148 }
149
150 #[must_use]
151 pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
152 Self::JumpReLU {
153 temperature,
154 threshold,
155 }
156 }
157
158 pub fn temperature(&self) -> f64 {
159 match *self {
160 AssignmentMode::Softmax { temperature, .. }
161 | AssignmentMode::IBPMap { temperature, .. }
162 | AssignmentMode::JumpReLU { temperature, .. } => temperature,
163 }
164 }
165
166 pub(crate) fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
167 if !(new_temperature.is_finite() && new_temperature > 0.0) {
168 return Err(format!(
169 "AssignmentMode: temperature must be finite and positive; got {new_temperature}"
170 ));
171 }
172 match self {
173 AssignmentMode::Softmax { temperature, .. }
174 | AssignmentMode::IBPMap { temperature, .. }
175 | AssignmentMode::JumpReLU { temperature, .. } => {
176 *temperature = new_temperature;
177 }
178 }
179 Ok(())
180 }
181
182 pub(crate) fn validate(&self) -> Result<(), String> {
183 let temperature = self.temperature();
184 if !(temperature.is_finite() && temperature > 0.0) {
185 return Err(format!(
186 "AssignmentMode: temperature must be finite and positive; got {temperature}"
187 ));
188 }
189 match *self {
190 AssignmentMode::Softmax { sparsity, .. } => {
191 if !(sparsity.is_finite() && sparsity > 0.0) {
192 return Err(format!(
193 "AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
194 ));
195 }
196 }
197 AssignmentMode::IBPMap { alpha, .. } => {
198 if !(alpha.is_finite() && alpha > 0.0) {
199 return Err(format!(
200 "AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
201 ));
202 }
203 }
204 AssignmentMode::JumpReLU { threshold, .. } => {
205 if !threshold.is_finite() {
206 return Err(format!(
207 "AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
208 ));
209 }
210 }
211 }
212 Ok(())
213 }
214
215 pub(crate) fn resolved_ibp_alpha(&self, rho: &SaeManifoldRho) -> Option<f64> {
216 match *self {
217 AssignmentMode::IBPMap {
218 alpha,
219 learnable_alpha,
220 ..
221 } => Some(if let Some(over) = ibp_alpha_override() {
222 // #1026 — a process-global α override flattens the ordered
223 // geometric prior π_k = (α/(α+1))^{k+1} so all K atoms can
224 // contribute to the reconstruction (the production α=1 gives a
225 // (0.5)^{k+1} schedule that structurally caps atoms 4..K to a few
226 // percent → effective-K≈3). Forces the fixed value, bypassing the
227 // learnable schedule, so a sweep can attribute the EV ceiling.
228 over
229 } else if learnable_alpha {
230 resolve_learnable_weight(alpha, rho.log_lambda_sparse)
231 } else {
232 alpha
233 }),
234 _ => None,
235 }
236 }
237}
238
239// #1026 — process-global IBP-α override (NaN sentinel = "unset → use the
240// AssignmentMode's compiled α"). Lets ONE wheel sweep the prior-flattening axis
241// from Python (`sae_set_ibp_alpha`) without recompiling the gam crate.
242static IBP_ALPHA_OVERRIDE_BITS: std::sync::atomic::AtomicU64 =
243 std::sync::atomic::AtomicU64::new(0x7ff8_0000_0000_0000);
244
245pub(crate) fn ibp_alpha_override() -> Option<f64> {
246 let v = f64::from_bits(IBP_ALPHA_OVERRIDE_BITS.load(std::sync::atomic::Ordering::Relaxed));
247 if v.is_finite() && v > 0.0 {
248 Some(v)
249 } else {
250 None
251 }
252}
253
254/// Set (or, with a non-finite/non-positive value, clear) the process-global
255/// IBP-α override. Called from the gamfit Python FFI sweep driver.
256pub fn set_ibp_alpha_override(alpha: f64) {
257 IBP_ALPHA_OVERRIDE_BITS.store(alpha.to_bits(), std::sync::atomic::Ordering::Relaxed);
258}
259
260/// Per-row latent assignment state.
261///
262/// The stored assignment parameter is `logits`; non-negative assignments are
263/// derived by row-wise softmax, independent IBP-MAP sigmoid active indicators,
264/// or JumpReLU gates. Softmax logits are canonicalized to the reference chart
265/// `logits[K - 1] = 0`, so the row-local Newton coordinates contain only the
266/// first `K - 1` logits (`0` coordinates for `K = 1`). Gate-style modes keep
267/// all `K` logits as identifiable scalar parameters. `coords[k]` holds
268/// `t_{.,k}` for atom `k`.
269#[derive(Debug, Clone)]
270pub struct SaeAssignment {
271 pub logits: Array2<f64>,
272 pub coords: Vec<LatentCoordValues>,
273 pub mode: AssignmentMode,
274 /// #1026 — per-atom UNGATED flag (length `K`, default all-`false`). An
275 /// ungated atom is the dense linear/background tier: its per-row gate is
276 /// fixed at `a_k ≡ 1` (it contributes `γ_k(t_k)` to EVERY row, unweighted),
277 /// it is excluded from the other atoms' gate (for the column-separable
278 /// IBP / JumpReLU modes the remaining atoms are computed independently, so
279 /// they are unaffected), and its logit is NOT a free parameter — its
280 /// logit-JVP, sparsity-prior gradient/curvature, and softmax majorizer
281 /// contributions are all zero, leaving its logit slot an inert
282 /// (ridge-regularized) null direction in the per-row Newton block. This lets
283 /// the linear tier carry FULL-RANK reconstructible variance
284 /// (`fitted = γ_ungated(x) + Σ_{gated} a_k·γ_k(x)`) so a linear SAE can reach
285 /// the rank-(K·d) PCA ceiling, while the gated curved atoms still add sparse
286 /// structure on the residual (#1026 routing-bound finding).
287 pub ungated: Vec<bool>,
288 /// #1033 — AMORTIZED / FROZEN routing. When `Some`, this `(n, K)` matrix is a
289 /// ρ-INVARIANT predicted routing (the amortized `x → logits` map distilled
290 /// from the frozen dictionary): the gates are computed from THESE logits
291 /// instead of the free `self.logits`, and the logits are NOT optimized by the
292 /// inner Newton (their gradient/curvature/prior contributions are zeroed,
293 /// exactly as for [`Self::ungated`]). This is the generalization of an ungated
294 /// atom from "pin the gate at 1" to "pin the gate at the predicted value": it
295 /// makes the per-row routing a fixed function of `x` + the frozen dictionary,
296 /// so the outer ρ-search reuses ONE routing instead of re-solving per-row
297 /// gates every outer eval — the n-independent-outer-loop lever (#1033). `None`
298 /// is the historical free-logit path (bit-identical).
299 pub frozen_logits: Option<Array2<f64>>,
300}
301
302impl SaeAssignment {
303 #[must_use = "build error must be handled"]
304 pub fn new(
305 logits: Array2<f64>,
306 coords: Vec<LatentCoordValues>,
307 temperature: f64,
308 ) -> Result<Self, String> {
309 Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
310 }
311
312 #[must_use = "build error must be handled"]
313 pub fn with_mode(
314 mut logits: Array2<f64>,
315 coords: Vec<LatentCoordValues>,
316 mode: AssignmentMode,
317 ) -> Result<Self, String> {
318 mode.validate()?;
319 let n = logits.nrows();
320 let k = logits.ncols();
321 if coords.len() != k {
322 return Err(format!(
323 "SaeAssignment::new: coords length {} must equal K={k}",
324 coords.len()
325 ));
326 }
327 for (atom, coord) in coords.iter().enumerate() {
328 if coord.n_obs() != n {
329 return Err(format!(
330 "SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
331 coord.n_obs()
332 ));
333 }
334 }
335 for row in 0..n {
336 validate_finite_logits(logits.row(row), row)?;
337 }
338 if matches!(mode, AssignmentMode::Softmax { .. }) {
339 canonicalize_softmax_logits(&mut logits);
340 }
341 Ok(Self {
342 logits,
343 coords,
344 mode,
345 ungated: vec![false; k],
346 frozen_logits: None,
347 })
348 }
349
350 /// #1033 — install a ρ-INVARIANT FROZEN routing (the amortized predicted
351 /// logits; see [`SaeAssignment::frozen_logits`]). `predicted` must be
352 /// `(n, K)`. With routing frozen, the gates are computed from `predicted` and
353 /// the logits are excluded from the inner Newton (their gradient/curvature are
354 /// inert, like an ungated atom's). Passing `None` restores the free-logit
355 /// path.
356 #[must_use = "build error must be handled"]
357 pub fn with_frozen_routing(mut self, predicted: Option<Array2<f64>>) -> Result<Self, String> {
358 if let Some(ref p) = predicted {
359 if p.dim() != (self.n_obs(), self.k_atoms()) {
360 return Err(format!(
361 "SaeAssignment::with_frozen_routing: predicted shape {:?} must be ({}, {})",
362 p.dim(),
363 self.n_obs(),
364 self.k_atoms()
365 ));
366 }
367 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
368 return Err(
369 "SaeAssignment::with_frozen_routing: frozen routing under Softmax is rejected \
370 — the coupled simplex's entropy majorizer is assembled over the logits, which \
371 a frozen (non-optimized) routing would leave inconsistent; this separable-mode \
372 contract supports IBP-MAP and JumpReLU, whose per-atom gates have no \
373 simplex-coupled curvature to skip"
374 .to_string(),
375 );
376 }
377 for row in 0..p.nrows() {
378 validate_finite_logits(p.row(row), row)?;
379 }
380 }
381 self.frozen_logits = predicted;
382 Ok(self)
383 }
384
385 /// Whether the per-row routing is FROZEN (amortized) rather than free-logit.
386 pub fn routing_is_frozen(&self) -> bool {
387 self.frozen_logits.is_some()
388 }
389
390 /// The active routing logits for `row`: the frozen/predicted logits when
391 /// routing is frozen (#1033), else the free `self.logits`. This is the SINGLE
392 /// source the gate value reads, so freezing routing changes every gate
393 /// consistently.
394 pub(crate) fn routing_logits_row(&self, row: usize) -> ArrayView1<'_, f64> {
395 match self.frozen_logits {
396 Some(ref f) => f.row(row),
397 None => self.logits.row(row),
398 }
399 }
400
401 /// Whether atom `k`'s logit is held fixed (not a free Newton parameter): true
402 /// for an ungated atom (#1026, gate pinned at 1) OR when routing is frozen
403 /// (#1033, gate pinned at the predicted value). Both share the same inert
404 /// treatment — zero logit-JVP, zero sparsity-prior gradient/curvature, zero
405 /// softmax majorizer — so the logit slot never moves.
406 pub(crate) fn logit_is_fixed(&self, k: usize) -> bool {
407 self.routing_is_frozen() || self.ungated.get(k).copied().unwrap_or(false)
408 }
409
410 /// Per-atom mask (length `K`) of [`Self::logit_is_fixed`] — the logit slots
411 /// that are NOT free Newton parameters (ungated #1026 and/or frozen-routing
412 /// #1033). Precompute once per assembly and pass to the logit-JVP fillers so
413 /// the data-fit Jacobian zeroes those rows. Under frozen routing every entry
414 /// is `true`; with only ungated atoms it equals `ungated`; otherwise all
415 /// `false` (the historical free-logit path).
416 pub(crate) fn fixed_logit_mask(&self) -> Vec<bool> {
417 if self.routing_is_frozen() {
418 vec![true; self.k_atoms()]
419 } else {
420 self.ungated.clone()
421 }
422 }
423
424 /// #1033 — install the simplest faithful AMORTIZED routing predictor: a
425 /// fixed-form DISTILL of the current dictionary's routing, namely the current
426 /// (converged) logits SNAPSHOTTED as the ρ-invariant frozen routing. This is
427 /// the `x → logits` map "evaluated once at the frozen dictionary" — the
428 /// routing the dictionary already expresses — held fixed so the outer ρ-search
429 /// reuses it instead of re-optimizing the gates at every ρ. (A richer
430 /// predictor that recomputes logits from `x` via the encode-atlas chart
431 /// geometry is a later refinement; snapshotting the converged routing is the
432 /// exact fixed-point it would target at the frozen dictionary.) Rejected for
433 /// Softmax for the same simplex-coupling reason as [`Self::with_frozen_routing`].
434 #[must_use = "build error must be handled"]
435 pub fn freeze_routing_from_current_logits(self) -> Result<Self, String> {
436 let snapshot = self.logits.clone();
437 self.with_frozen_routing(Some(snapshot))
438 }
439
440 /// #1033 — in-place variant of [`Self::freeze_routing_from_current_logits`]
441 /// for callers holding `&mut SaeAssignment` (e.g. inside a `SaeManifoldTerm`),
442 /// where moving the assignment out is awkward. Same contract: snapshot the
443 /// current logits as the ρ-invariant frozen routing; reject Softmax.
444 pub fn freeze_routing_in_place(&mut self) -> Result<(), String> {
445 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
446 return Err(
447 "SaeAssignment::freeze_routing_in_place: frozen routing under Softmax is rejected \
448 (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
449 .to_string(),
450 );
451 }
452 let snapshot = self.logits.clone();
453 for row in 0..snapshot.nrows() {
454 validate_finite_logits(snapshot.row(row), row)?;
455 }
456 self.frozen_logits = Some(snapshot);
457 Ok(())
458 }
459
460 /// #1033 — install an explicit predicted routing in place (the
461 /// [`RoutingPredictor::ChartGeometry`] output), `&mut self` variant of
462 /// [`Self::with_frozen_routing`]. `predicted` must be `(n, K)`; rejects Softmax
463 /// (separable-mode contract) and non-finite predictions.
464 pub fn set_frozen_routing_in_place(&mut self, predicted: Array2<f64>) -> Result<(), String> {
465 if predicted.dim() != (self.n_obs(), self.k_atoms()) {
466 return Err(format!(
467 "SaeAssignment::set_frozen_routing_in_place: predicted shape {:?} must be ({}, {})",
468 predicted.dim(),
469 self.n_obs(),
470 self.k_atoms()
471 ));
472 }
473 if matches!(self.mode, AssignmentMode::Softmax { .. }) {
474 return Err(
475 "SaeAssignment::set_frozen_routing_in_place: frozen routing under Softmax is \
476 rejected (coupled-simplex entropy-majorizer); use IBP-MAP or JumpReLU"
477 .to_string(),
478 );
479 }
480 for row in 0..predicted.nrows() {
481 validate_finite_logits(predicted.row(row), row)?;
482 }
483 self.frozen_logits = Some(predicted);
484 Ok(())
485 }
486
487 /// #1033 — lift the frozen routing, restoring the free-logit search path.
488 pub fn thaw_routing(&mut self) {
489 self.frozen_logits = None;
490 }
491
492 /// #1026 — designate which atoms are UNGATED (the dense linear/background
493 /// tier; see [`SaeAssignment::ungated`]). `flags` must have length `K`.
494 ///
495 /// Ungating is defined for the COLUMN-SEPARABLE gate modes (IBP-MAP and
496 /// JumpReLU): each atom's gate is an independent per-atom function of its own
497 /// logit, so pinning one atom to `a_k ≡ 1` leaves every other atom's gate
498 /// exactly as computed. Softmax is a coupled simplex (`Σ_k a_k = 1` over all
499 /// `K`), so a unit gate for one atom is only well defined relative to a
500 /// gated-subset renormalization that must also be reflected in the logit-JVP
501 /// and the entropy majorizer; this constructor's contract is restricted to
502 /// the separable modes, and an ungated atom under Softmax is REJECTED here so
503 /// the inner solve never runs on a value/gradient-mismatched gate. Callers
504 /// wanting a dense background tier under Softmax route it as an IBP-MAP or
505 /// JumpReLU atom.
506 #[must_use = "build error must be handled"]
507 pub fn with_ungated(mut self, flags: Vec<bool>) -> Result<Self, String> {
508 if flags.len() != self.k_atoms() {
509 return Err(format!(
510 "SaeAssignment::with_ungated: flags length {} must equal K={}",
511 flags.len(),
512 self.k_atoms()
513 ));
514 }
515 if matches!(self.mode, AssignmentMode::Softmax { .. }) && flags.iter().any(|&u| u) {
516 return Err(
517 "SaeAssignment::with_ungated: an ungated atom under Softmax routing is \
518 rejected — the coupled simplex requires a gated-subset renormalization \
519 reflected in the logit-JVP and entropy majorizer, which this separable-mode \
520 contract does not perform; route a dense background tier as IBP-MAP or JumpReLU"
521 .to_string(),
522 );
523 }
524 self.ungated = flags;
525 Ok(self)
526 }
527
528 /// Whether any atom is ungated (the #1026 background tier is engaged).
529 pub fn has_ungated(&self) -> bool {
530 self.ungated.iter().any(|&u| u)
531 }
532
533 pub fn n_obs(&self) -> usize {
534 self.logits.nrows()
535 }
536
537 pub fn k_atoms(&self) -> usize {
538 self.logits.ncols()
539 }
540
541 pub fn total_coord_dim(&self) -> usize {
542 self.coords.iter().map(|c| c.latent_dim()).sum()
543 }
544
545 pub fn assignment_coord_dim(&self) -> usize {
546 match self.mode {
547 AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
548 AssignmentMode::IBPMap { .. } | AssignmentMode::JumpReLU { .. } => self.k_atoms(),
549 }
550 }
551
552 pub fn row_block_dim(&self) -> usize {
553 self.assignment_coord_dim() + self.total_coord_dim()
554 }
555
556 pub fn coord_offsets(&self) -> Vec<usize> {
557 let mut out = Vec::with_capacity(self.k_atoms());
558 let mut cursor = self.assignment_coord_dim();
559 for coord in &self.coords {
560 out.push(cursor);
561 cursor += coord.latent_dim();
562 }
563 out
564 }
565
566 pub fn assignments(&self) -> Array2<f64> {
567 let n = self.n_obs();
568 let k = self.k_atoms();
569 let mut out = Array2::<f64>::zeros((n, k));
570 for row in 0..n {
571 let a = self.assignments_row(row);
572 for atom in 0..k {
573 out[[row, atom]] = a[atom];
574 }
575 }
576 out
577 }
578
579 pub fn assignments_row(&self, row: usize) -> Array1<f64> {
580 self.try_assignments_row(row)
581 .expect("assignment logits must be finite")
582 }
583
584 pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
585 self.try_assignments_row_with_alpha(row, None)
586 }
587
588 pub(crate) fn try_assignments_row_for_rho(
589 &self,
590 row: usize,
591 rho: &SaeManifoldRho,
592 ) -> Result<Array1<f64>, String> {
593 self.try_assignments_row_with_alpha(row, self.mode.resolved_ibp_alpha(rho))
594 }
595
596 fn try_assignments_row_with_alpha(
597 &self,
598 row: usize,
599 resolved_ibp_alpha: Option<f64>,
600 ) -> Result<Array1<f64>, String> {
601 // #1033 — read the ACTIVE routing logits: the ρ-invariant frozen/predicted
602 // logits when routing is frozen, else the free `self.logits`. This single
603 // source makes the gate value ρ-invariant under frozen routing (the
604 // amortized-routing lever) and bit-identical to the historical path when
605 // not frozen.
606 let routing = self.routing_logits_row(row);
607 validate_finite_logits(routing, row)?;
608 // Only Softmax collapses to a fixed assignment at K==1: its
609 // assignment_coord_dim is K-1 = 0, so there is no free logit. IBPMap and
610 // JumpReLU keep a free per-atom gate logit even at K==1
611 // (assignment_coord_dim = K = 1), so they must fall through to their real
612 // row functions or the logit would move the prior but not the gate.
613 if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
614 return Ok(Array1::from_vec(vec![1.0]));
615 }
616 let mut row_gates = match self.mode {
617 AssignmentMode::Softmax { temperature, .. } => softmax_row(routing, temperature),
618 AssignmentMode::IBPMap {
619 temperature, alpha, ..
620 } => ibp_map_row(routing, temperature, resolved_ibp_alpha.unwrap_or(alpha)),
621 AssignmentMode::JumpReLU {
622 temperature,
623 threshold,
624 } => jumprelu_row(routing, temperature, threshold),
625 };
626 // #1026 — ungated (background-tier) atoms have a fixed unit gate. For the
627 // column-separable IBP / JumpReLU modes the other atoms' gates are
628 // computed independently above, so overwriting the ungated entries to 1.0
629 // leaves the gated atoms exactly as they were; the ungated atom then
630 // contributes `γ_k(t_k)` unweighted to every row. (Softmax + ungated is
631 // rejected at `with_ungated`, so no simplex renormalization is needed
632 // here.)
633 if self.has_ungated() {
634 for (k, gate) in row_gates.iter_mut().enumerate() {
635 if self.ungated[k] {
636 *gate = 1.0;
637 }
638 }
639 }
640 Ok(row_gates)
641 }
642
643 /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_for_rho`].
644 ///
645 /// Writes the EXACT SAME per-atom assignment row into `out` (length
646 /// `k_atoms()`) instead of allocating a fresh `Array1`. Bit-identical to the
647 /// allocating path; intended for the hot per-row loops that immediately
648 /// consume the row, reusing a single scratch buffer across rows.
649 pub(crate) fn try_assignments_row_for_rho_into(
650 &self,
651 row: usize,
652 rho: &SaeManifoldRho,
653 out: &mut [f64],
654 ) -> Result<(), String> {
655 self.try_assignments_row_with_alpha_into(row, self.mode.resolved_ibp_alpha(rho), out)
656 }
657
658 /// #1557 — fill-into-caller-buffer twin of [`Self::try_assignments_row_with_alpha`].
659 ///
660 /// `out` must have length `k_atoms()`; it is fully overwritten with the same
661 /// values the allocating variant would return. Every branch (early-return
662 /// K==1 Softmax, the per-mode row math, the #1026 ungated overwrite) mirrors
663 /// the allocating path exactly so the two are bit-identical.
664 pub(crate) fn try_assignments_row_with_alpha_into(
665 &self,
666 row: usize,
667 resolved_ibp_alpha: Option<f64>,
668 out: &mut [f64],
669 ) -> Result<(), String> {
670 // `out` is sized `k_atoms()` by every caller; the per-mode helpers below
671 // fully overwrite indices `0..k_atoms()`.
672 let routing = self.routing_logits_row(row);
673 validate_finite_logits(routing, row)?;
674 // Mirror the allocating early-return: only Softmax collapses to a fixed
675 // unit assignment at K==1.
676 if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
677 out[0] = 1.0;
678 return Ok(());
679 }
680 match self.mode {
681 AssignmentMode::Softmax { temperature, .. } => {
682 softmax_row_into(routing, temperature, out)
683 }
684 AssignmentMode::IBPMap {
685 temperature, alpha, ..
686 } => ibp_map_row_into(
687 routing,
688 temperature,
689 resolved_ibp_alpha.unwrap_or(alpha),
690 out,
691 ),
692 AssignmentMode::JumpReLU {
693 temperature,
694 threshold,
695 } => jumprelu_row_into(routing, temperature, threshold, out),
696 };
697 // #1026 — ungated (background-tier) atoms have a fixed unit gate, exactly
698 // as in the allocating path.
699 if self.has_ungated() {
700 for (k, gate) in out.iter_mut().enumerate() {
701 if self.ungated[k] {
702 *gate = 1.0;
703 }
704 }
705 }
706 Ok(())
707 }
708
709 pub(crate) fn persist_resolved_ibp_alpha(&mut self, rho: &SaeManifoldRho) -> bool {
710 let AssignmentMode::IBPMap {
711 temperature,
712 alpha,
713 learnable_alpha: true,
714 } = self.mode
715 else {
716 return false;
717 };
718 let resolved_alpha = resolve_learnable_weight(alpha, rho.log_lambda_sparse);
719 self.mode = AssignmentMode::IBPMap {
720 temperature,
721 alpha: resolved_alpha,
722 learnable_alpha: false,
723 };
724 true
725 }
726
727 pub(crate) fn assignments_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
728 let n = self.n_obs();
729 let k = self.k_atoms();
730 let mut out = Array2::<f64>::zeros((n, k));
731 for row in 0..n {
732 let a = self.try_assignments_row_for_rho(row, rho)?;
733 for atom in 0..k {
734 out[[row, atom]] = a[atom];
735 }
736 }
737 Ok(out)
738 }
739
740 /// Flatten extension coordinates in row-major SAE layout:
741 /// `(assignment chart_i, t_i0[0..d_0], ..., t_iK[0..d_K])` for every row.
742 /// Softmax contributes the first `K - 1` reference logits and omits the
743 /// fixed reference logit; gate-style assignment modes contribute all `K`
744 /// logits.
745 pub fn flatten_ext_coords(&self) -> Array1<f64> {
746 let n = self.n_obs();
747 let q = self.row_block_dim();
748 let k = self.k_atoms();
749 let assignment_dim = self.assignment_coord_dim();
750 let offsets = self.coord_offsets();
751 let mut out = Array1::<f64>::zeros(n * q);
752 for row in 0..n {
753 let base = row * q;
754 for atom in 0..assignment_dim {
755 out[base + atom] = self.logits[[row, atom]];
756 }
757 for atom in 0..k {
758 let d = self.coords[atom].latent_dim();
759 let t_row = self.coords[atom].row(row);
760 for axis in 0..d {
761 out[base + offsets[atom] + axis] = t_row[axis];
762 }
763 }
764 }
765 out
766 }
767
768 #[must_use = "build error must be handled"]
769 pub fn from_blocks_with_mode(
770 logits: Array2<f64>,
771 coord_blocks: Vec<Array2<f64>>,
772 mode: AssignmentMode,
773 ) -> Result<Self, String> {
774 let coords = coord_blocks
775 .iter()
776 .map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
777 .collect();
778 Self::with_mode(logits, coords, mode)
779 }
780
781 #[must_use = "build error must be handled"]
782 pub fn from_blocks_with_mode_and_manifolds(
783 logits: Array2<f64>,
784 coord_blocks: Vec<Array2<f64>>,
785 manifolds: Vec<LatentManifold>,
786 mode: AssignmentMode,
787 ) -> Result<Self, String> {
788 if coord_blocks.len() != manifolds.len() {
789 return Err(format!(
790 "SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
791 coord_blocks.len(),
792 manifolds.len()
793 ));
794 }
795 let coords = coord_blocks
796 .iter()
797 .zip(manifolds)
798 .map(|(c, manifold)| {
799 LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
800 })
801 .collect();
802 Self::with_mode(logits, coords, mode)
803 }
804}
805
806pub(crate) fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
807 match mode {
808 AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
809 AssignmentMode::IBPMap {
810 temperature, alpha, ..
811 } => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
812 AssignmentMode::JumpReLU { .. } => Array1::from_elem(k_atoms, 0.5),
813 }
814}
815
816pub(crate) fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
817 let k = logits.len();
818 let inv_tau = 1.0 / temperature;
819 let mut max_logit = f64::NEG_INFINITY;
820 for &v in logits.iter() {
821 max_logit = max_logit.max(v);
822 }
823 let mut out = Array1::<f64>::zeros(k);
824 let mut sum = 0.0;
825 for i in 0..k {
826 let v = ((logits[i] - max_logit) * inv_tau).exp();
827 out[i] = v;
828 sum += v;
829 }
830 assert!(sum.is_finite() && sum > 0.0);
831 for v in out.iter_mut() {
832 *v /= sum;
833 }
834 out
835}
836
837pub(crate) fn validate_finite_logits(
838 logits: ArrayView1<'_, f64>,
839 row: usize,
840) -> Result<(), String> {
841 for (col, &v) in logits.iter().enumerate() {
842 if !v.is_finite() {
843 return Err(format!(
844 "SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
845 ));
846 }
847 }
848 Ok(())
849}
850
851pub(crate) fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
852 let k = logits.ncols();
853 if k == 0 {
854 return;
855 }
856 if k == 1 {
857 logits.fill(0.0);
858 return;
859 }
860 for row in 0..logits.nrows() {
861 let reference = logits[[row, k - 1]];
862 for col in 0..k - 1 {
863 logits[[row, col]] -= reference;
864 }
865 logits[[row, k - 1]] = 0.0;
866 }
867}
868
869/// Truncated Indian-Buffet-Process stick-breaking prior *means*
870/// `π_k = E[∏_{j=0}^{k} v_j] = (α/(α+1))^{k+1}` for k = 0, .., K-1, with sticks
871/// `v_j ~ Beta(α, 1)` so `E[v_j] = α/(α+1)`. EVERY atom (including the first,
872/// `π_0 = α/(α+1)`) carries the consistent Beta(α, 1) shrinkage: there is no
873/// special-cased always-on base atom, so `α` behaves as a genuine IBP
874/// concentration — larger `α` ⇒ heavier mass / slower decay, `α → 0` ⇒ all mass
875/// collapses onto nothing, matching the stick-breaking limit. This is the
876/// deterministic MAP / mean-field form of the IBP prior (the closed form the
877/// analytic Newton / Hessian / Woodbury machinery differentiates); no sticks are
878/// *sampled* here, the per-atom weight is the exact expectation of the
879/// stick-breaking product. (#614: previously `π_0 = 1` left the first atom
880/// unshrunk, which is the prior mean of NO stick at all and broke α's role as a
881/// concentration; the consistent product mean restores genuine IBP semantics.)
882pub(crate) fn ordered_geometric_shrinkage_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
883 // Accumulate the geometric schedule `π_k = ratio^(k+1)` in LOG space so the
884 // prior stays a finite *soft* weight even for large `K`. The naive product
885 // `acc *= ratio` underflows to exact `0.0` once `ratio^(k+1) < f64::MIN_POSITIVE`
886 // (e.g. `(0.1/1.1)^320`), which would turn the soft shrinkage prior into a
887 // HARD mask: such atoms would receive zero assignment AND zero logit
888 // gradient (the gradient is multiplied by `π_k`), so they could never
889 // reactivate. Working in log-space and flooring the exponentiated weight at
890 // the smallest positive normal keeps every atom's gradient path alive while
891 // preserving the geometric ordering.
892 let mut out = Array1::<f64>::zeros(k_atoms);
893 let log_ratio = (alpha / (alpha + 1.0)).ln();
894 for k in 0..k_atoms {
895 // π_k = (α/(α+1))^{k+1}: the product of (k+1) i.i.d. Beta(α,1) stick
896 // means, so atom 0 is also shrunk by one stick (E[v_0] = α/(α+1)).
897 let log_pi = ((k + 1) as f64) * log_ratio;
898 out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
899 }
900 out
901}
902
903/// IBP-MAP row activations: per-atom sigmoid likelihood times the truncated
904/// stick-breaking prior mean `π_k = (α/(α+1))^{k+1}`. With tied logits the prior
905/// dominates and yields strictly decreasing activations in atom index, with the
906/// first atom already shrunk by one Beta(α,1) stick mean (no unshrunk base atom).
907pub fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
908 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
909 let mut out = Array1::<f64>::zeros(logits.len());
910 for i in 0..logits.len() {
911 out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
912 }
913 out
914}
915
916/// IBP-MAP activations together with the diagonal Jacobian `∂z_k/∂l_k`,
917/// shared with the torch autograd `Function` so the Python IBP-Gumbel path
918/// applies the same stick-breaking prior mean `π_k = (α/(α+1))^{k+1}` and
919/// temperature scaling as the Rust closed form. With `z_k = σ(l_k/τ)·π_k` the
920/// per-atom derivative is
921/// `σ(l_k/τ)(1 − σ(l_k/τ))·π_k / τ`; the map is diagonal in `k`, so the
922/// Jacobian is returned as the per-atom diagonal vector.
923#[must_use]
924pub fn ibp_map_row_value_grad(
925 logits: ArrayView1<'_, f64>,
926 temperature: f64,
927 alpha: f64,
928) -> (Array1<f64>, Array1<f64>) {
929 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
930 let inv_tau = 1.0 / temperature;
931 let mut value = Array1::<f64>::zeros(logits.len());
932 let mut grad = Array1::<f64>::zeros(logits.len());
933 for i in 0..logits.len() {
934 let sig = gam_linalg::utils::stable_logistic(logits[i] * inv_tau);
935 value[i] = sig * prior[i];
936 grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
937 }
938 (value, grad)
939}
940
941pub fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
942 let mut out = Array1::<f64>::zeros(logits.len());
943 for i in 0..logits.len() {
944 // Hard gate: strictly zero below threshold (the intended "jump"). Above
945 // threshold the surrogate is centered at the threshold so the gate is
946 // most informative exactly at the boundary it switches on:
947 // σ((l−θ)/τ) ∈ [0.5, 1). Magnitude lives in the decoder, so the gate
948 // stays bounded in [0, 1] by design.
949 if logits[i] > threshold {
950 out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
951 }
952 }
953 out
954}
955
956// #1557 — fill-into-caller-buffer variants of the three per-mode row functions.
957// These compute the EXACT SAME values as `softmax_row` / `ibp_map_row` /
958// `jumprelu_row` (same arithmetic, same order of operations) but write into a
959// caller-provided `&mut [f64]` slice instead of heap-allocating a fresh
960// `Array1<f64>` per call. The hot per-row loops (loss eval, arrow/Schur row
961// loops) call these with a reused scratch buffer, eliminating millions of tiny
962// K-sized allocations while staying bit-identical to the allocating path.
963// `out` must have length `logits.len()`; the slice is fully overwritten.
964
965pub(crate) fn softmax_row_into(logits: ArrayView1<'_, f64>, temperature: f64, out: &mut [f64]) {
966 let k = logits.len();
967 let inv_tau = 1.0 / temperature;
968 let mut max_logit = f64::NEG_INFINITY;
969 for &v in logits.iter() {
970 max_logit = max_logit.max(v);
971 }
972 let mut sum = 0.0;
973 for i in 0..k {
974 let v = ((logits[i] - max_logit) * inv_tau).exp();
975 out[i] = v;
976 sum += v;
977 }
978 assert!(sum.is_finite() && sum > 0.0);
979 for v in out.iter_mut() {
980 *v /= sum;
981 }
982}
983
984pub(crate) fn ibp_map_row_into(
985 logits: ArrayView1<'_, f64>,
986 temperature: f64,
987 alpha: f64,
988 out: &mut [f64],
989) {
990 let prior = ordered_geometric_shrinkage_prior(logits.len(), alpha);
991 for i in 0..logits.len() {
992 out[i] = gam_linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
993 }
994}
995
996pub(crate) fn jumprelu_row_into(
997 logits: ArrayView1<'_, f64>,
998 temperature: f64,
999 threshold: f64,
1000 out: &mut [f64],
1001) {
1002 for i in 0..logits.len() {
1003 // Match `jumprelu_row`: strictly zero below threshold, sigmoid surrogate
1004 // above. The buffer is fully overwritten (no read of prior contents).
1005 if logits[i] > threshold {
1006 out[i] = gam_linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
1007 } else {
1008 out[i] = 0.0;
1009 }
1010 }
1011}
1012
1013pub(crate) struct ActiveAtomLogitJvp<'a> {
1014 pub(crate) mode: AssignmentMode,
1015 pub(crate) k: usize,
1016 pub(crate) logit_k: f64,
1017 pub(crate) a_k: f64,
1018 pub(crate) decoded_k: ArrayView1<'a, f64>,
1019 pub(crate) fitted: ArrayView1<'a, f64>,
1020 pub(crate) ibp_prior: Option<&'a [f64]>,
1021 pub(crate) compact_index: usize,
1022 /// #1026 — when `true`, atom `k` is the ungated background tier: its gate is
1023 /// the constant `1`, so its logit-JVP `da_k/dl_k` is identically zero (the
1024 /// compact row is left untouched / zero).
1025 pub(crate) ungated: bool,
1026}
1027
1028/// Fill the single compact logit-JVP row for active atom `k`, using the
1029/// per-mode assignment sensitivity `da_k/dl_k` contracted into the decoded /
1030/// fitted-corrected output direction. This is the active-set analogue of
1031/// [`fill_assignment_logit_jvp_rows`]: it reproduces that function's diagonal
1032/// logit row exactly for the atom `k`, but writes into a compact position of a
1033/// heterogeneous-`q` row block instead of the dense full-`K` Jacobian. `fitted`
1034/// is the row's *active-set* reconstruction so the softmax cross term
1035/// `(decoded_k − fitted)` is consistent with the curvature the compact block
1036/// carries.
1037pub(crate) fn fill_active_atom_logit_jvp(
1038 input: ActiveAtomLogitJvp<'_>,
1039 jac_compact: &mut Array2<f64>,
1040) {
1041 let ActiveAtomLogitJvp {
1042 mode,
1043 k,
1044 logit_k,
1045 a_k,
1046 decoded_k,
1047 fitted,
1048 ibp_prior,
1049 compact_index,
1050 ungated,
1051 } = input;
1052 let p = fitted.len();
1053 // #1026 — an ungated atom's gate is constant, so its logit-JVP is zero; leave
1054 // its compact row untouched (the buffer row is pre-zeroed by the caller).
1055 if ungated {
1056 return;
1057 }
1058 match mode {
1059 AssignmentMode::Softmax { temperature, .. } => {
1060 // da_k/dl_k contracted: a_k (decoded_k − fitted) / τ.
1061 let inv_tau = 1.0 / temperature;
1062 for out_col in 0..p {
1063 jac_compact[[compact_index, out_col]] =
1064 a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
1065 }
1066 }
1067 AssignmentMode::IBPMap { temperature, .. } => {
1068 // z_k = σ(l_k/τ)·π_k ⇒ dz_k/dl_k = a_k(π_k − a_k)/(π_k τ) · π_k form
1069 // (matches `fill_assignment_logit_jvp_rows`).
1070 let inv_tau = 1.0 / temperature;
1071 let prior =
1072 ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
1073 let pi_k = prior[k];
1074 let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1075 let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1076 for out_col in 0..p {
1077 jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
1078 }
1079 }
1080 AssignmentMode::JumpReLU {
1081 temperature,
1082 threshold,
1083 } => {
1084 // The data-fit Jacobian follows the hard forward gate. Below the
1085 // threshold the reconstruction contribution is exactly zero, so the
1086 // data-fit logit derivative must also be zero. Band-only atoms stay
1087 // in the compact row for prior terms, not phantom reconstruction
1088 // slope.
1089 if logit_k <= threshold {
1090 return;
1091 }
1092 let inv_tau = 1.0 / temperature;
1093 let activation = gam_linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
1094 let da = activation * (1.0 - activation) * inv_tau;
1095 for out_col in 0..p {
1096 jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
1097 }
1098 }
1099 }
1100}
1101
1102pub(crate) fn fill_assignment_logit_jvp_rows(
1103 mode: AssignmentMode,
1104 logits: ArrayView1<'_, f64>,
1105 assignments: ArrayView1<'_, f64>,
1106 decoded: ArrayView2<'_, f64>,
1107 fitted: ArrayView1<'_, f64>,
1108 ibp_prior: Option<&[f64]>,
1109 // #1026 — per-atom ungated flags (length `K`). An ungated atom's gate is
1110 // constant, so its logit-JVP row is identically zero (skipped below). Empty
1111 // ⇒ no atom is ungated (the historical path, bit-identical).
1112 ungated: &[bool],
1113 local_jac: &mut Array2<f64>,
1114) {
1115 let is_ungated = |k: usize| ungated.get(k).copied().unwrap_or(false);
1116 match mode {
1117 AssignmentMode::Softmax { temperature, .. } => {
1118 if assignments.len() == 1 {
1119 return;
1120 }
1121 // da_k/dl_j = a_k (1[k=j] - a_j) / tau, contracted against
1122 // the assignment-weighted fitted row. The dense row layout uses
1123 // the reference-logit chart, so only columns `0..K-1` are free;
1124 // the final reference logit is fixed at zero and has no row.
1125 let inv_tau = 1.0 / temperature;
1126 for logit_col in 0..assignments.len() - 1 {
1127 if is_ungated(logit_col) {
1128 continue;
1129 }
1130 for out_col in 0..fitted.len() {
1131 local_jac[[logit_col, out_col]] = assignments[logit_col]
1132 * (decoded[[logit_col, out_col]] - fitted[out_col])
1133 * inv_tau;
1134 }
1135 }
1136 }
1137 AssignmentMode::IBPMap { temperature, .. } => {
1138 // Truncated-IBP concrete relaxation: z_k = σ(l_k/τ) · π_k where
1139 // π_k is the stick-breaking prior. Thus
1140 // dz_k/dl_k = σ(l/τ)(1-σ(l/τ))/τ · π_k = a_k(π_k - a_k)/(π_k τ).
1141 let inv_tau = 1.0 / temperature;
1142 let prior = ibp_prior
1143 .expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
1144 for logit_col in 0..assignments.len() {
1145 if is_ungated(logit_col) {
1146 continue;
1147 }
1148 let pi_k = prior[logit_col];
1149 let a_k = assignments[logit_col];
1150 let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
1151 let dz = sig * (1.0 - sig) * inv_tau * pi_k;
1152 for out_col in 0..fitted.len() {
1153 local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
1154 }
1155 }
1156 }
1157 AssignmentMode::JumpReLU {
1158 temperature,
1159 threshold,
1160 } => {
1161 // Data-fit sensitivity follows the hard forward gate: rows at or
1162 // below the threshold have zero reconstruction value and therefore
1163 // zero data-fit logit derivative. The wider machine-precision prior
1164 // support is a compact-layout/prior rule, not a data-fit STE.
1165 let inv_tau = 1.0 / temperature;
1166 for logit_col in 0..assignments.len() {
1167 if is_ungated(logit_col) || logits[logit_col] <= threshold {
1168 continue;
1169 }
1170 let activation = gam_linalg::utils::stable_logistic(
1171 (logits[logit_col] - threshold) * inv_tau,
1172 );
1173 let da = activation * (1.0 - activation) * inv_tau;
1174 for out_col in 0..fitted.len() {
1175 local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
1176 }
1177 }
1178 }
1179 }
1180}
1181
1182pub(crate) fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
1183 let mut out = Array1::<f64>::zeros(logits.len());
1184 for row in 0..logits.nrows() {
1185 let start = row * logits.ncols();
1186 for col in 0..logits.ncols() {
1187 out[start + col] = logits[[row, col]];
1188 }
1189 }
1190 out
1191}
1192
1193pub(crate) fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
1194 for row in 0..assignment.n_obs() {
1195 validate_finite_logits(assignment.logits.row(row), row)
1196 .expect("assignment logits must be finite");
1197 }
1198 let target = flat_logits(assignment.logits.view());
1199 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1200 return 0.0;
1201 }
1202 match assignment.mode {
1203 AssignmentMode::Softmax {
1204 temperature,
1205 sparsity,
1206 } => {
1207 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1208 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1209 penalty.value(target.view(), rho_view.view())
1210 }
1211 AssignmentMode::IBPMap {
1212 temperature,
1213 alpha,
1214 learnable_alpha,
1215 } => {
1216 let mut penalty = IBPAssignmentPenalty::new(
1217 assignment.k_atoms(),
1218 alpha,
1219 temperature,
1220 learnable_alpha,
1221 );
1222 let rho_view = if learnable_alpha {
1223 Array1::from_vec(vec![rho.log_lambda_sparse])
1224 } else {
1225 // Keep the fixed-alpha value path on the same weighting branch as
1226 // assignment_prior_grad_hdiag; that gradient path owns the
1227 // lambda_sparse convention for IBP assignment sparsity.
1228 penalty.weight = rho.lambda_sparse();
1229 Array1::zeros(0)
1230 };
1231 penalty.value(target.view(), rho_view.view())
1232 }
1233 AssignmentMode::JumpReLU {
1234 temperature,
1235 threshold,
1236 } => {
1237 // Sparsity penalty uses the same threshold-centered surrogate and
1238 // machine-precision support as its gradient/Hessian. Data-fit
1239 // reconstruction remains hard-gated by `jumprelu_row`.
1240 let sparsity_strength = rho.lambda_sparse();
1241 let mut acc = 0.0;
1242 for &logit in target.iter() {
1243 if jumprelu_in_optimization_band(logit, threshold, temperature) {
1244 acc += gam_linalg::utils::stable_logistic((logit - threshold) / temperature);
1245 }
1246 }
1247 sparsity_strength * acc
1248 }
1249 }
1250}
1251
1252pub(crate) fn assignment_prior_log_strength_derivative(
1253 assignment: &SaeAssignment,
1254 rho: &SaeManifoldRho,
1255) -> f64 {
1256 for row in 0..assignment.n_obs() {
1257 validate_finite_logits(assignment.logits.row(row), row)
1258 .expect("assignment logits must be finite");
1259 }
1260 let target = flat_logits(assignment.logits.view());
1261 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1262 return 0.0;
1263 }
1264 match assignment.mode {
1265 AssignmentMode::Softmax { .. } | AssignmentMode::JumpReLU { .. } => {
1266 assignment_prior_value(assignment, rho)
1267 }
1268 AssignmentMode::IBPMap {
1269 temperature,
1270 alpha,
1271 learnable_alpha,
1272 } => {
1273 let mut penalty = IBPAssignmentPenalty::new(
1274 assignment.k_atoms(),
1275 alpha,
1276 temperature,
1277 learnable_alpha,
1278 );
1279 if learnable_alpha {
1280 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1281 penalty.grad_rho(target.view(), rho_view.view())[0]
1282 } else {
1283 penalty.weight = rho.lambda_sparse();
1284 penalty.value(target.view(), Array1::<f64>::zeros(0).view())
1285 }
1286 }
1287 }
1288}
1289
1290pub(crate) fn assignment_prior_log_strength_hdiag(
1291 assignment: &SaeAssignment,
1292 rho: &SaeManifoldRho,
1293) -> Result<Array1<f64>, String> {
1294 for row in 0..assignment.n_obs() {
1295 validate_finite_logits(assignment.logits.row(row), row)?;
1296 }
1297 let target = flat_logits(assignment.logits.view());
1298 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1299 return Ok(Array1::<f64>::zeros(target.len()));
1300 }
1301 match assignment.mode {
1302 AssignmentMode::Softmax {
1303 temperature,
1304 sparsity,
1305 } => {
1306 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1307 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1308 penalty
1309 .hessian_diag(target.view(), rho_view.view())
1310 .ok_or_else(|| {
1311 "softmax assignment log-strength hessian diag unavailable".to_string()
1312 })
1313 }
1314 AssignmentMode::JumpReLU {
1315 temperature,
1316 threshold,
1317 } => {
1318 let sparsity_strength = rho.lambda_sparse();
1319 let inv_tau = 1.0 / temperature;
1320 let inv_tau2 = inv_tau * inv_tau;
1321 let mut d = Array1::<f64>::zeros(target.len());
1322 for idx in 0..target.len() {
1323 let logit = target[idx];
1324 if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1325 continue;
1326 }
1327 let activation =
1328 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1329 let slope = activation * (1.0 - activation);
1330 d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1331 }
1332 Ok(d)
1333 }
1334 AssignmentMode::IBPMap {
1335 temperature,
1336 alpha,
1337 learnable_alpha,
1338 } => {
1339 let mut penalty = IBPAssignmentPenalty::new(
1340 assignment.k_atoms(),
1341 alpha,
1342 temperature,
1343 learnable_alpha,
1344 );
1345 if learnable_alpha {
1346 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1347 Ok(penalty.hessian_diag_log_alpha_derivative(target.view(), rho_view.view()))
1348 } else {
1349 penalty.weight = rho.lambda_sparse();
1350 penalty
1351 .hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
1352 .ok_or_else(|| {
1353 "IBP assignment log-strength hessian diag unavailable".to_string()
1354 })
1355 }
1356 }
1357 }
1358}
1359
1360pub(crate) fn assignment_prior_log_strength_target_mixed(
1361 assignment: &SaeAssignment,
1362 rho: &SaeManifoldRho,
1363) -> Result<Array1<f64>, String> {
1364 for row in 0..assignment.n_obs() {
1365 validate_finite_logits(assignment.logits.row(row), row)?;
1366 }
1367 let target = flat_logits(assignment.logits.view());
1368 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1369 return Ok(Array1::<f64>::zeros(target.len()));
1370 }
1371 match assignment.mode {
1372 AssignmentMode::IBPMap {
1373 temperature,
1374 alpha,
1375 learnable_alpha: true,
1376 } => {
1377 let penalty = IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, true);
1378 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
1379 Ok(penalty.log_alpha_target_mixed_derivative(target.view(), rho_view.view()))
1380 }
1381 _ => Ok(assignment_prior_grad_hdiag(assignment, rho)?.0),
1382 }
1383}
1384
1385pub(crate) fn assignment_prior_grad_hdiag(
1386 assignment: &SaeAssignment,
1387 rho: &SaeManifoldRho,
1388) -> Result<(Array1<f64>, Array1<f64>), String> {
1389 for row in 0..assignment.n_obs() {
1390 validate_finite_logits(assignment.logits.row(row), row)?;
1391 }
1392 let target = flat_logits(assignment.logits.view());
1393 let mut grad = Array1::<f64>::zeros(target.len());
1394 let mut diag = Array1::<f64>::zeros(target.len());
1395 if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
1396 return Ok((grad, diag));
1397 }
1398 let (sparsity_grad, sparsity_diag) = match assignment.mode {
1399 AssignmentMode::Softmax {
1400 temperature,
1401 sparsity,
1402 } => {
1403 let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
1404 let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
1405 let g = penalty.grad_target(target.view(), rho_view.view());
1406 let d = penalty
1407 .hessian_diag(target.view(), rho_view.view())
1408 .ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
1409 (g, d)
1410 }
1411 AssignmentMode::IBPMap {
1412 temperature,
1413 alpha,
1414 learnable_alpha,
1415 } => {
1416 // Scale the IBP assignment-sparsity prior by `lambda_sparse`, exactly
1417 // like the Softmax and JumpReLU branches do (Softmax folds it into the
1418 // penalty's rho coordinate, JumpReLU multiplies `sparsity_strength`).
1419 // Previously the IBP penalty used its hardcoded `weight = 1.0` and the
1420 // `rho.log_lambda_sparse` coordinate never reached it (the rho_view was
1421 // empty for the common `learnable_alpha = false` config), so the prior
1422 // ran at full strength with no way to dial it down — and its
1423 // Beta-Bernoulli BCE energy `−mass·ln π_k − (n−mass)·ln(1−π_k)` toward
1424 // the self-referential empirical active fraction `π_k` has its global
1425 // minimum at the all-off gate, so at full weight it over-shrank the
1426 // assignment off both atoms even with a truth-seeded decoder (#853).
1427 // Routing `lambda_sparse` into the penalty weight makes the prior a
1428 // genuine, user-controllable lever balanced against the data fit.
1429 let mut penalty = IBPAssignmentPenalty::new(
1430 assignment.k_atoms(),
1431 alpha,
1432 temperature,
1433 learnable_alpha,
1434 );
1435 // When `alpha` is learnable, `log_lambda_sparse` already modulates
1436 // it through `resolved_alpha(rho)`, so the weight stays 1.0 to avoid
1437 // double-counting that coordinate. Only when `alpha` is fixed (so the
1438 // sparse coordinate would otherwise be ignored entirely) does
1439 // `lambda_sparse` become the prior's weight lever.
1440 let rho_view = if learnable_alpha {
1441 Array1::from_vec(vec![rho.log_lambda_sparse])
1442 } else {
1443 penalty.weight = rho.lambda_sparse();
1444 Array1::zeros(0)
1445 };
1446 let g = penalty.grad_target(target.view(), rho_view.view());
1447 let d = penalty
1448 .hessian_diag(target.view(), rho_view.view())
1449 .ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
1450 (g, d)
1451 }
1452 AssignmentMode::JumpReLU {
1453 temperature,
1454 threshold,
1455 } => {
1456 // Gradient and exact diagonal Hessian of the sparsity value's
1457 // threshold-centered surrogate σ((l−θ)/τ), using the same
1458 // machine-precision support as the value path. Data-fit JVP support
1459 // is narrower and follows the hard forward gate.
1460 let sparsity_strength = rho.lambda_sparse();
1461 let inv_tau = 1.0 / temperature;
1462 let inv_tau2 = inv_tau * inv_tau;
1463 let mut g = Array1::<f64>::zeros(target.len());
1464 let mut d = Array1::<f64>::zeros(target.len());
1465 for idx in 0..target.len() {
1466 let logit = target[idx];
1467 if !jumprelu_in_optimization_band(logit, threshold, temperature) {
1468 continue;
1469 }
1470 let activation =
1471 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
1472 let slope = activation * (1.0 - activation);
1473 g[idx] = sparsity_strength * slope * inv_tau;
1474 d[idx] = sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2;
1475 }
1476 (g, d)
1477 }
1478 };
1479 grad += &sparsity_grad;
1480 diag += &sparsity_diag;
1481 // #1026/#1033 — a FIXED logit (an ungated atom's, or every atom's under
1482 // frozen routing) is not a free parameter, so it carries NO sparsity-prior
1483 // gradient or curvature. Zero its flat columns (`flat_logits` is row-major
1484 // `row*K + atom`) so the assembled `gt` and `htt` logit slots stay zero —
1485 // matching the zero logit-JVP. The column-separable IBP / JumpReLU priors are
1486 // per-atom, so zeroing one atom's columns leaves the others' prior intact;
1487 // under frozen routing ALL atoms' logit columns are zeroed (the whole routing
1488 // is a fixed predicted function, not optimized).
1489 if assignment.has_ungated() || assignment.routing_is_frozen() {
1490 let k = assignment.k_atoms();
1491 for idx in 0..grad.len() {
1492 if assignment.logit_is_fixed(idx % k) {
1493 grad[idx] = 0.0;
1494 diag[idx] = 0.0;
1495 }
1496 }
1497 }
1498 Ok((grad, diag))
1499}
1500
1501/// Build the exact IBP `hessian_diag` logit third-derivative channels (#1006)
1502/// for the SAE log-det adjoint Γ, using the SAME penalty configuration —
1503/// `alpha`/`tau`/`learnable_alpha` and the `lambda_sparse` weight convention —
1504/// that [`assignment_prior_grad_hdiag`] assembles into `htt`. Returns `None`
1505/// for non-IBP assignment modes (no cross-row empirical-π coupling to correct).
1506pub(crate) fn ibp_assignment_third_channels(
1507 assignment: &SaeAssignment,
1508 rho: &SaeManifoldRho,
1509) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
1510 let AssignmentMode::IBPMap {
1511 temperature,
1512 alpha,
1513 learnable_alpha,
1514 } = assignment.mode
1515 else {
1516 return Ok(None);
1517 };
1518 for row in 0..assignment.n_obs() {
1519 validate_finite_logits(assignment.logits.row(row), row)?;
1520 }
1521 let target = flat_logits(assignment.logits.view());
1522 let mut penalty =
1523 IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
1524 // Mirror assignment_prior_grad_hdiag exactly: when alpha is learnable the
1525 // sparse coordinate already modulates it through resolved_alpha(rho), so the
1526 // weight stays 1.0; otherwise lambda_sparse becomes the prior's weight lever.
1527 let rho_view = if learnable_alpha {
1528 Array1::from_vec(vec![rho.log_lambda_sparse])
1529 } else {
1530 penalty.weight = rho.lambda_sparse();
1531 Array1::zeros(0)
1532 };
1533 let mut channels = penalty.hessian_diag_logit_third_channels(target.view(), rho_view.view());
1534 // #1026/#1033 — zero the log-det third-derivative channels of FIXED-logit
1535 // atoms (ungated, or all atoms under frozen routing) so the #1006 θ-adjoint
1536 // differentiates the SAME (fixed-logit-zeroed) `htt` that
1537 // `assignment_prior_grad_hdiag` assembled. `k_max` columns, row-major `N·K`
1538 // for the per-(row,atom) arrays and length-`K` for the per-column ones.
1539 if assignment.has_ungated() || assignment.routing_is_frozen() {
1540 let k = channels.k_max;
1541 for idx in 0..channels.z_jac.len() {
1542 if assignment.logit_is_fixed(idx % k) {
1543 channels.z_jac[idx] = 0.0;
1544 channels.local_logit_third[idx] = 0.0;
1545 channels.m_channel[idx] = 0.0;
1546 channels.logit_curvature[idx] = 0.0;
1547 }
1548 }
1549 for atom in 0..k {
1550 if assignment.logit_is_fixed(atom) {
1551 channels.cross_row_d[atom] = 0.0;
1552 channels.cross_row_dd[atom] = 0.0;
1553 }
1554 }
1555 }
1556 Ok(Some(channels))
1557}
1558
1559/// #1026 hybrid curved + linear-tail adjudication for one SAE atom slot.
1560///
1561/// A hybrid dictionary lets each atom slot be either a CURVED atom (its fitted
1562/// `latent_dim ≥ 1` manifold chart, whose decoded image may turn) or its LINEAR
1563/// special case (the euclidean-d=1-linear atom — one straight decoder direction,
1564/// `γ(t) = t·b`, zero turning). The two are nested: the linear atom is exactly
1565/// the curved family restricted to its straight sub-model, so a hybrid slot
1566/// cannot lose to pure-linear at matched actives — it strictly generalizes it.
1567///
1568/// This is the single call the SAE fitter makes per atom to choose the split by
1569/// EVIDENCE rather than fiat. It packages the atom's two already-fitted
1570/// candidates — each scored on the COMMON rank-aware Laplace scale (`−V = NLE`,
1571/// lower wins, identical to the union/mixture rungs) on the same rows — and
1572/// routes them through [`select_hybrid_atom`]. The curved candidate's fitted
1573/// turning `Θ` (from
1574/// [`crate::chart_canonicalization::d1_atom_fitted_turning`]) enters
1575/// as the decision feature: a `Θ → 0` atom yields to the cheaper linear tail by
1576/// construction (the dominance floor — a curved atom buys nothing on a straight
1577/// feature), a high-`Θ` atom takes the curved parameterization when its
1578/// curvature lowers the NLE by more than its extra-parameter price (the `Θ/√ε`
1579/// crossover).
1580///
1581/// `manifold` is the atom's fitted chart manifold; a non-curveable (already
1582/// Euclidean-flat) chart can only present the linear candidate, which this
1583/// helper enforces by ignoring any curved candidate offered for a flat chart —
1584/// a flat chart has no curvature to price, so the linear special case is its
1585/// only honest parameterization. Curveable charts present both candidates.
1586///
1587/// # Wiring into the fitter (the one call into `sae_manifold.rs`)
1588///
1589/// The post-fit pass in `sae_manifold.rs` already computes each d=1 atom's
1590/// fitted turning `Θ` (the read-only EV-vs-Θ diagnostic). To make the split
1591/// load-bearing, that pass supplies, per atom, the curved-candidate NLE +
1592/// parameter count + `Θ` and the linear-candidate NLE + parameter count (both
1593/// fitted on the atom's rows), and calls this helper; the returned
1594/// [`HybridAtomChoice`] tells the fitter which parameterization to keep for that
1595/// slot. The fitting of the two candidates lives in `sae_manifold.rs` (the
1596/// manifold-chart fitter); the SELECTION/scoring lives here.
1597pub fn select_hybrid_atom_parameterization(
1598 manifold: &LatentManifold,
1599 curved: Option<HybridAtomCandidate>,
1600 linear: HybridAtomCandidate,
1601) -> HybridAtomChoice {
1602 // A flat (Euclidean) chart has no curvature to price: its only honest
1603 // parameterization is the linear special case, so any curved candidate
1604 // offered for it is dropped before the evidence comparison. Curveable charts
1605 // (Circle / Sphere / Torus / curved products) present both candidates.
1606 let curved = if manifold.is_euclidean() {
1607 None
1608 } else {
1609 curved
1610 };
1611 let candidates: Vec<HybridAtomCandidate> = match curved {
1612 Some(c) => vec![linear, c],
1613 None => vec![linear],
1614 };
1615 // `candidates` is never empty (it always contains the linear candidate), so
1616 // the selector always returns a choice.
1617 select_hybrid_atom(&candidates).expect("hybrid atom slot always has the linear candidate")
1618}
1619
1620#[cfg(test)]
1621mod ibp_prior_614_tests {
1622 // #614: `ibp_stick_breaking_prior` used to compute `π_k = (α/(α+1))^k` with
1623 // `π_0 = 1`, i.e. an UNSHRUNK first atom — the prior mean of no stick at all,
1624 // which broke α's role as an IBP concentration parameter. The consistent
1625 // truncated-IBP stick-breaking prior mean is `π_k = (α/(α+1))^{k+1}`, the
1626 // expectation of the product of (k+1) i.i.d. Beta(α,1) stick means, so EVERY
1627 // atom (including the first) carries one stick of shrinkage. This test pins
1628 // that contract so the regression cannot silently return.
1629 use super::*;
1630
1631 fn ratio(alpha: f64) -> f64 {
1632 alpha / (alpha + 1.0)
1633 }
1634
1635 #[test]
1636 fn first_atom_is_shrunk_not_unity() {
1637 // The #614 defect: π_0 must equal the single-stick mean α/(α+1), NOT 1.0.
1638 for &alpha in &[0.1_f64, 0.5, 1.0, 2.0, 5.0] {
1639 let prior = ordered_geometric_shrinkage_prior(8, alpha);
1640 let r = ratio(alpha);
1641 assert!(
1642 (prior[0] - r).abs() < 1e-12,
1643 "π_0 must be the single-stick mean α/(α+1)={r} (was the unshrunk 1.0 in #614); got {}",
1644 prior[0]
1645 );
1646 assert!(
1647 prior[0] < 1.0,
1648 "first atom must be shrunk (π_0<1) for alpha={alpha}; got {}",
1649 prior[0]
1650 );
1651 }
1652 }
1653
1654 #[test]
1655 fn prior_is_consistent_geometric_product_mean() {
1656 // π_k = (α/(α+1))^{k+1} exactly, and every successive ratio equals α/(α+1).
1657 for &alpha in &[0.3_f64, 1.0, 4.0] {
1658 let k = 12;
1659 let prior = ordered_geometric_shrinkage_prior(k, alpha);
1660 let r = ratio(alpha);
1661 for j in 0..k {
1662 let expected = r.powi((j + 1) as i32);
1663 assert!(
1664 (prior[j] - expected).abs() < 1e-12 * expected.max(1.0),
1665 "alpha={alpha} π_{j}: expected {expected}, got {}",
1666 prior[j]
1667 );
1668 }
1669 // Strictly decreasing (ordered shrinkage), no plateau at the head.
1670 for j in 1..k {
1671 assert!(
1672 prior[j] < prior[j - 1],
1673 "alpha={alpha}: prior must strictly decrease at index {j}"
1674 );
1675 }
1676 }
1677 }
1678
1679 #[test]
1680 fn alpha_behaves_as_concentration() {
1681 // Larger α => heavier mass / slower decay: π_0 increases toward 1 and the
1682 // tail (e.g. π_4) carries more mass. This is the IBP-concentration role
1683 // the #614 fix restored.
1684 let lo = ordered_geometric_shrinkage_prior(8, 0.5);
1685 let hi = ordered_geometric_shrinkage_prior(8, 5.0);
1686 assert!(
1687 hi[0] > lo[0],
1688 "larger alpha must raise π_0 (concentration): {} vs {}",
1689 hi[0],
1690 lo[0]
1691 );
1692 assert!(
1693 hi[4] > lo[4],
1694 "larger alpha must put more mass in the tail: {} vs {}",
1695 hi[4],
1696 lo[4]
1697 );
1698 }
1699}
1700
1701#[cfg(test)]
1702mod hybrid_split_tests {
1703 use super::*;
1704 use gam_solve::evidence::HybridAtomParam;
1705
1706 #[test]
1707 fn flat_chart_drops_curved_candidate_and_keeps_linear() {
1708 // A Euclidean chart has no curvature: even if a curved candidate with a
1709 // lower NLE is offered, the helper drops it (a flat chart cannot honestly
1710 // present a curved parameterization).
1711 let linear = HybridAtomCandidate::linear(100.0, 2);
1712 let curved = HybridAtomCandidate::curved(1, 1.0, 5, Some(2.0));
1713 let choice =
1714 select_hybrid_atom_parameterization(&LatentManifold::Euclidean, Some(curved), linear);
1715 assert!(choice.param.is_linear());
1716 }
1717
1718 #[test]
1719 fn curveable_chart_selects_curved_when_turning_pays() {
1720 // A Circle chart presents both candidates; a turning feature whose curved
1721 // fit beats the linear secant on evidence selects curved.
1722 let linear = HybridAtomCandidate::linear(100.0, 2);
1723 let curved = HybridAtomCandidate::curved(1, 70.0, 5, Some(2.0 * std::f64::consts::PI));
1724 let choice = select_hybrid_atom_parameterization(
1725 &LatentManifold::Circle {
1726 period: 2.0 * std::f64::consts::PI,
1727 },
1728 Some(curved),
1729 linear,
1730 );
1731 assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
1732 }
1733
1734 #[test]
1735 fn curveable_chart_falls_back_to_linear_when_no_curved_candidate() {
1736 let linear = HybridAtomCandidate::linear(33.0, 2);
1737 let choice = select_hybrid_atom_parameterization(
1738 &LatentManifold::Circle {
1739 period: 2.0 * std::f64::consts::PI,
1740 },
1741 None,
1742 linear,
1743 );
1744 assert!(choice.param.is_linear());
1745 assert_eq!(choice.num_parameters, 2);
1746 }
1747}
1748
1749#[cfg(test)]
1750mod frozen_routing_1033_tests {
1751 //! #1033 — the FROZEN (amortized) routing mechanism: once installed, the
1752 //! per-row gate is a ρ-invariant function of the FROZEN predicted logits and
1753 //! is DECOUPLED from any subsequent update to the free `self.logits` (the
1754 //! inner-fit logit drift the outer ρ-search would otherwise re-incur every
1755 //! eval). These are deterministic mechanism invariants — no inner fit — so
1756 //! they pin the load-bearing freeze properties without the cluster.
1757 use super::*;
1758
1759 fn ibp_assignment(n: usize, k: usize) -> SaeAssignment {
1760 let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1761 0.3 + 0.05 * (i as f64) - 0.1 * (kk as f64)
1762 });
1763 let coords: Vec<Array2<f64>> = (0..k)
1764 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1765 .collect();
1766 // learnable_alpha = false: alpha is ρ-independent, isolating the routing.
1767 SaeAssignment::from_blocks_with_mode(
1768 logits,
1769 coords,
1770 AssignmentMode::ibp_map(0.5, 1.0, false),
1771 )
1772 .unwrap()
1773 }
1774
1775 #[test]
1776 fn frozen_routing_decouples_gates_from_logit_updates_1033() {
1777 let (n, k) = (6usize, 3usize);
1778 let mut a = ibp_assignment(n, k)
1779 .freeze_routing_from_current_logits()
1780 .unwrap();
1781 assert!(a.routing_is_frozen());
1782 // Gates BEFORE mutating the free logits.
1783 let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1); k]);
1784 let before: Vec<Array1<f64>> = (0..n)
1785 .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1786 .collect();
1787 // Simulate an inner-fit logit update (what the ρ-search would otherwise do
1788 // every eval): perturb every free logit substantially.
1789 a.logits.mapv_inplace(|v| v + 5.0);
1790 let after: Vec<Array1<f64>> = (0..n)
1791 .map(|r| a.try_assignments_row_for_rho(r, &rho).unwrap())
1792 .collect();
1793 // FROZEN routing reads the snapshot, so the gates are UNCHANGED by the
1794 // free-logit perturbation — the routing is decoupled from inner-fit drift.
1795 for r in 0..n {
1796 for kk in 0..k {
1797 assert_eq!(
1798 before[r][kk], after[r][kk],
1799 "row {r} atom {kk}: frozen-routing gate must be UNCHANGED by a free-logit \
1800 update (decoupled from inner-fit drift); {} vs {}",
1801 before[r][kk], after[r][kk]
1802 );
1803 }
1804 }
1805 }
1806
1807 #[test]
1808 fn frozen_routing_gates_are_rho_invariant_1033() {
1809 let (n, k) = (5usize, 2usize);
1810 let a = ibp_assignment(n, k)
1811 .freeze_routing_from_current_logits()
1812 .unwrap();
1813 // Two different ρ (different sparse + smooth strengths). With frozen routing
1814 // and learnable_alpha=false, the gate value must be identical at both ρ.
1815 let rho_a = SaeManifoldRho::new(
1816 (1e-3_f64).ln(),
1817 (1e-2_f64).ln(),
1818 vec![Array1::<f64>::zeros(1); k],
1819 );
1820 let rho_b = SaeManifoldRho::new(
1821 (1e3_f64).ln(),
1822 (1e1_f64).ln(),
1823 vec![Array1::<f64>::zeros(1); k],
1824 );
1825 for r in 0..n {
1826 let ga = a.try_assignments_row_for_rho(r, &rho_a).unwrap();
1827 let gb = a.try_assignments_row_for_rho(r, &rho_b).unwrap();
1828 for kk in 0..k {
1829 assert_eq!(
1830 ga[kk], gb[kk],
1831 "row {r} atom {kk}: frozen-routing gate must be ρ-INVARIANT (the n-independence \
1832 lever); {} at ρ_a vs {} at ρ_b",
1833 ga[kk], gb[kk]
1834 );
1835 }
1836 }
1837 }
1838
1839 #[test]
1840 fn frozen_routing_fixes_all_logits_and_thaw_restores_free_path_1033() {
1841 let (n, k) = (4usize, 3usize);
1842 let mut a = ibp_assignment(n, k)
1843 .freeze_routing_from_current_logits()
1844 .unwrap();
1845 // Under frozen routing EVERY logit is fixed (not a free Newton coord).
1846 let mask = a.fixed_logit_mask();
1847 assert_eq!(mask.len(), k);
1848 assert!(
1849 mask.iter().all(|&f| f),
1850 "frozen routing must fix ALL logits"
1851 );
1852 for kk in 0..k {
1853 assert!(
1854 a.logit_is_fixed(kk),
1855 "atom {kk} logit must be fixed under frozen routing"
1856 );
1857 }
1858 // Thawing restores the free-logit path (no fixed logits, no ungated).
1859 a.thaw_routing();
1860 assert!(!a.routing_is_frozen());
1861 assert!(
1862 a.fixed_logit_mask().iter().all(|&f| !f),
1863 "thaw must restore the free-logit path"
1864 );
1865 }
1866
1867 #[test]
1868 fn frozen_routing_rejects_softmax_1033() {
1869 let (n, k) = (4usize, 3usize);
1870 let logits = Array2::from_shape_fn((n, k), |(i, kk)| 0.1 * (i as f64) - 0.05 * (kk as f64));
1871 let coords: Vec<Array2<f64>> = (0..k)
1872 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) * 0.1))
1873 .collect();
1874 let a = SaeAssignment::from_blocks_with_mode(logits, coords, AssignmentMode::softmax(1.0))
1875 .unwrap();
1876 // Softmax + frozen routing is rejected (the coupled-simplex entropy
1877 // majorizer would be inconsistent with a frozen, non-optimized routing).
1878 assert!(
1879 a.freeze_routing_from_current_logits().is_err(),
1880 "frozen routing under Softmax must be rejected (simplex entropy-majorizer coupling)"
1881 );
1882 }
1883}
1884
1885#[cfg(test)]
1886mod fill_into_buffer_1557_tests {
1887 //! #1557 — the fill-into-caller-buffer variant
1888 //! [`SaeAssignment::try_assignments_row_for_rho_into`] must produce
1889 //! BIT-IDENTICAL output to the allocating
1890 //! [`SaeAssignment::try_assignments_row_for_rho`] across every assignment
1891 //! mode (Softmax, IBPMap, JumpReLU), the #1026 ungated case, and the K==1
1892 //! edge. Exact `==` on f64 — not an approximate tolerance — because the
1893 //! `_into` path is a pure allocation-elision refactor and any numeric drift
1894 //! is a regression.
1895 use super::*;
1896
1897 fn build(n: usize, k: usize, mode: AssignmentMode) -> SaeAssignment {
1898 // Deterministic, asymmetric logits/coords so every atom takes a distinct
1899 // value (no accidental ties masking an index bug).
1900 let logits = Array2::from_shape_fn((n, k), |(i, kk)| {
1901 0.37 + 0.11 * (i as f64) - 0.23 * (kk as f64)
1902 });
1903 let coords: Vec<Array2<f64>> = (0..k)
1904 .map(|_| Array2::from_shape_fn((n, 1), |(i, _)| 0.1 + 0.05 * (i as f64)))
1905 .collect();
1906 SaeAssignment::from_blocks_with_mode(logits, coords, mode).unwrap()
1907 }
1908
1909 fn rho(k: usize) -> SaeManifoldRho {
1910 SaeManifoldRho::new(
1911 (1e-2_f64).ln(),
1912 (1e-1_f64).ln(),
1913 vec![Array1::<f64>::zeros(1); k],
1914 )
1915 }
1916
1917 fn assert_into_matches_alloc(a: &SaeAssignment) {
1918 let n = a.n_obs();
1919 let k = a.k_atoms();
1920 let rho = rho(k);
1921 let mut scratch = vec![f64::NAN; k];
1922 for row in 0..n {
1923 let allocated = a.try_assignments_row_for_rho(row, &rho).unwrap();
1924 // Pre-fill with NaN so a partial write (e.g. a JumpReLU below-threshold
1925 // entry left untouched) is caught as a mismatch, not silently passed.
1926 for s in scratch.iter_mut() {
1927 *s = f64::NAN;
1928 }
1929 a.try_assignments_row_for_rho_into(row, &rho, &mut scratch)
1930 .unwrap();
1931 assert_eq!(allocated.len(), k);
1932 for kk in 0..k {
1933 assert_eq!(
1934 allocated[kk], scratch[kk],
1935 "row {row} atom {kk}: _into must be BIT-IDENTICAL to the allocating \
1936 try_assignments_row_for_rho; got {} vs {}",
1937 allocated[kk], scratch[kk]
1938 );
1939 }
1940 }
1941 }
1942
1943 #[test]
1944 fn softmax_into_is_bit_identical() {
1945 assert_into_matches_alloc(&build(7, 4, AssignmentMode::softmax(0.8)));
1946 }
1947
1948 #[test]
1949 fn ibp_map_into_is_bit_identical() {
1950 // Both learnable and fixed alpha exercise the resolved-alpha branch.
1951 assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, false)));
1952 assert_into_matches_alloc(&build(7, 5, AssignmentMode::ibp_map(0.6, 1.3, true)));
1953 }
1954
1955 #[test]
1956 fn jumprelu_into_is_bit_identical() {
1957 // Threshold chosen so SOME atoms fall below it (the untouched-entry path)
1958 // and some clear it (the sigmoid path) — both branches are exercised.
1959 assert_into_matches_alloc(&build(7, 5, AssignmentMode::jumprelu(0.9, 0.2)));
1960 }
1961
1962 #[test]
1963 fn ungated_into_is_bit_identical() {
1964 // #1026 ungated overwrite under a gate-style mode (IBP/JumpReLU allow it).
1965 let a = build(6, 4, AssignmentMode::ibp_map(0.6, 1.1, false))
1966 .with_ungated(vec![false, true, false, true])
1967 .unwrap();
1968 assert_into_matches_alloc(&a);
1969 let j = build(6, 4, AssignmentMode::jumprelu(0.9, 0.15))
1970 .with_ungated(vec![true, false, true, false])
1971 .unwrap();
1972 assert_into_matches_alloc(&j);
1973 }
1974
1975 #[test]
1976 fn k_equals_one_into_is_bit_identical() {
1977 // Softmax K==1 hits the fixed-unit early return; IBP/JumpReLU K==1 keep a
1978 // free per-atom gate and fall through to the real row functions.
1979 assert_into_matches_alloc(&build(5, 1, AssignmentMode::softmax(1.0)));
1980 assert_into_matches_alloc(&build(5, 1, AssignmentMode::ibp_map(0.7, 1.0, false)));
1981 assert_into_matches_alloc(&build(5, 1, AssignmentMode::jumprelu(0.8, 0.1)));
1982 }
1983}