Skip to main content

gam_sae/manifold/
construction.rs

1use super::*;
2use gam_math::jet_scalar::JetScalar;
3
4// [#780] Softmax-entropy Gershgorin majorizer leaf helpers live in a sibling
5// cohesive module, inlined here so they share this module scope.
6include!("softmax_entropy_majorizer.rs");
7
8/// Typed error from the SAE outer-gradient analytic assembly path (#1436).
9///
10/// The `eval()` analytic fallback (#1273/#1440: the plain undeflated analytic
11/// outer gradient, NOT a finite difference) must fire ONLY for the genuine
12/// conditioning/identifiability failure modes it was designed for — a
13/// near-singular-but-valid joint Hessian or a gauge-degenerate direction.
14/// Shape/indexing bugs, non-finite intermediates, and violated internal
15/// invariants are [`OuterGradientError::InternalInvariant`] and MUST propagate
16/// as hard errors so regressions surface instead of being silently masked by a
17/// degraded descent direction.
18#[derive(Clone, Debug)]
19pub(crate) enum OuterGradientError {
20    /// Expected: near-singular or ill-conditioned joint Hessian at a feasible ρ
21    /// (the genuine #1273 flat-valley case). Eligible for the FD fallback.
22    IllConditioned { reason: String },
23    /// Expected: a non-identifiable / gauge-degenerate direction at this ρ.
24    /// Eligible for the FD fallback.
25    NonIdentifiable { reason: String },
26    /// Unexpected: shape/dimension mismatch, non-finite intermediate, or a
27    /// violated internal invariant. MUST propagate — never fall back to FD.
28    InternalInvariant { reason: String },
29}
30
31impl OuterGradientError {
32    /// Whether this error class is recoverable by the #1273/#1440 analytic
33    /// plain-solver fallback (i.e. it represents a legitimate
34    /// conditioning/identifiability failure, not a programming/invariant defect).
35    pub(crate) fn is_conditioning_recoverable(&self) -> bool {
36        matches!(
37            self,
38            Self::IllConditioned { .. } | Self::NonIdentifiable { .. }
39        )
40    }
41
42    /// Construct an [`OuterGradientError::InternalInvariant`] from any error
43    /// displayable — the default classification for unexpected assembly failures
44    /// (shape mismatches, non-finite intermediates, violated invariants).
45    pub(crate) fn internal<E: std::fmt::Display>(err: E) -> Self {
46        Self::InternalInvariant {
47            reason: err.to_string(),
48        }
49    }
50
51    /// #1451 — classify a `String` error surfaced by the deflation linear-algebra
52    /// path (`apply_cached_arrow_hessian`, `DeflatedArrowSolver::from_orthonormal_gauges`)
53    /// into the correct [`OuterGradientError`] class.
54    ///
55    /// A genuine rank-deficiency / near-singularity failure (a back-solve or
56    /// Cholesky/Woodbury factor that tripped on a finite, correctly-shaped input)
57    /// is a legitimate #1273 conditioning failure and keeps `conditioning_err`
58    /// (`IllConditioned`), so it stays recoverable by the analytic fallback. A
59    /// shape/dimension mismatch or a non-finite intermediate is an
60    /// internal-invariant defect and MUST propagate ([`Self::internal`]) instead
61    /// of being masked as a plausible-but-wrong descent direction — exactly the
62    /// #1436 contract.
63    ///
64    /// The two solver helpers return `String` (not a typed error), so the
65    /// distinction is drawn from the stable markers those helpers emit for their
66    /// shape/non-finite guards (`vector shapes`, `gauge length`, `must be finite`,
67    /// `non-finite`). Everything else — including the `cholesky`/back-solve
68    /// near-singular failures — is treated as a genuine conditioning trip.
69    pub(crate) fn classify_arrow_solver_error(message: &str, conditioning_err: Self) -> Self {
70        let lower = message.to_ascii_lowercase();
71        let is_internal = lower.contains("vector shapes")
72            || lower.contains("gauge length")
73            || lower.contains("solution length")
74            || lower.contains("!= cache")
75            || lower.contains("must be finite")
76            || lower.contains("non-finite")
77            || lower.contains("not finite")
78            || lower.contains("nan")
79            || lower.contains("inf");
80        if is_internal {
81            Self::internal(message)
82        } else {
83            conditioning_err
84        }
85    }
86
87    /// The exact gate the gradient lane (`SaeManifoldOuterObjective::eval`) uses
88    /// to decide whether to descend with the #1273/#1440 analytic plain-solver
89    /// fallback instead of propagating the error as a hard failure.
90    ///
91    /// The fallback is admissible ONLY when BOTH hold:
92    /// * the REML cost at this rho is finite (a genuinely feasible point -- the
93    ///   plain analytic solver supplies a descent direction for a value the
94    ///   analytic path already produced), and
95    /// * the error is a legitimate conditioning/identifiability failure
96    ///   ([`Self::is_conditioning_recoverable`]) -- the genuine #1273 flat-valley
97    ///   case.
98    ///
99    /// A non-finite cost or an [`OuterGradientError::InternalInvariant`] must
100    /// propagate: masking a shape/indexing bug, a non-finite intermediate, or a
101    /// violated invariant behind a plausible-but-wrong step is exactly the
102    /// regression #1436 closes. Centralising the decision here (rather than
103    /// inlining the boolean at the call site) makes the `cost x error-class`
104    /// contract a single, directly unit-testable predicate.
105    pub(crate) fn admits_plain_solver_fallback(&self, cost: f64) -> bool {
106        cost.is_finite() && self.is_conditioning_recoverable()
107    }
108}
109
110impl std::fmt::Display for OuterGradientError {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match self {
113            Self::IllConditioned { reason } => write!(f, "ill-conditioned: {reason}"),
114            Self::NonIdentifiable { reason } => write!(f, "non-identifiable: {reason}"),
115            Self::InternalInvariant { reason } => write!(f, "internal invariant: {reason}"),
116        }
117    }
118}
119
120impl From<OuterGradientError> for String {
121    fn from(e: OuterGradientError) -> String {
122        e.to_string()
123    }
124}
125
126/// Active-set layout override for [`SaeManifoldTerm::assemble_arrow_schur_inner`].
127///
128/// `None` is the production path: the layout is derived from the assignment mode
129/// and `sparse_active_plan`. `Some(layout_opt)` pins a specific layout — dense
130/// (`Some(None)`) or a chosen compact `SaeRowLayout` (`Some(Some(..))`) — so the
131/// compact-vs-dense Riemannian-geometry equality regression can drive both code
132/// paths on identical data without depending on the host/device memory budget
133/// that gates the compact path in production.
134pub(crate) type ForcedRowLayout = Option<Option<SaeRowLayout>>;
135
136/// #1154 — base co-training weight for the amortized-encoder reconstruction
137/// consistency penalty, as a fraction of the REML criterion magnitude. The
138/// effective weight is `COTRAIN_RECON_WEIGHT · max(|REML|, 1)`, so the penalty
139/// is a bounded, scale-free share of the objective and needs no caller knob.
140pub(crate) const COTRAIN_RECON_WEIGHT: f64 = 0.1;
141
142/// #1154 — base co-training weight for the encoder's certifiable-coverage
143/// penalty (the fraction of (row, atom) encodes the Kantorovich certificate
144/// rejected). Scaled like [`COTRAIN_RECON_WEIGHT`].
145pub(crate) const COTRAIN_CERT_WEIGHT: f64 = 0.05;
146
147/// #1154 — amortized-encoder consistency of a fitted dictionary against its own
148/// fit-time target. The co-training signal of the joint amortized-encoder +
149/// REML loop: how faithfully (and how certifiably) the cheap one-mat-vec
150/// encoder inverts the dictionary the inner solve converged to.
151#[derive(Debug, Clone, Copy)]
152pub struct AmortizedEncoderConsistency {
153    /// Mean per-element squared gap between the amortized reconstruction and the
154    /// exact fitted reconstruction (`‖x̂_amortized − x̂_exact‖² / (n·p)`). Zero ⇒
155    /// the IFT predictor reproduces the encode map exactly to first order.
156    pub recon_consistency: f64,
157    /// Fraction of (row, atom) amortized encodes whose Kantorovich certificate
158    /// failed (`h > ½`) and fell back to the certified Newton encode.
159    pub uncertified_fraction: f64,
160    /// Count of uncertified (row, atom) encodes (numerator of the fraction).
161    pub n_uncertified: usize,
162    /// Total (row, atom) encodes scored (`n · K`).
163    pub n_encodes: usize,
164}
165
166impl SaeManifoldTerm {
167    #[must_use = "build error must be handled"]
168    pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
169        if atoms.is_empty() {
170            return Err("SaeManifoldTerm::new: at least one atom required".into());
171        }
172        let n = atoms[0].n_obs();
173        let p = atoms[0].output_dim();
174        if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
175            return Err(format!(
176                "SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
177                assignment.n_obs(),
178                assignment.k_atoms(),
179                atoms.len()
180            ));
181        }
182        for (k, atom) in atoms.iter().enumerate() {
183            if atom.n_obs() != n {
184                return Err(format!(
185                    "SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
186                    atom.n_obs()
187                ));
188            }
189            if atom.output_dim() != p {
190                return Err(format!(
191                    "SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
192                    atom.output_dim()
193                ));
194            }
195            if atom.latent_dim != assignment.coords[k].latent_dim() {
196                return Err(format!(
197                    "SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
198                    atom.latent_dim,
199                    assignment.coords[k].latent_dim()
200                ));
201            }
202        }
203        Ok(Self {
204            atoms,
205            assignment,
206            temperature_schedule: None,
207            last_row_layout: None,
208            row_metric: None,
209            collapse_events: Vec::new(),
210            row_loss_weights: None,
211            last_frames_active: false,
212            assembly_chunk_override: None,
213            fixed_decoder_assembly: false,
214            softmax_active_cap: None,
215            border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
216            certificate_dispersion: None,
217            curvature_walk_report: None,
218            expected_evidence_gauge_deflated_directions: None,
219            evidence_gauge_deflation_reanchors: 0,
220            evidence_gauge_deflation_last_delta_sign: 0,
221            dictionary_cocollapse_reseeds: 0,
222            best_cocollapse_incumbent: None,
223            decoder_repulsion_gate: None,
224            barrier_coactivation_gate: None,
225            hybrid_split_report: None,
226            atom_inner_fits: None,
227            oos_linear_images: None,
228        })
229    }
230
231    /// #1408/#1409 — install the optional hard per-row active-atom cap for
232    /// Softmax mode (threaded from the fit/encode `top_k`). A `Some(k)` with
233    /// `1 <= k < K` makes the Softmax assignment optimize on the COMPACT
234    /// top-`k` row layout (see [`Self::softmax_active_cap`]); `Some(k) >= K`
235    /// and `None` are both no-ops (full support). Non-softmax modes ignore it.
236    pub fn set_softmax_active_cap(&mut self, top_k: Option<usize>) {
237        self.softmax_active_cap = match top_k {
238            Some(k) if k >= 1 && k < self.k_atoms() => Some(k),
239            _ => None,
240        };
241    }
242
243    /// Install the fitted reconstruction dispersion used by
244    /// [`dictionary_incoherence_report`]. This is a pure diagnostic scalar and
245    /// does not feed any loss, criterion, penalty, or optimizer state.
246    pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
247        if !dispersion.is_finite() || dispersion <= 0.0 {
248            return Err(format!(
249                "SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
250            ));
251        }
252        self.certificate_dispersion = Some(dispersion);
253        Ok(())
254    }
255
256    /// Harvest the per-atom inner-decoder-smooth byproducts (#1097 / #1103) the
257    /// residual-gauge certificate's post-PIRLS atom inference reports consume.
258    ///
259    /// This is the post-fit harness seam: it needs the reconstruction target `Z`
260    /// (`target`) and the fitted dispersion `φ` (`dispersion`), both available
261    /// only after the joint fit converges and the engine has discarded `Z` from
262    /// the objective. For each atom `k` it captures the Gaussian-identity
263    /// penalized smooth of the atom's leading decoder output channel `j`
264    /// (largest column 2-norm of `B_k`) against its partial residual
265    /// `e_{i} = z_i − fitted_i + a_{ik} g_k(t_i)` on channel `j`, holding all
266    /// other atoms and the assignment fixed at the fitted optimum — exactly the
267    /// fixed snapshot ([`crate::identifiability::AtomInnerFit`]) the Riesz
268    /// debiasing and split-LRT smooth-structure e-value read.
269    ///
270    /// A pure read of the fitted state: it mutates only the diagnostic
271    /// `atom_inner_fits` field, never a loss / criterion / penalty / optimizer
272    /// state. Atoms with no active rows or a degenerate (rank-deficient,
273    /// non-SPD) inner Hessian get a `None` slot — the genuine prerequisite (an
274    /// SPD penalized inner Hessian on a non-empty active set) is absent there.
275    pub fn set_atom_inner_fits(
276        &mut self,
277        target: ArrayView2<'_, f64>,
278        rho: &SaeManifoldRho,
279        dispersion: f64,
280    ) -> Result<(), String> {
281        if !dispersion.is_finite() || dispersion <= 0.0 {
282            return Err(format!(
283                "SaeManifoldTerm::set_atom_inner_fits: dispersion must be finite and positive, got {dispersion}"
284            ));
285        }
286        let n = self.n_obs();
287        let p = self.output_dim();
288        let k_atoms = self.k_atoms();
289        if target.dim() != (n, p) {
290            return Err(format!(
291                "SaeManifoldTerm::set_atom_inner_fits: target {:?} != ({n}, {p})",
292                target.dim()
293            ));
294        }
295
296        // #1026 — `atom_inner_fits` is a pure diagnostic; skip its dense (N×K×P)
297        // tensor (~256 GiB at K=32768,P=32) past a cell ceiling — all-None slots,
298        // never OOM. The fit is unaffected; only this audit field is absent.
299        if n.saturating_mul(k_atoms).saturating_mul(p) > 64_000_000 {
300            self.atom_inner_fits = Some((0..k_atoms).map(|_| None).collect());
301            return Ok(());
302        }
303
304        // Settled per-row assignments and per-(row, atom) decoded outputs, so the
305        // per-atom partial residual is `e_k = (z − fitted) + a_k decoded_k`.
306        let mut assignments = Vec::with_capacity(n);
307        for row in 0..n {
308            assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
309        }
310        let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
311        let mut dbuf = vec![0.0_f64; p];
312        for row in 0..n {
313            for atom_idx in 0..k_atoms {
314                self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
315                for c in 0..p {
316                    decoded[[row, atom_idx, c]] = dbuf[c];
317                }
318            }
319        }
320        let mut fitted = Array2::<f64>::zeros((n, p));
321        for row in 0..n {
322            for atom_idx in 0..k_atoms {
323                let a = assignments[row][atom_idx];
324                if a == 0.0 {
325                    continue;
326                }
327                for c in 0..p {
328                    fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
329                }
330            }
331        }
332
333        let mut inner_fits: Vec<Option<crate::identifiability::AtomInnerFit>> =
334            Vec::with_capacity(k_atoms);
335        for atom_idx in 0..k_atoms {
336            inner_fits.push(self.build_atom_inner_fit(
337                atom_idx,
338                target,
339                &assignments,
340                decoded.view(),
341                fitted.view(),
342                dispersion,
343            )?);
344        }
345        self.atom_inner_fits = Some(inner_fits);
346        Ok(())
347    }
348
349    /// Build one atom's fixed inner-smooth snapshot for the post-PIRLS atom
350    /// inference reports, or `None` when the atom has no active rows or the
351    /// penalized inner Hessian is not SPD. Returns `Err` only on a structural
352    /// inconsistency (shape mismatch), never on a benign degenerate atom.
353    pub(crate) fn build_atom_inner_fit(
354        &self,
355        atom_idx: usize,
356        target: ArrayView2<'_, f64>,
357        assignments: &[Array1<f64>],
358        decoded: ArrayView3<'_, f64>,
359        fitted: ArrayView2<'_, f64>,
360        dispersion: f64,
361    ) -> Result<Option<crate::identifiability::AtomInnerFit>, String> {
362        let atom = &self.atoms[atom_idx];
363        let n = atom.n_obs();
364        let m = atom.basis_size();
365        let p = atom.output_dim();
366        if m == 0 || p == 0 {
367            return Ok(None);
368        }
369
370        // Leading decoder output channel j = argmax_j ‖B_k[:, j]‖, the channel
371        // that carries the atom's signal.
372        let mut j_lead = 0usize;
373        let mut best_norm = -1.0_f64;
374        for col in 0..p {
375            let mut norm = 0.0_f64;
376            for r in 0..m {
377                let v = atom.decoder_coefficients[[r, col]];
378                norm += v * v;
379            }
380            if norm > best_norm {
381                best_norm = norm;
382                j_lead = col;
383            }
384        }
385        let beta = atom.decoder_coefficients.column(j_lead).to_owned();
386
387        // Active rows: a_{ik} > 0.
388        let active: Vec<usize> = (0..n)
389            .filter(|&row| assignments[row][atom_idx] > 0.0)
390            .collect();
391        let n_active = active.len();
392        // The penalized smooth needs at least as many active rows as it has
393        // basis columns to give a non-degenerate data Gram; below that the inner
394        // fit's SPD prerequisite is genuinely unmet.
395        if n_active == 0 {
396            return Ok(None);
397        }
398
399        let mut design = Array2::<f64>::zeros((n_active, m));
400        let mut derivative_design = Array2::<f64>::zeros((n_active, m));
401        let mut row_scores = Array2::<f64>::zeros((n_active, m));
402        let mut weights = Array1::<f64>::zeros(n_active);
403        for (slot, &row) in active.iter().enumerate() {
404            let a_ik = assignments[row][atom_idx];
405            let w_i = a_ik * a_ik;
406            weights[slot] = w_i;
407            for col in 0..m {
408                design[[slot, col]] = atom.basis_values[[row, col]];
409                // Leading latent axis (axis 0) is the atom's primary coordinate;
410                // it is the one the average-derivative functional integrates.
411                derivative_design[[slot, col]] = atom.basis_jacobian[[row, col, 0]];
412            }
413            // Partial residual on channel j, then the inner-smooth working
414            // response z_i = e_i / a_ik so that w_i (z_i − Φᵀβ) = a_ik r_i.
415            let e_i = target[[row, j_lead]] - fitted[[row, j_lead]]
416                + a_ik * decoded[[row, atom_idx, j_lead]];
417            let mu_hat = design.row(slot).dot(&beta);
418            let z_i = e_i / a_ik;
419            let res_i = z_i - mu_hat;
420            // Gaussian-identity score s_i = −w_i res_i Φ_i / φ.
421            let scale = -w_i * res_i / dispersion;
422            for col in 0..m {
423                row_scores[[slot, col]] = scale * design[[slot, col]];
424            }
425        }
426
427        // Penalized inner Hessian H = ΦᵀWΦ + S̃_k.
428        let mut xtwx = Array2::<f64>::zeros((m, m));
429        for slot in 0..n_active {
430            let w_i = weights[slot];
431            for a in 0..m {
432                let xa = design[[slot, a]];
433                if xa == 0.0 {
434                    continue;
435                }
436                for b in 0..m {
437                    xtwx[[a, b]] += w_i * xa * design[[slot, b]];
438                }
439            }
440        }
441        let penalty = atom.smooth_penalty.clone();
442        if penalty.dim() != (m, m) {
443            return Err(format!(
444                "build_atom_inner_fit: atom {atom_idx} smooth penalty {:?} != ({m}, {m})",
445                penalty.dim()
446            ));
447        }
448        let penalized_hessian = &xtwx + &penalty;
449
450        // SPD prerequisite: the inner penalized Hessian must factor, else the
451        // atom's inner-smooth fit is degenerate and no report is producible.
452        if penalized_hessian.cholesky(Side::Lower).is_err() {
453            return Ok(None);
454        }
455
456        // Peak (largest fitted |g_k| on channel j) and mode (largest assignment
457        // mass) design rows, over the active set.
458        let mut peak_slot = 0usize;
459        let mut peak_val = -1.0_f64;
460        let mut mode_slot = 0usize;
461        let mut mode_mass = -1.0_f64;
462        for (slot, &row) in active.iter().enumerate() {
463            let g_val = design.row(slot).dot(&beta).abs();
464            if g_val > peak_val {
465                peak_val = g_val;
466                peak_slot = slot;
467            }
468            let mass = assignments[row][atom_idx];
469            if mass > mode_mass {
470                mode_mass = mass;
471                mode_slot = slot;
472            }
473        }
474        let peak_design_row = design.row(peak_slot).to_owned();
475        let mode_design_row = design.row(mode_slot).to_owned();
476
477        Ok(Some(crate::identifiability::AtomInnerFit {
478            design,
479            derivative_design,
480            beta,
481            penalty,
482            penalized_hessian,
483            row_scores,
484            weights,
485            dispersion,
486            peak_design_row,
487            mode_design_row,
488        }))
489    }
490
491    /// Profile the Gaussian reconstruction dispersion at the current seed
492    /// state. This is the scale used to make SAE penalty seeds dimensionless
493    /// before the outer rho search starts.
494    pub fn seed_reconstruction_dispersion(
495        &self,
496        target: ArrayView2<'_, f64>,
497    ) -> Result<f64, String> {
498        let fitted = self.try_fitted()?;
499        if fitted.dim() != target.dim() {
500            return Err(format!(
501                "SaeManifoldTerm::seed_reconstruction_dispersion: fitted {:?} != target {:?}",
502                fitted.dim(),
503                target.dim()
504            ));
505        }
506        let n_scalar = (target.nrows() * target.ncols()).max(1) as f64;
507        let mut rss = 0.0_f64;
508        for row in 0..target.nrows() {
509            for col in 0..target.ncols() {
510                let r = target[[row, col]] - fitted[[row, col]];
511                rss += r * r;
512            }
513        }
514        if !rss.is_finite() || rss < 0.0 {
515            return Err(format!(
516                "SaeManifoldTerm::seed_reconstruction_dispersion: non-finite seed RSS {rss}"
517            ));
518        }
519        Ok((rss / n_scalar).max(SAE_SEED_DISPERSION_FLOOR))
520    }
521
522    /// Install per-row design honesty weights (#991) — the `1/π` inclusion
523    /// corrections of a designed corpus subsample (see the field docs on
524    /// `row_loss_weights` for exactly where they enter the objective).
525    ///
526    /// Weights must be finite and strictly positive, one per term row. They
527    /// are self-normalized to mean `1.0` here (only the *relative* design
528    /// correction matters at the fitted sample size; the absolute `n/budget`
529    /// scale would silently inflate the dispersion estimate against the
530    /// sample-sized dof). Weights that are identically equal after
531    /// normalization (an exact full pass, or any uniform design) are stored
532    /// as `None`, so the unweighted path stays bit-for-bit identical rather
533    /// than "multiplied by 1.0".
534    pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
535        if weights.len() != self.n_obs() {
536            return Err(format!(
537                "SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
538                weights.len(),
539                self.n_obs()
540            ));
541        }
542        if weights.is_empty() {
543            self.row_loss_weights = None;
544            return Ok(());
545        }
546        if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
547            return Err(
548                "SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
549                 positive"
550                    .to_string(),
551            );
552        }
553        let first = weights[0];
554        if weights.iter().all(|w| *w == first) {
555            // Uniform design (full pass, or flat measure): the normalized
556            // weight is exactly 1 everywhere — take the unweighted path.
557            self.row_loss_weights = None;
558            return Ok(());
559        }
560        let mean = weights.iter().sum::<f64>() / weights.len() as f64;
561        self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
562        Ok(())
563    }
564
565    /// The installed (mean-1 normalized) design honesty weights, `None` on the
566    /// exact unweighted path.
567    pub fn row_loss_weights(&self) -> Option<&[f64]> {
568        self.row_loss_weights.as_deref()
569    }
570
571    /// Drop any installed per-row reconstruction weights, returning the term to
572    /// the exact unweighted (full-pass) path. Used by the #997 structure-search
573    /// wiring to clear the internal estimation/evaluation mask off the adopted
574    /// term before the payload reconstruction is read over all rows.
575    pub fn clear_row_loss_weights(&mut self) {
576        self.row_loss_weights = None;
577    }
578
579    /// Huber-style OUTLIER-ROBUST per-row weights from the target activation
580    /// norms — the missing default *policy* for the existing
581    /// [`set_row_loss_weights`](Self::set_row_loss_weights) mechanism.
582    ///
583    /// The SAE fits unweighted least squares, which weights each token by its
584    /// squared residual ∝ `‖z_i‖²`. On real LLM residual streams the per-token
585    /// norm distribution is heavy-tailed (e.g. an OLMo mixed-layer slice has
586    /// `p99/median ≈ 4.7`), so a small **coherent** cluster of high-norm tokens —
587    /// typically special / attention-sink tokens, not semantic content —
588    /// dominates the objective (measured: the top 5% of tokens carry ~31% of the
589    /// total `‖z‖²` budget) and pulls dictionary atoms toward their direction.
590    /// Mean-centering does NOT address this (it is per-feature, not per-token).
591    ///
592    /// This returns Huber weights `w_i = min(1, δ·m / ‖z_i‖)` where `m` is the
593    /// MEDIAN token norm: tokens at or below `δ·m` keep full weight, higher-norm
594    /// tokens are downweighted so their objective share grows only LINEARLY (not
595    /// quadratically) with norm. `δ` is the robustness knob (`δ=1` thresholds at
596    /// the median; larger `δ` only touches the extreme tail). The result is
597    /// mean-normalized (overall objective scale preserved). OPT-IN: the caller
598    /// installs it via `set_row_loss_weights` — the default fit is unchanged.
599    pub fn robust_norm_row_weights(
600        target: ArrayView2<'_, f64>,
601        delta: f64,
602    ) -> Result<Vec<f64>, String> {
603        if !(delta.is_finite() && delta > 0.0) {
604            return Err(format!(
605                "robust_norm_row_weights: delta must be finite and positive; got {delta}"
606            ));
607        }
608        let n = target.nrows();
609        if n == 0 {
610            return Ok(Vec::new());
611        }
612        let norms: Vec<f64> = (0..n)
613            .map(|i| {
614                let r = target.row(i);
615                r.dot(&r).sqrt()
616            })
617            .collect();
618        let mut sorted = norms.clone();
619        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
620        // Median token norm (lower-median for even n; floored off zero so an
621        // all-zero/degenerate slice yields uniform weights instead of NaN).
622        let median = sorted[n / 2].max(f64::MIN_POSITIVE);
623        let thresh = delta * median;
624        let raw: Vec<f64> = norms
625            .iter()
626            .map(|&nm| if nm <= thresh { 1.0 } else { thresh / nm })
627            .collect();
628        let mean = raw.iter().sum::<f64>() / n as f64;
629        if !(mean.is_finite() && mean > 0.0) {
630            return Err("robust_norm_row_weights: degenerate weight normalizer".to_string());
631        }
632        Ok(raw.into_iter().map(|w| w / mean).collect())
633    }
634
635    /// Install the single per-row [`RowMetric`](gam_problem::RowMetric)
636    /// that both the reconstruction likelihood and the isometry gauge read.
637    /// Installing per-row output-Fisher factors here flips the provenance to
638    /// `OutputFisher` *and* is the only way the gauge acquires a non-identity
639    /// weight, so the two inner products cannot diverge. Passing a Euclidean
640    /// metric (or never calling this) keeps the bit-identical isotropic path.
641    ///
642    /// The metric's row count and output dimension must match the term.
643    pub fn set_row_metric(
644        &mut self,
645        metric: gam_problem::RowMetric,
646    ) -> Result<(), String> {
647        if metric.n_rows() != self.n_obs() {
648            return Err(format!(
649                "SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
650                metric.n_rows(),
651                self.n_obs()
652            ));
653        }
654        if metric.p_out() != self.output_dim() {
655            return Err(format!(
656                "SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
657                metric.p_out(),
658                self.output_dim()
659            ));
660        }
661        self.row_metric = Some(metric);
662        Ok(())
663    }
664
665    /// The installed per-row metric, if any. `None` ⇒ Euclidean / isotropic.
666    /// Consumed by the gauge wiring (to build the matching `WeightField`) and by
667    /// Object 4 (to read the [`MetricProvenance`](gam_problem::MetricProvenance)).
668    pub fn row_metric(&self) -> Option<&gam_problem::RowMetric> {
669        self.row_metric.as_ref()
670    }
671
672    /// The per-row inner product the additive diagnostics read through: the
673    /// installed [`RowMetric`](gam_problem::RowMetric) when one
674    /// was set (output-Fisher harvest present), otherwise a freshly-built
675    /// Euclidean metric of the term's own `(n_obs, output_dim)` shape. Either way
676    /// a metric always exists, so the diagnostics are never gated by a flag — the
677    /// Euclidean fallback is the bit-identical isotropic path.
678    pub(crate) fn diagnostic_metric(
679        &self,
680    ) -> Result<gam_problem::RowMetric, String> {
681        match self.row_metric() {
682            Some(metric) => Ok(metric.clone()),
683            None => {
684                gam_problem::RowMetric::euclidean(self.n_obs(), self.output_dim())
685            }
686        }
687    }
688
689    /// Build the additive post-fit diagnostic report for this fitted term: the
690    /// two-score per-atom [`AtomTwoLensReport`](crate::inference::atom_lens::AtomTwoLensReport)
691    /// (presence / behavioral coupling / discrepancy) and the residual-gauge
692    /// [`ResidualGaugeReport`](crate::identifiability::ResidualGaugeReport)
693    /// certificate.
694    ///
695    /// Both reports are read through the same single metric
696    /// ([`Self::diagnostic_metric`]): under a Euclidean / no-harvest provenance
697    /// the lens coupling is `None` and the gauge is certified under Euclidean
698    /// provenance — never an error, never gated by a flag (magic-by-default,
699    /// mirroring the metric selection itself).
700    ///
701    /// `per_atom_ard_variances`, when supplied, is one ARD variance vector per
702    /// atom (length = `latent_dim_k`), threaded into the certificate's
703    /// equal-ARD-rotation detection. `None` (or a per-atom `None`) ⇒ no ARD prior
704    /// on that atom. `isometry_pin_active` records whether an isometry gauge
705    /// penalty was installed on the fit: `false` escalates the certificate to the
706    /// `diffeomorphism-unpinned` verdict (the honest "no metric pin" statement),
707    /// exactly as the certificate's own escalation flag specifies.
708    ///
709    /// Pure read: it never mutates the term, never touches a loss / criterion /
710    /// penalty / optimizer state.
711    pub fn fit_diagnostics_report(
712        &self,
713        per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
714        isometry_pin_active: bool,
715        reconstruction_dispersion: Option<f64>,
716        assignments_override: Option<ArrayView2<'_, f64>>,
717    ) -> Result<SaeManifoldFitDiagnostics, String> {
718        if let Some(view) = assignments_override {
719            let n = self.n_obs();
720            let k = self.k_atoms();
721            if view.dim() != (n, k) {
722                return Err(format!(
723                    "fit_diagnostics_report: assignments_override shape {:?} must be ({n}, {k})",
724                    view.dim()
725                ));
726            }
727        }
728        let metric = self.diagnostic_metric()?;
729        let atom_two_lens =
730            crate::inference::atom_lens::atom_two_lens(self, &metric, assignments_override);
731
732        let (certificate_model, streamed_curvature) =
733            self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
734        // #998: within-atom gauge families are certified on their EXACT orbits
735        // in the model's own (decoder, coordinate) parameter space — compensated
736        // symmetries are data-nulls by construction there, no lowering-error
737        // calibration involved. This now holds whether or not an isometry pin is
738        // active:
739        //   * pin INACTIVE ⇒ the orbit verdict is the data residual alone (no
740        //     penalty operator);
741        //   * pin ACTIVE ⇒ the orbit verdict adds the isometry pin's orbit-space
742        //     curvature through an [`OrbitPenaltyOperator`] lowered from the
743        //     atom's second jet `Φ''` (the pullback-metric change along the orbit
744        //     differentiates `J = Φ'B` through `t`). A model-class symmetry that
745        //     preserves the metric stays a certified freedom; a non-isometric
746        //     orbit (a basis not closed under the action) is genuinely pinned.
747        // The relative-curvature fraction `cost/stiffness²` is invariant to the
748        // pin strength μ (both faces scale with μ), so the operator is built at a
749        // canonical unit weight. An atom whose basis exposes no analytic second
750        // jet supplies no operator and falls back to the data residual — never an
751        // error. Magic-by-default either way: the choice is derived from the fit,
752        // never a flag.
753        let views = self.atom_parameter_views();
754        let ops: Vec<Option<crate::identifiability::OrbitPenaltyOperator>> =
755            if isometry_pin_active {
756                views
757                    .iter()
758                    .map(|view| {
759                        view.as_ref().and_then(|v| {
760                            crate::identifiability::isometry_orbit_penalty_operator(
761                                v, 1.0,
762                            )
763                        })
764                    })
765                    .collect()
766            } else {
767                (0..self.k_atoms()).map(|_| None).collect()
768            };
769        let residual_gauge = if isometry_pin_active {
770            // The pin-active path consumes the per-row Jacobian curvature
771            // directly (the certificate_model retains it under a pin), so route
772            // through the non-streamed exact entry point.
773            crate::identifiability::residual_gauge_exact(
774                &certificate_model,
775                &views,
776                &ops,
777            )?
778        } else {
779            let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
780                "fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
781                    .to_string()
782            })?;
783            crate::identifiability::residual_gauge_exact_from_curvature_gram(
784                &certificate_model,
785                &views,
786                &ops,
787                curvature_gram,
788                root_rows,
789            )?
790        };
791
792        // #1097 / #1103: per-atom Riesz-debiased functionals and the any-n-valid
793        // split-LRT smooth-structure e-value (non-constant vs constant inner
794        // decoder), read straight off the certificate model — which carries
795        // each atom's `inner_fit` snapshot when the caller harvested it via
796        // [`Self::set_atom_inner_fits`] before this report. Atoms without a
797        // harvested inner fit degrade their inference fields to `None` inside
798        // `atom_inference_reports`, so this is always populated (one entry per
799        // atom) and never gated by a flag.
800        let atom_inference =
801            crate::identifiability::atom_inference_reports(&certificate_model);
802
803        Ok(SaeManifoldFitDiagnostics {
804            atom_two_lens,
805            residual_gauge,
806            incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
807                Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
808                    self, dispersion,
809                )?),
810                None => None,
811            },
812            atom_inference,
813        })
814    }
815
816    /// Build the trust-diagnostics producer for the Python `diagnostics` block.
817    ///
818    /// `assignments` is supplied by the payload assembly site so top-k projection,
819    /// when requested, is reflected in coverage/frequency and in the tangent
820    /// spectra. The active threshold is shared with the atom lens so all
821    /// assignment-support diagnostics agree on what "active" means.
822    pub fn trust_diagnostics_report(
823        &self,
824        assignments: ArrayView2<'_, f64>,
825    ) -> Result<SaeTrustDiagnostics, String> {
826        let n = self.n_obs();
827        let k_atoms = self.k_atoms();
828        if assignments.dim() != (n, k_atoms) {
829            return Err(format!(
830                "trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
831                assignments.dim()
832            ));
833        }
834        if !assignments.iter().all(|v| v.is_finite()) {
835            return Err("trust_diagnostics_report: assignments must be finite".to_string());
836        }
837        let metric = self.diagnostic_metric()?;
838        let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
839        let mut atoms = Vec::with_capacity(k_atoms);
840        let mut atom_trust = Vec::with_capacity(k_atoms);
841        for (atom_idx, atom) in self.atoms.iter().enumerate() {
842            let mut active_token_count = 0usize;
843            let mut activation_sum = 0.0_f64;
844            for row in 0..n {
845                let mass = assignments[[row, atom_idx]];
846                activation_sum += mass;
847                if mass > active_threshold {
848                    active_token_count += 1;
849                }
850            }
851            let coverage = if n > 0 {
852                active_token_count as f64 / n as f64
853            } else {
854                0.0
855            };
856            let activation_frequency = if n > 0 {
857                activation_sum / n as f64
858            } else {
859                0.0
860            };
861            let (sigma_min_tangent, sigma_max_tangent) = self
862                .atom_tangent_spectrum_from_assignments(
863                    atom_idx,
864                    assignments,
865                    &metric,
866                    active_threshold,
867                )?;
868            let tangent_condition_score = if sigma_max_tangent > 0.0 {
869                (sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
870            } else {
871                0.0
872            };
873            let trust_score = tangent_condition_score;
874            atom_trust.push(trust_score);
875            atoms.push(SaeAtomTrustDiagnostics {
876                trust_score,
877                sigma_min_tangent,
878                sigma_max_tangent,
879                tangent_condition_score,
880                coverage,
881                activation_frequency,
882                untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
883                active_token_count,
884            });
885        }
886        Ok(SaeTrustDiagnostics { atom_trust, atoms })
887    }
888
889    pub(crate) fn atom_tangent_spectrum_from_assignments(
890        &self,
891        atom_idx: usize,
892        assignments: ArrayView2<'_, f64>,
893        metric: &gam_problem::RowMetric,
894        active_threshold: f64,
895    ) -> Result<(f64, f64), String> {
896        let atom = &self.atoms[atom_idx];
897        let d = atom.latent_dim;
898        let p = self.output_dim();
899        if d == 0 || p == 0 {
900            return Ok((0.0, 0.0));
901        }
902        let mut gram = Array2::<f64>::zeros((d, d));
903        let mut active_mass_sum = 0.0_f64;
904        let mut jac_row = vec![0.0_f64; p * d];
905        for row in 0..self.n_obs() {
906            let mass = assignments[[row, atom_idx]];
907            if !(mass > active_threshold) {
908                continue;
909            }
910            active_mass_sum += mass;
911            for axis in 0..d {
912                let start = axis;
913                let mut tangent = vec![0.0_f64; p];
914                atom.fill_decoded_derivative_row(row, axis, &mut tangent);
915                for out in 0..p {
916                    jac_row[out * d + start] = tangent[out];
917                }
918            }
919            let row_pullback = metric.pullback(row, &jac_row, d);
920            for axis_a in 0..d {
921                for axis_b in 0..=axis_a {
922                    gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
923                }
924            }
925            jac_row.fill(0.0);
926        }
927        if !(active_mass_sum > 0.0) {
928            return Ok((0.0, 0.0));
929        }
930        let inv_mass = 1.0 / active_mass_sum;
931        for axis_a in 0..d {
932            for axis_b in 0..=axis_a {
933                let value = gram[[axis_a, axis_b]] * inv_mass;
934                gram[[axis_a, axis_b]] = value;
935                gram[[axis_b, axis_a]] = value;
936            }
937        }
938        let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
939            format!(
940                "trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
941            )
942        })?;
943        let mut sigma_min = f64::INFINITY;
944        let mut sigma_max = 0.0_f64;
945        for value in evals.iter().copied() {
946            let clamped = value.max(0.0);
947            let sigma = clamped.sqrt();
948            sigma_min = sigma_min.min(sigma);
949            sigma_max = sigma_max.max(sigma);
950        }
951        if sigma_min.is_finite() {
952            Ok((sigma_min, sigma_max))
953        } else {
954            Ok((0.0, 0.0))
955        }
956    }
957
958    /// Per-atom exact parameter-space views for the #998 certificate path:
959    /// the basis values / first-derivative jet, decoder coefficients, latent
960    /// coordinates, and assignment mass each atom was actually fitted with.
961    /// Sphere atoms get `None` (their chart's group action is nonlinear, so
962    /// the exact-orbit realisation does not apply and they stay on the frame
963    /// path), as does any atom whose coordinate chart width disagrees with its
964    /// latent dimension (a structurally inconsistent atom must not masquerade
965    /// as exactly certified).
966    pub(crate) fn atom_parameter_views(
967        &self,
968    ) -> Vec<Option<crate::identifiability::AtomParameterView>> {
969        let assignments = self.assignment.assignments();
970        let n = self.n_obs();
971        self.atoms
972            .iter()
973            .enumerate()
974            .map(|(k, atom)| {
975                if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
976                    return None;
977                }
978                let coords = self.assignment.coords[k].as_matrix().to_owned();
979                if coords.nrows() != n || coords.ncols() != atom.latent_dim {
980                    return None;
981                }
982                let mut activations = Array1::<f64>::zeros(n);
983                for row in 0..n {
984                    activations[row] = assignments[[row, k]];
985                }
986                // Second jet Φ'' (#998): supplied when the atom's evaluator
987                // exposes an analytic Hessian, so a pin-active fit can lower its
988                // orbit-space isometry penalty operator (the metric-change of the
989                // pullback gram differentiates Φ' through t). Absent ⇒ the orbit
990                // verdict stays on the data residual / no-pin path, never an
991                // error.
992                let basis_second_jet = atom
993                    .basis_evaluator
994                    .as_ref()
995                    .and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
996                    .and_then(|res| res.ok());
997                Some(crate::identifiability::AtomParameterView {
998                    basis_values: atom.basis_values.clone(),
999                    basis_jacobian: atom.basis_jacobian.clone(),
1000                    decoder: atom.decoder_coefficients.clone(),
1001                    coords,
1002                    activations,
1003                    basis_second_jet,
1004                })
1005            })
1006            .collect()
1007    }
1008
1009    /// Lower this fitted term into the self-contained
1010    /// [`FittedSaeManifold`](crate::identifiability::FittedSaeManifold) the
1011    /// residual-gauge certificate consumes.
1012    ///
1013    /// The certificate's parameter space is the per-atom decoder **frame** — the
1014    /// `(output_dim, latent_dim)` image of the atom's latent axes in output space.
1015    /// We realise it as the active-mass-weighted mean decoder tangent
1016    /// `frame_k[:, a] = (Σ_n a_{nk} · ∂g_k/∂t_a(n)) / Σ_n a_{nk}` over the atom's
1017    /// active rows (the centroid decoder Jacobian columns the certificate docs
1018    /// name). The per-row pinning Jacobian block `J_n ∈ ℝ^{p × param_dim}` is the
1019    /// assignment-weighted per-row decoder tangent placed at each atom's frame
1020    /// slot: column `(k, i, a)` of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i]` — exactly
1021    /// the directions the reconstruction data gives cost to, in the same metric
1022    /// the fit used (whitened by the certificate through `RowMetric`).
1023    ///
1024    /// The flattened frame layout matches the certificate's
1025    /// `vec(frame_0) ⊕ vec(frame_1) ⊕ …`, row-major within each frame
1026    /// (`frame_k[i, a]` at offset `atom_offset(k) + i·latent_dim_k + a`).
1027    pub(crate) fn to_residual_gauge_model(
1028        &self,
1029        metric: gam_problem::RowMetric,
1030        per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
1031        isometry_pin_active: bool,
1032    ) -> Result<
1033        (
1034            crate::identifiability::FittedSaeManifold,
1035            Option<(Array2<f64>, usize)>,
1036        ),
1037        String,
1038    > {
1039        use crate::identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
1040
1041        let n = self.n_obs();
1042        let p = self.output_dim();
1043        let k = self.k_atoms();
1044        let assignments = self.assignment.assignments();
1045
1046        // Per-atom frame `(p, d)` = active-mass-weighted mean decoder tangent,
1047        // and the flattened-frame column offset bookkeeping for the joint
1048        // parameter vector (`vec(frame_0) ⊕ …`, row-major within each frame).
1049        let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
1050        let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
1051        let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
1052        let mut cursor = 0usize;
1053        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1054            let d = atom.latent_dim;
1055            let topology = match (&atom.basis_kind, d) {
1056                (SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
1057                    AtomTopology::Circle
1058                }
1059                (SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
1060                    AtomTopology::Torus { latent_dim: d }
1061                }
1062                (SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
1063                // `Cylinder` (`S¹ × ℝ`) has exactly one continuous gauge: the
1064                // rotation (shift) of the periodic axis. The unbounded line axis
1065                // carries no rotational gauge, and its translation is already
1066                // pinned by the design's constant column — so the identifiability
1067                // gauge is that of a single circle. Fixing it as `Torus` would
1068                // over-impose a second (nonexistent) circle shift; fixing it as
1069                // `EuclideanPatch { 2 }` would over-impose a frame rotation
1070                // mixing the periodic and linear axes. `Circle` fixes the one
1071                // real continuous gauge and leaves the linear axis ungauged.
1072                (SaeAtomBasisKind::Cylinder, _) => AtomTopology::Circle,
1073                (
1074                    SaeAtomBasisKind::Linear
1075                    | SaeAtomBasisKind::Duchon
1076                    | SaeAtomBasisKind::EuclideanPatch
1077                    | SaeAtomBasisKind::Poincare
1078                    | SaeAtomBasisKind::Precomputed(_),
1079                    _,
1080                ) => AtomTopology::EuclideanPatch { latent_dim: d },
1081            };
1082
1083            let mut frame = Array2::<f64>::zeros((p, d));
1084            let mut active_mass = 0.0_f64;
1085            let mut tangent = vec![0.0_f64; p];
1086            for row in 0..n {
1087                let a_nk = assignments[[row, atom_idx]];
1088                if !(a_nk > 0.0) {
1089                    continue;
1090                }
1091                active_mass += a_nk;
1092                for axis in 0..d {
1093                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1094                    for i in 0..p {
1095                        frame[[i, axis]] += a_nk * tangent[i];
1096                    }
1097                }
1098            }
1099            if active_mass > 0.0 {
1100                let inv = 1.0 / active_mass;
1101                frame.mapv_inplace(|v| v * inv);
1102            }
1103
1104            // #995 lowering-error scale: mass-weighted relative dispersion of
1105            // the per-row tangents around the mean frame just built,
1106            //   Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖².
1107            // 0 ⇒ the frame represents every active row exactly (flat
1108            // decoder); → 1 ⇒ the tangent field disperses so strongly (e.g. a
1109            // full circle, whose tangents average out) that the mean-frame
1110            // compression cannot distinguish gauge motion from curvature. The
1111            // certificate calibrates its per-generator verdict tolerance to
1112            // this scale so it never claims a pin it cannot resolve.
1113            let mut disp_num = 0.0_f64;
1114            let mut disp_den = 0.0_f64;
1115            for row in 0..n {
1116                let a_nk = assignments[[row, atom_idx]];
1117                if !(a_nk > 0.0) {
1118                    continue;
1119                }
1120                for axis in 0..d {
1121                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1122                    for i in 0..p {
1123                        let dev = tangent[i] - frame[[i, axis]];
1124                        disp_num += a_nk * dev * dev;
1125                        disp_den += a_nk * tangent[i] * tangent[i];
1126                    }
1127                }
1128            }
1129            let lowering_error = if disp_den > 0.0 {
1130                (disp_num / disp_den).clamp(0.0, 1.0)
1131            } else {
1132                0.0
1133            };
1134
1135            let ard_variances = per_atom_ard_variances
1136                .and_then(|all| all.get(atom_idx))
1137                .and_then(|opt| opt.clone())
1138                .filter(|v| v.len() == d);
1139
1140            fitted_atoms.push(FittedAtom {
1141                name: atom.name.clone(),
1142                topology,
1143                frame,
1144                ard_variances,
1145                lowering_error,
1146                // #1019: post-fit chart canonicalization (arc length for
1147                // d = 1, isometry-flow for d = 2 torus, flat-reference
1148                // isometry-flow for d = 2 free/patch, round-sphere
1149                // conformal-boost flow for d = 2 sphere atoms) pins the chart;
1150                // the certificate downgrades this atom's chart freedom to the
1151                // finite isometry group with PinnedByCanonicalization
1152                // provenance.
1153                chart_canonicalized: atom.chart_canonicalized
1154                    && (d == 1
1155                        || (d == 2
1156                            && matches!(
1157                                atom.basis_kind,
1158                                SaeAtomBasisKind::Torus
1159                                    | SaeAtomBasisKind::Linear
1160                                    | SaeAtomBasisKind::Duchon
1161                                    | SaeAtomBasisKind::EuclideanPatch
1162                                    | SaeAtomBasisKind::Sphere
1163                            ))),
1164                // #1097 / #1103: the per-atom inner-decoder-smooth snapshot,
1165                // attached when the post-fit harness has run
1166                // [`Self::set_atom_inner_fits`] (it needs the reconstruction
1167                // target Z, dropped from the objective at fit end). `None` on a
1168                // bare certificate-only model, or for a degenerate atom whose
1169                // inner Hessian was not SPD.
1170                inner_fit: self
1171                    .atom_inner_fits
1172                    .as_ref()
1173                    .and_then(|fits| fits.get(atom_idx))
1174                    .and_then(|slot| slot.clone()),
1175            });
1176            atom_offsets.push(cursor);
1177            atom_axis_dim.push(d);
1178            cursor += p * d;
1179        }
1180        let param_dim = cursor;
1181
1182        // Per-row pinning Jacobian `J_n ∈ ℝ^{p × param_dim}` flattened row-major
1183        // (`J_n[i, c] = jacobian_rows[n][i · param_dim + c]`). Column `(k, i', a)`
1184        // of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i']` placed at the atom-k frame slot
1185        // and read out on output coordinate `i = i'` (a frame perturbation of
1186        // output `i'` moves only the row's output coordinate `i'`).
1187        //
1188        // The pinned certificate still consumes the legacy row-block contract.
1189        // The unpinned exact path consumes only `RᵀR`, so stream each transient
1190        // row Jacobian through the metric whitening and discard it immediately.
1191        let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
1192            let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1193            let mut tangent = vec![0.0_f64; p];
1194            for row in 0..n {
1195                let mut j_flat = vec![0.0_f64; p * param_dim];
1196                for (atom_idx, atom) in self.atoms.iter().enumerate() {
1197                    let a_nk = assignments[[row, atom_idx]];
1198                    if !(a_nk > 0.0) {
1199                        continue;
1200                    }
1201                    let d = atom_axis_dim[atom_idx];
1202                    let base = atom_offsets[atom_idx];
1203                    for axis in 0..d {
1204                        atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1205                        for i in 0..p {
1206                            // Frame coordinate `(k, i, axis)` sits at column
1207                            // `base + i·d + axis`; it sources output coordinate `i`.
1208                            j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1209                        }
1210                    }
1211                }
1212                jacobian_rows.push(j_flat);
1213            }
1214            (jacobian_rows, None)
1215        } else {
1216            let streamed = self.residual_gauge_streamed_data_curvature(
1217                &metric,
1218                &atom_offsets,
1219                &atom_axis_dim,
1220                param_dim,
1221            )?;
1222            (Vec::new(), Some(streamed))
1223        };
1224
1225        // Isometry-penalty curvature root over the frame parameter space. When
1226        // the isometry gauge pin is active it gives curvature along every fitted
1227        // frame direction (it resists deviation of the decoder image from its
1228        // arc-length parameterization), so its row space is the span of the
1229        // per-atom frame columns: one root row per `(k, axis)` carrying that
1230        // atom's frame column at the atom's frame slot. Empty (`0 × param_dim`)
1231        // when the pin is inactive — exactly the certificate's escalation
1232        // condition to `diffeomorphism-unpinned`.
1233        let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
1234            let mut root_rows: Vec<Array1<f64>> = Vec::new();
1235            for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
1236                let d = atom_axis_dim[atom_idx];
1237                let base = atom_offsets[atom_idx];
1238                for axis in 0..d {
1239                    let mut r = Array1::<f64>::zeros(param_dim);
1240                    let mut any = false;
1241                    for i in 0..p {
1242                        let v = fitted.frame[[i, axis]];
1243                        if v != 0.0 {
1244                            any = true;
1245                        }
1246                        r[base + i * d + axis] = v;
1247                    }
1248                    if any {
1249                        root_rows.push(r);
1250                    }
1251                }
1252            }
1253            let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
1254            for (ri, r) in root_rows.iter().enumerate() {
1255                root.row_mut(ri).assign(r);
1256            }
1257            root
1258        } else {
1259            Array2::<f64>::zeros((0, param_dim))
1260        };
1261
1262        Ok((
1263            FittedSaeManifold {
1264                atoms: fitted_atoms,
1265                jacobian_rows,
1266                isometry_penalty_root,
1267                metric,
1268            },
1269            streamed_curvature,
1270        ))
1271    }
1272
1273    pub(crate) fn residual_gauge_streamed_data_curvature(
1274        &self,
1275        metric: &gam_problem::RowMetric,
1276        atom_offsets: &[usize],
1277        atom_axis_dim: &[usize],
1278        param_dim: usize,
1279    ) -> Result<(Array2<f64>, usize), String> {
1280        let n = self.n_obs();
1281        let p = self.output_dim();
1282        if metric.p_out() != p {
1283            return Err(format!(
1284                "residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
1285                metric.p_out()
1286            ));
1287        }
1288        let rank = metric.metric_rank();
1289        let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
1290        if param_dim == 0 || n == 0 || rank == 0 {
1291            return Ok((gram, n * rank));
1292        }
1293
1294        let assignments = self.assignment.assignments();
1295        let mut tangent = vec![0.0_f64; p];
1296        let mut j_flat = vec![0.0_f64; p * param_dim];
1297        let mut root_row = Array1::<f64>::zeros(param_dim);
1298        for row in 0..n {
1299            j_flat.fill(0.0);
1300            for (atom_idx, atom) in self.atoms.iter().enumerate() {
1301                let a_nk = assignments[[row, atom_idx]];
1302                if !(a_nk > 0.0) {
1303                    continue;
1304                }
1305                let d = atom_axis_dim[atom_idx];
1306                let base = atom_offsets[atom_idx];
1307                for axis in 0..d {
1308                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1309                    for i in 0..p {
1310                        j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1311                    }
1312                }
1313            }
1314
1315            if metric.drives_gauge() {
1316                for r in 0..rank {
1317                    root_row.fill(0.0);
1318                    for c in 0..param_dim {
1319                        let mut acc = 0.0_f64;
1320                        for i in 0..p {
1321                            acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
1322                        }
1323                        root_row[c] = acc;
1324                    }
1325                    let row_slice = root_row.as_slice().ok_or_else(|| {
1326                        "residual_gauge_streamed_data_curvature: non-contiguous root row"
1327                            .to_string()
1328                    })?;
1329                    Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
1330                }
1331            } else {
1332                for i in 0..p {
1333                    let start = i * param_dim;
1334                    let end = start + param_dim;
1335                    Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
1336                }
1337            }
1338        }
1339
1340        for a in 0..param_dim {
1341            for b in 0..a {
1342                gram[[b, a]] = gram[[a, b]];
1343            }
1344        }
1345        Ok((gram, n * rank))
1346    }
1347
1348    pub(crate) fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
1349        for a in 0..row.len() {
1350            let va = row[a];
1351            if va == 0.0 {
1352                continue;
1353            }
1354            for b in 0..=a {
1355                let vb = row[b];
1356                if vb != 0.0 {
1357                    gram[[a, b]] += va * vb;
1358                }
1359            }
1360        }
1361    }
1362
1363    pub fn set_temperature_schedule(
1364        &mut self,
1365        sched: GumbelTemperatureSchedule,
1366    ) -> Result<(), String> {
1367        sched.validate()?;
1368        self.assignment
1369            .mode
1370            .set_temperature(sched.current_tau(sched.iter_count))?;
1371        self.temperature_schedule = Some(sched);
1372        Ok(())
1373    }
1374
1375    pub(crate) fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
1376        let Some(schedule) = self.temperature_schedule.as_mut() else {
1377            return Ok(None);
1378        };
1379        schedule.validate()?;
1380        let tau = schedule.step();
1381        self.assignment.mode.set_temperature(tau)?;
1382        Ok(Some(tau))
1383    }
1384
1385    pub fn n_obs(&self) -> usize {
1386        self.assignment.n_obs()
1387    }
1388
1389    pub fn k_atoms(&self) -> usize {
1390        self.atoms.len()
1391    }
1392
1393    /// Auto-derived in-core vs streaming plan for SAE Arrow-Schur work.
1394    ///
1395    /// This is intentionally not user-configurable: the route follows the
1396    /// retained full-batch working-set estimate and the currently selected GPU
1397    /// memory budget when CUDA is usable, otherwise a conservative host budget.
1398    pub fn streaming_plan(&self) -> SaeStreamingPlan {
1399        let n_obs = self.n_obs();
1400        let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
1401        let d_max = self
1402            .atoms
1403            .iter()
1404            .map(|atom| atom.latent_dim)
1405            .max()
1406            .unwrap_or(0);
1407        let border_dim = if self.any_frame_active() {
1408            self.factored_border_dim()
1409        } else {
1410            self.beta_dim()
1411        };
1412        sae_streaming_plan_for_shape(n_obs, total_basis, self.k_atoms(), d_max, border_dim)
1413    }
1414
1415    /// Construction-time validation: every Psi-tier analytic penalty in the
1416    /// registry must be dispatchable into the SAE arrow-Schur row layout.
1417    ///
1418    /// Two invariants are enforced upfront so the dispatch loop in
1419    /// `add_sae_analytic_penalty_contributions` is total (no runtime
1420    /// "unsupported penalty" fallthrough, no per-call K-gating):
1421    ///
1422    /// 1. Every Psi-tier penalty is either in [`sae_penalty_is_row_block_supported`],
1423    ///    or `NuclearNorm` (which is redirected to the per-atom decoder (β) block
1424    ///    rather than the coord "t" row block). Assignment sparsity penalties
1425    ///    (`IBPAssignment`, `SoftmaxAssignmentSparsity`) are refused because the SAE
1426    ///    term already owns them through its built-in assignment path
1427    ///    (`loss.assignment_sparsity`). Penalty kinds with cross-row structure
1428    ///    (`TotalVariation`, `Monotonicity`, `BlockSparsity`,
1429    ///    `IvaeRidgeMeanGauge`, `Orthogonality`, `NestedPrefix`,
1430    ///    `SheafConsistency`) cannot be expressed in the SAE row-block layout
1431    ///    and are refused here.
1432    ///
1433    /// 2. If any Psi-tier row-block penalty is present, every atom shares
1434    ///    the same coord latent dim. The current registry model carries one
1435    ///    `latent_dim` per descriptor (the "t" latent block declares one
1436    ///    `d` value); per-atom dispatch with heterogeneous `d_k` would
1437    ///    require per-atom registry entries or per-kind in-place
1438    ///    reshaping. Mixed-d row-block fits are rejected with an actionable
1439    ///    error pointing at the configuration mismatch.
1440    ///
1441    /// The K=1 case trivially satisfies (2). Beta-tier and rho-tier
1442    /// penalties are not constrained here.
1443    pub(crate) fn validate_analytic_penalty_registry(
1444        &self,
1445        registry: &AnalyticPenaltyRegistry,
1446    ) -> Result<(), String> {
1447        let mut row_block_penalty_present = false;
1448        for penalty in &registry.penalties {
1449            if penalty.tier() != PenaltyTier::Psi {
1450                continue;
1451            }
1452            if matches!(
1453                penalty,
1454                AnalyticPenaltyKind::IBPAssignment(_)
1455                    | AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
1456            ) {
1457                return Err(format!(
1458                    "SAE-manifold term refuses analytic penalty {:?}: assignment sparsity \
1459                     is owned by the built-in SAE assignment path (loss.assignment_sparsity). \
1460                     Registering it would double-count the objective and gradient",
1461                    penalty.name()
1462                ));
1463            }
1464            // NuclearNorm is redirected to the per-atom decoder (β) block in
1465            // `add_sae_beta_penalty` (it penalizes each atom's decoder matrix
1466            // singular spectrum, i.e. its embedding rank), so it bypasses the
1467            // coord "t" row-block requirement below.
1468            if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
1469                continue;
1470            }
1471            if !sae_penalty_is_row_block_supported(penalty) {
1472                return Err(format!(
1473                    "SAE-manifold term refuses analytic penalty {:?}: this kind \
1474                     has cross-row structure and cannot be expressed in the \
1475                     arrow-Schur row layout. Use only row-block-supported \
1476                     coord penalties (ARD, BlockOrthogonality, \
1477                     Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
1478                     ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
1479                     coord latent block, or move the penalty to a non-SAE \
1480                     term",
1481                    penalty.name()
1482                ));
1483            }
1484            row_block_penalty_present = true;
1485        }
1486        if row_block_penalty_present {
1487            let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
1488            if let Some(first) = dims.next() {
1489                if let Some(mismatch) = dims.find(|d| *d != first) {
1490                    return Err(format!(
1491                        "SAE-manifold term refuses row-block analytic penalty: \
1492                         atoms have heterogeneous coord latent dims (saw {first} \
1493                         and {mismatch}). Row-block penalties (ARD, \
1494                         BlockOrthogonality, ...) target the unified \"t\" \
1495                         latent block whose declared `d` matches one shape; \
1496                         per-atom dispatch with mixed `d_k` would silently \
1497                         truncate or expand axes. Configure all atoms with the \
1498                         same `atom_dim`, or split the row-block penalty into \
1499                         per-atom descriptors keyed to per-atom latent blocks"
1500                    ));
1501                }
1502            }
1503        }
1504        Ok(())
1505    }
1506
1507    pub fn output_dim(&self) -> usize {
1508        self.atoms[0].output_dim()
1509    }
1510
1511    pub fn beta_dim(&self) -> usize {
1512        let p = self.output_dim();
1513        self.atoms.iter().map(|a| a.basis_size() * p).sum()
1514    }
1515
1516    pub(crate) fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
1517        let mut workspace =
1518            std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
1519        if workspace.dim() != (border_dim, border_dim) {
1520            workspace = Array2::<f64>::zeros((border_dim, border_dim));
1521        } else {
1522            workspace.fill(0.0);
1523        }
1524        workspace
1525    }
1526
1527    pub(crate) fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
1528        let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
1529        self.border_hbb_workspace = workspace;
1530    }
1531
1532    /// Factored arrow-Schur border dimension `Σ_k M_k · r_k` (issue #972): the
1533    /// number of decoder coordinates the border actually carries once the
1534    /// low-rank Grassmann frames are profiled out. Atoms with no active frame
1535    /// contribute their full `M_k · p` (`r_k == p`), so on the all-full-`B` path
1536    /// this equals [`Self::beta_dim`]. The border Cholesky / evidence log-det
1537    /// scale with THIS count, not `beta_dim`.
1538    pub fn factored_border_dim(&self) -> usize {
1539        self.atoms.iter().map(|a| a.border_coeff_count()).sum()
1540    }
1541
1542    /// Total profiled-out Grassmann manifold dimension `Σ_k r_k·(p − r_k)` across
1543    /// all active frames (issue #972). This is the count of decoder-frame degrees
1544    /// of freedom estimated OUTSIDE the border by closed-form polar steps, and it
1545    /// must enter the Laplace evidence dimension accounting (evidence honesty):
1546    /// the profiled frame is a MAP point on `∏_k Gr(r_k, p)`, contributing this
1547    /// many free dimensions to the model. `0` when every atom is on the full-`B`
1548    /// path. Threaded into [`Self::reml_occam_term`].
1549    pub fn grassmann_evidence_dimension(&self) -> usize {
1550        self.atoms
1551            .iter()
1552            .map(|a| a.frame_manifold_dimension())
1553            .sum()
1554    }
1555
1556    /// True iff any atom has an active low-rank Grassmann frame (issue #972).
1557    pub fn frames_active(&self) -> bool {
1558        self.atoms.iter().any(|a| a.decoder_frame.is_some())
1559    }
1560
1561    /// Alias of [`Self::frames_active`] (issue #972 / #977 T1): the predicate the
1562    /// assembly / step-lift branch on to decide whether the β-tier is built in
1563    /// the factored coordinate layout. Named to read as the question
1564    /// "is the factored path engaged?" at its call sites.
1565    pub fn any_frame_active(&self) -> bool {
1566        self.frames_active()
1567    }
1568
1569    /// Per-atom column offsets of the *factored* border (issue #972 / #977 T1):
1570    /// the running prefix sum of `M_k · r_k`, one entry per atom (the same
1571    /// convention as [`Self::beta_offsets`]). This is the start of each atom's
1572    /// `C_k` block in the reduced border vector; on the all-full-`B` path it
1573    /// equals `beta_offsets`. Distinct from [`Self::factored_border_offsets`]
1574    /// only in name (both compute the identical prefix sum) — this method is the
1575    /// one the frame transform reads, mirroring `beta_offsets` at the call site.
1576    pub fn factored_beta_offsets(&self) -> Vec<usize> {
1577        self.factored_border_offsets()
1578    }
1579
1580    /// Frame output matrix `U_k ∈ St(p, r_k)` for atom `k` (issue #972 / #977 T1).
1581    /// Returns the active frame `U_k` (`p × r_k`) when atom `k` is framed, else
1582    /// the identity `I_p` (the `r_k == p`, `U_k == I_p` full-`B` special case) so
1583    /// the projection / lift code is uniform across a mixed dictionary.
1584    pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
1585        let atom = &self.atoms[atom_idx];
1586        match &atom.decoder_frame {
1587            Some(frame) => frame.frame().to_owned(),
1588            None => Array2::<f64>::eye(atom.output_dim()),
1589        }
1590    }
1591
1592    /// Per-pair frame factor `W_{ij} = U_iᵀ U_j` (`r_i × r_j`) used as the output
1593    /// factor of the factored data β-Hessian block `G_{ij} ⊗ W_{ij}` (issue #972
1594    /// / #977 T1). When both atoms are framed this is the dense principal-angle
1595    /// cosine matrix between the two frames; for `i == j` with an orthonormal
1596    /// frame it is exactly `I_{r_i}`; for any un-framed atom the corresponding
1597    /// `U` is `I_p`, so a same-atom un-framed pair gives `I_p` (the clean full-`B`
1598    /// `G ⊗ I_p` collapse) and a framed/un-framed cross pair gives the rectangular
1599    /// `U_iᵀ` / `U_j` overlap.
1600    pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
1601        let ui = self.frame_output_matrix(atom_i);
1602        let uj = self.frame_output_matrix(atom_j);
1603        // `U_iᵀ U_j`: `(r_i × p) · (p × r_j)`. `fast_atb` forms `U_iᵀ U_j` directly.
1604        fast_atb(&ui, &uj)
1605    }
1606
1607    /// Per-atom column offsets of the *factored* border (issue #972): the
1608    /// running prefix sum of `M_k · r_k`. The analogue of [`Self::beta_offsets`]
1609    /// for the reduced coordinate layout — atom `k`'s `C_k` occupies
1610    /// `[factored_border_offsets()[k] .. + M_k·r_k)`. On the full-`B` path this
1611    /// equals `beta_offsets`.
1612    pub fn factored_border_offsets(&self) -> Vec<usize> {
1613        let mut out = Vec::with_capacity(self.k_atoms());
1614        let mut cursor = 0usize;
1615        for atom in &self.atoms {
1616            out.push(cursor);
1617            cursor += atom.border_coeff_count();
1618        }
1619        out
1620    }
1621
1622    /// Assemble the factored border coordinate vector `C = [vec(C_1); …; vec(C_K)]`
1623    /// in row-major `C_k[m, j] → C[off_k + m·r_k + j]` layout (issue #972).
1624    ///
1625    /// This is the reduced state the arrow-Schur border carries when frames are
1626    /// active: its length is [`Self::factored_border_dim`] (`Σ M_k·r_k`), the
1627    /// border-size invariant verified by [`grassmann_assert_border_dim_invariant`].
1628    /// Atoms
1629    /// without an active frame contribute their full `vec(B_k)` (their `r_k == p`
1630    /// coordinates are the decoder itself), so on the all-full-`B` path this
1631    /// reproduces [`Self::flatten_beta`].
1632    pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
1633        let offsets = self.factored_border_offsets();
1634        let mut out = Array1::<f64>::zeros(self.factored_border_dim());
1635        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1636            let off = offsets[atom_idx];
1637            let r = atom.border_frame_rank();
1638            let m = atom.basis_size();
1639            let coords = match atom.factored_coordinates()? {
1640                Some(c) => c,
1641                // Full-`B` path: the decoder itself is the coordinate matrix.
1642                None => atom.decoder_coefficients.clone(),
1643            };
1644            for basis_col in 0..m {
1645                for j in 0..r {
1646                    out[off + basis_col * r + j] = coords[[basis_col, j]];
1647                }
1648            }
1649        }
1650        Ok(out)
1651    }
1652
1653    /// Scatter a factored border coordinate vector `C` (length
1654    /// [`Self::factored_border_dim`]) back into the per-atom decoders, refreshing
1655    /// each `decoder_coefficients = C_k · U_kᵀ` so the full-`B` consumers stay
1656    /// consistent after a factored border solve (issue #972). The inverse of
1657    /// [`Self::flatten_factored_border`].
1658    pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
1659        let expected = self.factored_border_dim();
1660        if border.len() != expected {
1661            return Err(format!(
1662                "SaeManifoldTerm::scatter_factored_border: border length {} must equal \
1663                 factored border dim {expected}",
1664                border.len()
1665            ));
1666        }
1667        let offsets = self.factored_border_offsets();
1668        for atom_idx in 0..self.atoms.len() {
1669            let off = offsets[atom_idx];
1670            let (r, m, has_frame) = {
1671                let atom = &self.atoms[atom_idx];
1672                (
1673                    atom.border_frame_rank(),
1674                    atom.basis_size(),
1675                    atom.decoder_frame.is_some(),
1676                )
1677            };
1678            let mut coords = Array2::<f64>::zeros((m, r));
1679            for basis_col in 0..m {
1680                for j in 0..r {
1681                    coords[[basis_col, j]] = border[off + basis_col * r + j];
1682                }
1683            }
1684            if has_frame {
1685                self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
1686            } else {
1687                // Full-`B` path: the coordinates ARE the decoder.
1688                self.atoms[atom_idx].decoder_coefficients = coords;
1689            }
1690        }
1691        Ok(())
1692    }
1693
1694    /// Auto-derive and install low-rank Grassmann decoder frames across all
1695    /// atoms (issue #972) — magic-by-default, no flag. Each atom independently
1696    /// activates its frame iff the factorization materially shrinks its border
1697    /// (see [`SaeManifoldAtom::maybe_activate_decoder_frame`]). Returns the
1698    /// number of atoms that activated a frame. Idempotent: re-running re-derives
1699    /// each frame from the current decoder.
1700    ///
1701    /// The decision keys on the *frontier* regime the issue targets: at large
1702    /// ambient `p` the full border `Σ M_k · p` reaches `10^7`–`10^8` and the
1703    /// border Cholesky dies, while the decoder's effective column rank `r` stays
1704    /// `≪ p`. Small-`p` atoms (where `r` cannot beat the activation margin)
1705    /// keep the bit-for-bit full-`B` path, so the small-model evidence is
1706    /// unchanged (verified by `factored_evidence_matches_full_b_at_small_p`).
1707    pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
1708        let mut activated = 0usize;
1709        for atom in &mut self.atoms {
1710            let expected_rank = atom.decoder_frame_activation_rank()?;
1711            match (
1712                expected_rank,
1713                atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
1714            ) {
1715                (Some(expected), Some(current)) if expected == current => {
1716                    continue;
1717                }
1718                (None, Some(_)) => {
1719                    atom.deactivate_decoder_frame();
1720                    continue;
1721                }
1722                (None, None) => {
1723                    continue;
1724                }
1725                (Some(_), _) => {}
1726            }
1727            if atom.maybe_activate_decoder_frame()?.is_some() {
1728                activated += 1;
1729            }
1730        }
1731        Ok(activated)
1732    }
1733
1734    /// Reconcile decoder-frame activation before a fit entry point. The
1735    /// user-facing `auto_activate_decoder_frames` contract returns only newly
1736    /// installed frames; this helper enforces the stronger invariant the large-p
1737    /// solver needs: every atom whose current decoder satisfies the activation
1738    /// predicate has an active frame after the pass.
1739    pub(crate) fn ensure_decoder_frames_active_for_current_decoder(
1740        &mut self,
1741    ) -> Result<(), String> {
1742        self.auto_activate_decoder_frames()?;
1743        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1744            let expected_rank = atom.decoder_frame_activation_rank()?;
1745            if let Some(expected_rank) = expected_rank {
1746                match atom.decoder_frame.as_ref() {
1747                    Some(frame) if frame.rank() == expected_rank => {}
1748                    Some(frame) => {
1749                        return Err(format!(
1750                            "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1751                             atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
1752                            frame.rank()
1753                        ));
1754                    }
1755                    None => {
1756                        return Err(format!(
1757                            "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1758                             atom {atom_idx} has audited rank {expected_rank} but no active frame"
1759                        ));
1760                    }
1761                }
1762            } else if atom.decoder_frame.is_some() {
1763                return Err(format!(
1764                    "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1765                     atom {atom_idx} kept a frame after the full-B predicate won"
1766                ));
1767            }
1768        }
1769        Ok(())
1770    }
1771
1772    /// Closed-form streaming POLAR refresh of every ACTIVE decoder frame from the
1773    /// current data evidence (issue #972 / #977 T1) — the U-block of the
1774    /// alternating block-coordinate ascent that complements the border's
1775    /// C-block Newton step.
1776    ///
1777    /// For each framed atom `k` we accumulate the `p × r_k` cross-moment
1778    ///   `A_k = Σ_n a_{n,k} · e_{n,k} · ĉ_{n,k}ᵀ`,
1779    /// where `e_{n,k} = z_n − Σ_{k'≠k} a_{n,k'}·decoded_{k'}(n)` is the row's
1780    /// partial reconstruction residual (everything except atom `k`) and
1781    /// `ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^{r_k}` is atom `k`'s in-span decoded
1782    /// coordinate. The polar factor `U_new = polar(A_k)` is the closed-form MAP
1783    /// frame on `Gr(r_k, p)` given the C-coordinates held fixed — the same
1784    /// `O(p r²)` thin SVD the issue prescribes, run OUTSIDE the border. The frame
1785    /// is then re-installed and the decoder re-projected onto it so the
1786    /// authoritative `B_k = C_k U_newᵀ` and the `(C_k, U_new)` pair stay
1787    /// consistent (a no-op in span for a truly rank-`r` atom). Un-framed atoms
1788    /// are skipped. Returns the number of frames refreshed.
1789    pub(crate) fn refresh_active_frames_from_data(
1790        &mut self,
1791        target: ArrayView2<'_, f64>,
1792        rho: &SaeManifoldRho,
1793    ) -> Result<usize, String> {
1794        let n = self.n_obs();
1795        let p = self.output_dim();
1796        let k_atoms = self.k_atoms();
1797        if n == 0 {
1798            return Ok(0);
1799        }
1800        // Per-row assignments and per-(row, atom) decoded outputs, computed once.
1801        let mut assignments = Vec::with_capacity(n);
1802        for row in 0..n {
1803            assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
1804        }
1805        let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
1806        let mut dbuf = vec![0.0_f64; p];
1807        for row in 0..n {
1808            for atom_idx in 0..k_atoms {
1809                self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
1810                for c in 0..p {
1811                    decoded[[row, atom_idx, c]] = dbuf[c];
1812                }
1813            }
1814        }
1815        // Full fitted reconstruction `Σ_k a_k decoded_k`, so the per-atom partial
1816        // residual is `e_k = (z − fitted) + a_k decoded_k` (add atom k back in).
1817        let mut fitted = Array2::<f64>::zeros((n, p));
1818        for row in 0..n {
1819            for atom_idx in 0..k_atoms {
1820                let a = assignments[row][atom_idx];
1821                if a == 0.0 {
1822                    continue;
1823                }
1824                for c in 0..p {
1825                    fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
1826                }
1827            }
1828        }
1829        let mut refreshed = 0usize;
1830        for atom_idx in 0..k_atoms {
1831            // Only atoms with an active frame are refreshed.
1832            let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
1833                continue;
1834            };
1835            let r = self.atoms[atom_idx].border_frame_rank();
1836            let m = self.atoms[atom_idx].basis_size();
1837            // Accumulate `A_k = Σ_n a_k · e_{n,k} · ĉ_{n,k}ᵀ` directly (p × r).
1838            let mut cross = GrassmannCrossMoment::new(p, r);
1839            // Build per-row p-target `a_k·e_k` and r-coord `a_k·ĉ` batched, then
1840            // accumulate as one outer-product sum. `accumulate` forms
1841            // `targetsᵀ·coords`, so scaling EITHER side by `a_k` once gives the
1842            // `a_k²` weight on the cross-moment that matches the C-block normal
1843            // equations (residual leg carries `a_k`, coordinate leg carries
1844            // `a_k`).
1845            let mut targets = Array2::<f64>::zeros((n, p));
1846            let mut rcoords = Array2::<f64>::zeros((n, r));
1847            for row in 0..n {
1848                let a = assignments[row][atom_idx];
1849                // Partial residual e_{n,k} = z_n − (fitted − a_k decoded_k).
1850                for c in 0..p {
1851                    let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
1852                    targets[[row, c]] = a * e;
1853                }
1854                // In-span coordinate ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^r.
1855                for j in 0..r {
1856                    let mut acc = 0.0_f64;
1857                    for basis_col in 0..m {
1858                        acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
1859                            * coords_c[[basis_col, j]];
1860                    }
1861                    rcoords[[row, j]] = a * acc;
1862                }
1863            }
1864            cross.accumulate(targets.view(), rcoords.view())?;
1865            // `polar(A_k)` is well-defined only when the moment is non-trivial;
1866            // a zero moment (e.g. a fully collapsed atom) leaves the frame as-is.
1867            if cross.moment().iter().all(|&v| v == 0.0) {
1868                continue;
1869            }
1870            self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
1871            refreshed += 1;
1872        }
1873        Ok(refreshed)
1874    }
1875
1876    pub fn beta_offsets(&self) -> Vec<usize> {
1877        let p = self.output_dim();
1878        let mut out = Vec::with_capacity(self.k_atoms());
1879        let mut cursor = 0usize;
1880        for atom in &self.atoms {
1881            out.push(cursor);
1882            cursor += atom.basis_size() * p;
1883        }
1884        out
1885    }
1886
1887    /// Per-atom β column ranges for the block-Jacobi Schur preconditioner.
1888    ///
1889    /// Returns one `Range<usize>` per atom, covering that atom's decoder
1890    /// coefficients in the flat β vector:
1891    ///   `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
1892    ///
1893    /// Pass to [`ArrowSchurSystem::set_block_offsets`] so that
1894    /// [`gam_solve::arrow_schur::JacobiPreconditioner`] builds one dense
1895    /// Schur sub-block per atom instead of scalar-diagonal inversion.
1896    pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
1897        let p = self.output_dim();
1898        let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
1899        let mut cursor = 0usize;
1900        for atom in &self.atoms {
1901            let width = atom.basis_size() * p;
1902            ranges.push(cursor..cursor + width);
1903            cursor += width;
1904        }
1905        Arc::from(ranges.into_boxed_slice())
1906    }
1907
1908    /// Decide whether the sparse per-row active-set layout is engaged for a
1909    /// dense-weight assignment mode, and if so derive the per-row active-atom
1910    /// cap and magnitude cutoff.
1911    ///
1912    /// #1408: this plan is mode-agnostic. `assemble_arrow_schur` consults it
1913    /// directly for IBP-MAP, and for `AssignmentMode::Softmax` via
1914    /// [`Self::softmax_active_plan`], which tightens it with an explicit `top_k`
1915    /// (`softmax_active_cap`). Softmax therefore engages the compact active-set
1916    /// layout whenever `top_k` or the budget bounds the active set (the
1917    /// active-sub-block Gershgorin majorizer + coherent logdet/θ-adjoint are
1918    /// landed — see `SaeRowLayout`'s doc); it keeps the full `K`-atom layout only
1919    /// when neither lever engages. The decision is auto-derived from
1920    /// the problem size and the device/host working-set budget — never a CLI flag
1921    /// or kwarg. JumpReLU is not handled here (it always uses its structural gate
1922    /// via [`SaeRowLayout::from_jumprelu`]). The dense Gauss-Newton data Gram `G`
1923    /// is `(m_total × m_total)` f64; if its dense form fits the budget we keep
1924    /// the exact full-support solve (every atom active per row), so small-`K`
1925    /// problems are bit-for-bit unchanged. Above that, we cap each row to the
1926    /// `k_active` atoms that make the *sparse* Gram fit the same budget, with a
1927    /// relative magnitude cutoff that drops assignment mass contributing
1928    /// negligible `O(a²)` curvature.
1929    ///
1930    /// Returns `Some((k_active_cap, cutoff))` to engage sparsity, or `None` to
1931    /// keep the dense full-support layout.
1932    pub(crate) fn sparse_active_plan(&self) -> Option<(usize, f64)> {
1933        // The per-row Riemannian tangent projection for non-Euclidean atom
1934        // latents is now applied directly on the compact active-set rows (see
1935        // the `Some(layout)` arm in `assemble_arrow_schur`, via
1936        // `compact_row_ext_manifold_and_point`), which rebuilds each row's
1937        // product manifold in its compact column order and applies the SAME
1938        // gt/htt/htbeta + Kronecker-Jacobian projections the dense path uses. So
1939        // the sparse plan may engage on curved ext-coord manifolds (circle /
1940        // torus / sphere atoms) — the affordability lever for manifold-SAE at
1941        // large `K`, where the dense `K²` co-assignment Gram is the cost. (The
1942        // former `is_euclidean()`-only restriction punted every curved atom to
1943        // the dense layout; it is lifted.) The host/device in-core budget is the
1944        // single gate now; it is parameterised in `sparse_active_plan_for_budget`
1945        // so the engagement regression can pin a small budget without allocating
1946        // a multi-GB dense Gram.
1947        let budget = match crate::gpu::device_runtime::GpuRuntime::global() {
1948            // Allow up to one quarter of the AGGREGATE device budget for the dense
1949            // Gram, matching the streaming dispatcher's in-core fraction. The
1950            // per-atom-pair Gram blocks fan out across the whole device pool, so
1951            // the in-core fraction sums every ordinal's budget, not just the
1952            // primary's.
1953            Some(rt) => {
1954                let aggregate: usize = rt
1955                    .device_ordinals()
1956                    .iter()
1957                    .map(|&ord| rt.memory_budget_for(ord))
1958                    .sum();
1959                aggregate / 4
1960            }
1961            None => sae_host_in_core_budget_bytes().0,
1962        };
1963        self.sparse_active_plan_for_budget(budget)
1964    }
1965
1966    /// Budget-parameterised core of [`Self::sparse_active_plan`]. The dense data
1967    /// Gram footprint `(m_total · m_total) f64` is compared against `budget`; a
1968    /// term whose dense Gram exceeds the budget engages the compact active-set
1969    /// plan (returns `Some((k_active_cap, cutoff))`), regardless of whether any
1970    /// atom latent is curved. Pulled out so the curved-atom engagement
1971    /// regression can pin a small budget deterministically.
1972    pub(crate) fn sparse_active_plan_for_budget(&self, budget: usize) -> Option<(usize, f64)> {
1973        // Relative magnitude cutoff: assignment mass below this fraction of the
1974        // row's peak `|a_k|` enters the Gram only as `O(a²)` curvature and is
1975        // dropped. Chosen so dropped terms are ~1e-6 of the peak self-coupling.
1976        const RELATIVE_CUTOFF: f64 = 1.0e-3;
1977
1978        let k_atoms = self.k_atoms();
1979        if k_atoms <= 1 {
1980            return None;
1981        }
1982        let p = self.output_dim();
1983        let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
1984        // Dense data Gram footprint: (m_total · m_total) f64.
1985        let dense_gram_bytes = m_total
1986            .saturating_mul(m_total)
1987            .saturating_mul(SAE_BYTES_PER_F64);
1988        if dense_gram_bytes <= budget {
1989            return None;
1990        }
1991
1992        // Sparse Gram footprint scales with the per-row active basis count
1993        // `k_active · m_atom`. Solve for the largest `k_active` whose sparse
1994        // Gram `(k_active · m_atom)²` still fits the budget.
1995        let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
1996        let max_active_basis = ((budget as f64 / SAE_BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
1997        let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
1998        // p does not enter the Gram dimension (it is carried by the `⊗ I_p`
1999        // structure), but a degenerate `p == 0` term has no decoder columns.
2000        if p == 0 {
2001            return None;
2002        }
2003        Some((k_active_cap, RELATIVE_CUTOFF))
2004    }
2005
2006    /// #1408/#1409 — per-row active-set plan for the Softmax assignment.
2007    ///
2008    /// Engages the compact top-`k` row layout when EITHER the user supplied a
2009    /// hard `top_k` cap ([`Self::softmax_active_cap`], `1 <= k < K`) OR the
2010    /// dense data Gram exceeds the in-core budget (the same memory lever the
2011    /// IBP path uses via [`Self::sparse_active_plan`]). The returned
2012    /// `k_active_cap` is the tighter of the two, so an explicit `top_k`
2013    /// genuinely bounds the optimization even below the memory threshold and a
2014    /// large-K budget breach still bounds it when no `top_k` is set. Returns
2015    /// `None` (keep the exact full-`K` dense softmax layout) when neither lever
2016    /// engages.
2017    ///
2018    /// The cutoff is the same relative magnitude floor as the budget plan
2019    /// (`1e-3` of the row peak); under an explicit `top_k` cap alone (no budget
2020    /// breach) it is `0.0` so exactly the top-`k` atoms are retained.
2021    pub(crate) fn softmax_active_plan(&self) -> Option<(usize, f64)> {
2022        if self.k_atoms() <= 1 {
2023            return None;
2024        }
2025        let budget_plan = self.sparse_active_plan();
2026        match (self.softmax_active_cap, budget_plan) {
2027            (Some(cap), Some((budget_cap, cutoff))) => Some((cap.min(budget_cap), cutoff)),
2028            // Explicit cap only: retain exactly the top-`cap` atoms (no extra
2029            // magnitude cutoff beyond the cap).
2030            (Some(cap), None) => Some((cap, 0.0)),
2031            (None, plan) => plan,
2032        }
2033    }
2034
2035    pub fn flatten_beta(&self) -> Array1<f64> {
2036        let p = self.output_dim();
2037        let offsets = self.beta_offsets();
2038        let mut out = Array1::<f64>::zeros(self.beta_dim());
2039        for (atom_idx, atom) in self.atoms.iter().enumerate() {
2040            let m = atom.basis_size();
2041            let off = offsets[atom_idx];
2042            for basis_col in 0..m {
2043                for out_col in 0..p {
2044                    out[off + basis_col * p + out_col] =
2045                        atom.decoder_coefficients[[basis_col, out_col]];
2046                }
2047            }
2048        }
2049        out
2050    }
2051
2052    pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
2053        if beta.len() != self.beta_dim() {
2054            return Err(format!(
2055                "set_flat_beta: beta length {} != expected {}",
2056                beta.len(),
2057                self.beta_dim()
2058            ));
2059        }
2060        let p = self.output_dim();
2061        let offsets = self.beta_offsets();
2062        for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
2063            let m = atom.basis_size();
2064            let off = offsets[atom_idx];
2065            for basis_col in 0..m {
2066                for out_col in 0..p {
2067                    atom.decoder_coefficients[[basis_col, out_col]] =
2068                        beta[off + basis_col * p + out_col];
2069                }
2070            }
2071        }
2072        Ok(())
2073    }
2074
2075    pub fn refit_decoder_least_squares_at_current_state(
2076        &mut self,
2077        target: ArrayView2<'_, f64>,
2078        rho: Option<&SaeManifoldRho>,
2079    ) -> Result<(), String> {
2080        let n = self.n_obs();
2081        let p = self.output_dim();
2082        if target.dim() != (n, p) {
2083            return Err(format!(
2084                "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: target shape {:?} != ({n}, {p})",
2085                target.dim()
2086            ));
2087        }
2088        let k_atoms = self.k_atoms();
2089        let offsets = self.beta_offsets();
2090        let m_total = self.beta_dim() / p;
2091        let mut design = Array2::<f64>::zeros((n, m_total));
2092        for row in 0..n {
2093            let assignments = match rho {
2094                Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2095                None => self.assignment.try_assignments_row(row)?,
2096            };
2097            for atom_idx in 0..k_atoms {
2098                let atom = &self.atoms[atom_idx];
2099                let weight = assignments[atom_idx];
2100                let m = atom.basis_size();
2101                let off = offsets[atom_idx] / p;
2102                for basis_col in 0..m {
2103                    design[[row, off + basis_col]] = weight * atom.basis_values[[row, basis_col]];
2104                }
2105            }
2106        }
2107        let beta = solve_design_least_squares(design.view(), target)?;
2108        if beta.dim() != (m_total, p) {
2109            return Err(format!(
2110                "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: beta shape {:?} != ({m_total}, {p})",
2111                beta.dim()
2112            ));
2113        }
2114        for atom_idx in 0..k_atoms {
2115            let m = self.atoms[atom_idx].basis_size();
2116            let off = offsets[atom_idx] / p;
2117            for basis_col in 0..m {
2118                for out_col in 0..p {
2119                    self.atoms[atom_idx].decoder_coefficients[[basis_col, out_col]] =
2120                        beta[[off + basis_col, out_col]];
2121                }
2122            }
2123            self.atoms[atom_idx].refresh_intrinsic_smooth_penalty();
2124        }
2125        Ok(())
2126    }
2127
2128    pub fn fitted(&self) -> Array2<f64> {
2129        self.try_fitted().expect("assignment logits must be finite")
2130    }
2131
2132    /// The #1026 hybrid-collapse substitution map: `atom_idx → &AtomLinearImage`
2133    /// for every `d = 1` slot whose post-fit verdict selected its straight
2134    /// (`Θ → 0`) sub-model. Empty when no report has been computed
2135    /// (`hybrid_split_report == None`, e.g. mid-fit) or no slot collapsed. The
2136    /// SINGLE source of the collapse policy — every reconstruction path (the
2137    /// rho-keyed `try_fitted_with_rho`, the explicit-assignment
2138    /// [`Self::reconstruct_from_assignments`] used by the top-k projection)
2139    /// reads it so train, OOS, and top-k reconstructions decode collapsed slots
2140    /// identically (#1228, #1233).
2141    pub(crate) fn hybrid_linear_image_map(
2142        &self,
2143    ) -> std::collections::HashMap<usize, &crate::hybrid_split::AtomLinearImage> {
2144        // A fitted term carries its collapse policy on the post-fit
2145        // `hybrid_split_report`; an OOS term carries the same trained images on
2146        // `oos_linear_images` (#1228). At most one is `Some` in practice, but
2147        // prefer the report when both are present.
2148        if let Some(report) = self.hybrid_split_report.as_ref() {
2149            return report
2150                .verdicts
2151                .iter()
2152                .filter_map(|v| v.linear_image.as_ref().map(|img| (img.atom_idx, img)))
2153                .collect();
2154        }
2155        if let Some(images) = self.oos_linear_images.as_ref() {
2156            return images.iter().map(|img| (img.atom_idx, img)).collect();
2157        }
2158        std::collections::HashMap::new()
2159    }
2160
2161    /// #1228 — attach the trained dictionary's hybrid-collapsed linear images to
2162    /// this (typically OOS) term so its reconstruction (`fitted` / the top-k
2163    /// assembler) decodes verdict-linear `d = 1` slots by the SAME straight
2164    /// sub-model the training reconstruction used, instead of the original
2165    /// curved decoder. Each image's `atom_idx` must index a real slot; an image
2166    /// whose channel count `p` disagrees with this term's output dim, or whose
2167    /// `atom_idx` is out of range, is rejected so a stale/mismatched payload
2168    /// cannot silently corrupt the reconstruction. Pass an empty slice (or never
2169    /// call this) for an all-curved OOS reconstruction.
2170    ///
2171    /// `pub` (not `pub(crate)`): this is part of the FFI surface — the gam-pyffi
2172    /// crate calls it from `latent_basis_and_sae_ffi.rs` to attach a trained
2173    /// dictionary's hybrid-linear images to an OOS reconstruction term (#1228).
2174    /// Downgrading it to `pub(crate)` breaks the gam-pyffi cdylib build with
2175    /// E0624 (the gam lib still compiles, so the lib build does not catch it).
2176    pub fn set_hybrid_linear_images(
2177        &mut self,
2178        images: Vec<crate::hybrid_split::AtomLinearImage>,
2179    ) -> Result<(), String> {
2180        let p = self.output_dim();
2181        let k_atoms = self.k_atoms();
2182        for img in &images {
2183            if img.atom_idx >= k_atoms {
2184                return Err(format!(
2185                    "set_hybrid_linear_images: atom_idx {} out of range (k_atoms={k_atoms})",
2186                    img.atom_idx
2187                ));
2188            }
2189            if img.b0.len() != p || img.b1.len() != p {
2190                return Err(format!(
2191                    "set_hybrid_linear_images: atom {} linear image has p=({}, {}) != output_dim {p}",
2192                    img.atom_idx,
2193                    img.b0.len(),
2194                    img.b1.len()
2195                ));
2196            }
2197            if self.atoms[img.atom_idx].latent_dim != 1 {
2198                return Err(format!(
2199                    "set_hybrid_linear_images: atom {} is not d=1; only d=1 slots collapse to a straight image",
2200                    img.atom_idx
2201                ));
2202            }
2203        }
2204        self.oos_linear_images = if images.is_empty() {
2205            None
2206        } else {
2207            Some(images)
2208        };
2209        Ok(())
2210    }
2211
2212    /// Assemble the reconstruction `Σ_k a[i,k]·g_k(t_{ik})` from an EXPLICIT
2213    /// per-row assignment matrix (e.g. a hard top-k projection of the fitted
2214    /// soft assignments), honouring the #1026 hybrid collapse when `collapse` is
2215    /// set: a verdict-linear `d = 1` slot decodes its straight sub-model image
2216    /// instead of its curved curve, exactly as the production `try_fitted` does.
2217    /// This is the shared assembler the FFI top-k path uses so the projected
2218    /// reconstruction composes with hybrid collapse (#1233) instead of
2219    /// re-deriving the curved image by hand and silently bypassing the verdict.
2220    /// The atom coordinates (`t`) and decoded curves are the term's own fitted
2221    /// ones; only the assignment masses come from `assignments`.
2222    pub fn reconstruct_from_assignments(
2223        &self,
2224        assignments: ArrayView2<'_, f64>,
2225        collapse: bool,
2226    ) -> Result<Array2<f64>, String> {
2227        let n = self.n_obs();
2228        let p = self.output_dim();
2229        let k_atoms = self.k_atoms();
2230        if assignments.dim() != (n, k_atoms) {
2231            return Err(format!(
2232                "SaeManifoldTerm::reconstruct_from_assignments: assignments {:?} != ({n}, {k_atoms})",
2233                assignments.dim()
2234            ));
2235        }
2236        let linear_images = if collapse {
2237            self.hybrid_linear_image_map()
2238        } else {
2239            std::collections::HashMap::new()
2240        };
2241        let mut out = Array2::<f64>::zeros((n, p));
2242        let mut g_buf = vec![0.0_f64; p];
2243        for row in 0..n {
2244            for atom_idx in 0..k_atoms {
2245                let a_k = assignments[[row, atom_idx]];
2246                if a_k == 0.0 {
2247                    continue;
2248                }
2249                if let Some(image) = linear_images.get(&atom_idx) {
2250                    let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2251                    image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2252                } else {
2253                    self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2254                }
2255                let mut out_row = out.row_mut(row);
2256                for out_col in 0..p {
2257                    out_row[out_col] += a_k * g_buf[out_col];
2258                }
2259            }
2260        }
2261        Ok(out)
2262    }
2263
2264    pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
2265        // Production/user-facing reconstruction: honours the #1026 hybrid-split
2266        // verdict (verdict-linear `d = 1` slots decode their straight sub-model).
2267        self.try_fitted_with_rho(None, true)
2268    }
2269
2270    pub(crate) fn try_fitted_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
2271        // Internal/fitting reconstruction: the pure CURVED image (the joint fit
2272        // and the #1026 adjudication both require the uncollapsed curve).
2273        self.try_fitted_with_rho(Some(rho), false)
2274    }
2275
2276    pub(crate) fn try_fitted_with_rho(
2277        &self,
2278        rho: Option<&SaeManifoldRho>,
2279        collapse: bool,
2280    ) -> Result<Array2<f64>, String> {
2281        let n = self.n_obs();
2282        let p = self.output_dim();
2283        let k_atoms = self.k_atoms();
2284        let mut out = Array2::<f64>::zeros((n, p));
2285        // #1026 — the curved/linear hybrid-split verdict is LOAD-BEARING on the
2286        // production reconstruction, not just a side report. When
2287        // [`Self::compute_hybrid_split_report`] (run post-fit in
2288        // `canonicalize_charts_post_fit`) adjudicated a `d = 1` atom's evidence
2289        // in favour of its straight (Θ→0) sub-model, the model's output
2290        // reconstruction (`fitted()` / `try_fitted` → predict and the user-facing
2291        // output) decodes that slot with its fitted linear image instead of its
2292        // curved decoded curve. The linear images are coordinate-keyed and
2293        // rho-independent (exact weighted-LS lines realised inside the
2294        // adjudication — no re-fit, no #1051 outer continuation).
2295        //
2296        // The collapse engages only when the caller asks for it (`collapse`):
2297        // the production `try_fitted` path and the explicit
2298        // `hybrid_collapsed_reconstruction` entry point. The pure-curved
2299        // `try_fitted_for_rho` opts out — the joint fit's loss/assembly optimise
2300        // the curved decoder coefficients and must see the curved image, and the
2301        // #1026 adjudication itself compares the curved fit against its straight
2302        // sub-model — both require the uncollapsed curve. (During fitting the
2303        // report is `None` regardless; it is only computed post-fit.)
2304        let linear_images = if collapse {
2305            self.hybrid_linear_image_map()
2306        } else {
2307            std::collections::HashMap::new()
2308        };
2309        // Reuse a single scratch buffer across all (row, atom) pairs instead of
2310        // allocating a fresh `Array1<f64>` of length p per call.
2311        let mut g_buf = vec![0.0_f64; p];
2312        for row in 0..n {
2313            let a = match rho {
2314                Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2315                None => self.assignment.try_assignments_row(row)?,
2316            };
2317            for atom_idx in 0..k_atoms {
2318                let a_k = a[atom_idx];
2319                if let Some(image) = linear_images.get(&atom_idx) {
2320                    // Verdict-linear slot: substitute the straight sub-model image
2321                    // at this row's fitted on-atom coordinate — or, for a #1026
2322                    // collapse-rescued slot, at its fresh per-row code.
2323                    let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2324                    image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2325                } else {
2326                    self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2327                }
2328                let mut out_row = out.row_mut(row);
2329                for out_col in 0..p {
2330                    out_row[out_col] += a_k * g_buf[out_col];
2331                }
2332            }
2333        }
2334        Ok(out)
2335    }
2336
2337    /// Per-atom **leave-one-atom-out (LOAO) explained-variance contribution**
2338    /// (#1026): for each atom `k`, the drop in reconstruction explained variance
2339    /// `ΔEV_k = EV(full) − EV(full ⊖ atom_k)` when that atom's contribution
2340    /// `a[i,k]·g_k(coord[i,k])` is removed from the assembled reconstruction and
2341    /// nothing else is refit. Because every atom adds linearly into the same
2342    /// fitted reconstruction (`fitted[i] = Σ_k a[i,k]·g_k`), zeroing one atom is
2343    /// the exact "this atom withheld" counterfactual, and the EV it was earning
2344    /// is `EV(full) − EV(without k)`. This is the per-atom held-out EV
2345    /// attribution the #1026 roadmap pairs with each atom's fitted turning `Θ`:
2346    /// a `Θ ≈ 0` atom earning a large `ΔEV` is a linear-tail direction; a
2347    /// high-`Θ` atom earning a large `ΔEV` is a genuine curved family carrying
2348    /// reconstruction it would otherwise shatter into `N(ε) ≈ Θ/(2√(2ε))` linear
2349    /// directions. Pure read-only diagnostic — never mutates any atom.
2350    ///
2351    /// Returns one `Option<f64>` per atom in atom order; `None` for an atom
2352    /// whose ⊖-reconstruction EV is undefined (degenerate target variance), and
2353    /// `None` for the whole vector if the full-reconstruction EV is undefined.
2354    /// #1026: the load-bearing curved-vs-linear hybrid-split verdict for the
2355    /// fitted dictionary, or `None` until [`Self::canonicalize_charts_post_fit`]
2356    /// has run (or when no `d = 1` atom is eligible). Surfaced in the Python model
2357    /// output so the user sees which atoms genuinely earn their curvature.
2358    pub fn hybrid_split_report(
2359        &self,
2360    ) -> Option<&crate::hybrid_split::SaeHybridSplitReport> {
2361        self.hybrid_split_report.as_ref()
2362    }
2363
2364    /// Build the #1026 curved-vs-linear hybrid-split report by adjudicating each
2365    /// eligible `d = 1` atom's fitted curved image against its straight (linear
2366    /// special-case) sub-model on the common rank-aware Laplace evidence scale.
2367    ///
2368    /// Both candidates are scored against the SAME data — the atom's
2369    /// leave-this-atom-out response residual `y_resp = target − (full − a_k·γ_k)`
2370    /// (#1202) — over its assigned rows: the curved candidate predicts its actual
2371    /// mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2372    /// mass-weighted straight line fit to `y_resp` (the collapsed linear lane —
2373    /// closed form, NOT the broken euclidean outer fit path of #1051). Linear is
2374    /// the curved family's nested `Θ = 0` sub-model on common data, so the
2375    /// per-slot evidence argmin is a genuine match-or-beat comparison. Eligible
2376    /// atoms are `d = 1` atoms with an installed evaluator at the full curvature
2377    /// dial (`homotopy_eta == 1.0`) whose live coordinate dim still matches the
2378    /// atom's latent dim. Returns `None` when no reconstruction `target` is
2379    /// supplied (there is no data to adjudicate against).
2380    pub fn compute_hybrid_split_report(
2381        &self,
2382        rho: &SaeManifoldRho,
2383        target: Option<ArrayView2<'_, f64>>,
2384    ) -> Result<Option<crate::hybrid_split::SaeHybridSplitReport>, String> {
2385        let n = self.n_obs();
2386        let p = self.output_dim();
2387        // Per-atom held-out `ΔEV_k` (leave-one-atom-out explained-variance drop),
2388        // paired with each atom's fitted turning Θ onto the verdict so the report
2389        // carries the #1026 `(Θ, ΔEV)` frontier point as structured data. Absent
2390        // when no reconstruction target is supplied.
2391        let loao_ev: Vec<Option<f64>> = match target {
2392            Some(t) => self.per_atom_loao_explained_variance(t, rho)?,
2393            None => vec![None; self.k_atoms()],
2394        };
2395        let delta_ev_for =
2396            |atom_idx: usize| -> Option<f64> { loao_ev.get(atom_idx).copied().flatten() };
2397        // The common-evidence comparison (#1202) scores both candidates against
2398        // the response data the atom is responsible for. That requires a target;
2399        // with none supplied there is nothing to adjudicate against, so no report.
2400        let Some(target) = target else {
2401            return Ok(None);
2402        };
2403        if target.dim() != (n, p) {
2404            return Err(format!(
2405                "SaeManifoldTerm::compute_hybrid_split_report: target {:?} != ({n}, {p})",
2406                target.dim()
2407            ));
2408        }
2409        // Per-row assignment masses (once), so each atom's weighted straight-line
2410        // fit uses the same row weighting the joint reconstruction loss does.
2411        let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2412        for row in 0..n {
2413            weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2414        }
2415        // The full assembled reconstruction `Σ_k a[i,k]·γ_k`, computed once. Each
2416        // atom's leave-this-atom-out response residual is `y_resp = target −
2417        // (full − a_k·γ_k)`, the data both that atom's candidates fit (#1202).
2418        let full = self.try_fitted_for_rho(rho)?;
2419        let eligible: Vec<usize> = (0..self.k_atoms())
2420            .filter(|&atom_idx| {
2421                let atom = &self.atoms[atom_idx];
2422                atom.latent_dim == 1
2423                    && atom.basis_evaluator.is_some()
2424                    && atom.homotopy_eta == 1.0
2425                    && self.assignment.coords[atom_idx].latent_dim() == atom.latent_dim
2426            })
2427            .collect();
2428        // Per-atom fitted decoded image at every row (the curved candidate's
2429        // realized curve, which the linear candidate must approximate).
2430        let coords_for = |atom_idx: usize| -> Array1<f64> {
2431            self.assignment.coords[atom_idx]
2432                .as_matrix()
2433                .column(0)
2434                .to_owned()
2435        };
2436        let assign_for = |atom_idx: usize| -> Array1<f64> {
2437            Array1::from_iter((0..n).map(|row| weights[row][atom_idx]))
2438        };
2439        let decoded_for = |atom_idx: usize| -> Array2<f64> {
2440            let mut decoded = Array2::<f64>::zeros((n, p));
2441            let mut buf = vec![0.0_f64; p];
2442            for row in 0..n {
2443                self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2444                for col in 0..p {
2445                    decoded[[row, col]] = buf[col];
2446                }
2447            }
2448            decoded
2449        };
2450        // The atom's leave-this-atom-out response residual `y_resp = target −
2451        // (full − a_k·γ_k) = (target − full) + a_k·γ_k`. Both the curved and the
2452        // linear candidate are scored against this on common data (#1202).
2453        let target_resid_for = |atom_idx: usize| -> Array2<f64> {
2454            let mut resid = Array2::<f64>::zeros((n, p));
2455            let mut buf = vec![0.0_f64; p];
2456            for row in 0..n {
2457                let a_k = weights[row][atom_idx];
2458                self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2459                for col in 0..p {
2460                    resid[[row, col]] = target[[row, col]] - full[[row, col]] + a_k * buf[col];
2461                }
2462            }
2463            resid
2464        };
2465        let manifold_for = |atom_idx: usize| -> gam_terms::latent::LatentManifold {
2466            self.assignment.coords[atom_idx].manifold().clone()
2467        };
2468        // #1026 EV-preservation gate denominator: the full target's total
2469        // column-centered variance `SST_full` (the SAME `sst` the reconstruction
2470        // EV is measured against), so the gate vetoes any collapse that would drop
2471        // full-reconstruction EV by more than its tolerance.
2472        let total_centered_variance = {
2473            let mut tss = 0.0_f64;
2474            for col in 0..p {
2475                let mut mean = 0.0_f64;
2476                for row in 0..n {
2477                    mean += target[[row, col]];
2478                }
2479                mean /= n as f64;
2480                for row in 0..n {
2481                    let c = target[[row, col]] - mean;
2482                    tss += c * c;
2483                }
2484            }
2485            tss
2486        };
2487        crate::hybrid_split::build_hybrid_split_report(
2488            &self.atoms,
2489            eligible.into_iter(),
2490            coords_for,
2491            assign_for,
2492            decoded_for,
2493            target_resid_for,
2494            manifold_for,
2495            delta_ev_for,
2496            total_centered_variance,
2497        )
2498    }
2499
2500    pub fn per_atom_loao_explained_variance(
2501        &self,
2502        target: ArrayView2<'_, f64>,
2503        rho: &SaeManifoldRho,
2504    ) -> Result<Vec<Option<f64>>, String> {
2505        let n = self.n_obs();
2506        let p = self.output_dim();
2507        let k_atoms = self.k_atoms();
2508        if target.dim() != (n, p) {
2509            return Err(format!(
2510                "SaeManifoldTerm::per_atom_loao_explained_variance: target {:?} != ({n}, {p})",
2511                target.dim()
2512            ));
2513        }
2514        let full = self.try_fitted_for_rho(rho)?;
2515        let Some(ev_full) = reconstruction_explained_variance(target, full.view()) else {
2516            return Ok(vec![None; k_atoms]);
2517        };
2518        // Cache each row's assignment weights once, then subtract a single
2519        // atom's decoded contribution per LOAO pass instead of reassembling the
2520        // whole dictionary k times.
2521        let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2522        for row in 0..n {
2523            weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2524        }
2525        let mut g_buf = vec![0.0_f64; p];
2526        let mut out = Vec::with_capacity(k_atoms);
2527        for atom_idx in 0..k_atoms {
2528            let mut without = full.clone();
2529            for row in 0..n {
2530                let a_k = weights[row][atom_idx];
2531                if a_k == 0.0 {
2532                    continue;
2533                }
2534                self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2535                let mut without_row = without.row_mut(row);
2536                for out_col in 0..p {
2537                    without_row[out_col] -= a_k * g_buf[out_col];
2538                }
2539            }
2540            out.push(
2541                reconstruction_explained_variance(target, without.view())
2542                    .map(|ev_without| ev_full - ev_without),
2543            );
2544        }
2545        Ok(out)
2546    }
2547
2548    /// #1026 — the LOAD-BEARING collapsed reconstruction: the assembled
2549    /// dictionary output `Σ_k a[i,k]·g_k(coord[i,k])` in which every slot whose
2550    /// hybrid-split verdict selected LINEAR has its curved decoded image replaced
2551    /// by its fitted straight sub-model `b₀ + (t − t̄)·b₁`. This is what makes the
2552    /// verdict *change the reconstruction* instead of merely logging a choice:
2553    /// the linear-collapsed atom no longer pays its `M·p` curved coefficients, it
2554    /// carries a `2·p` straight image whose decoded curve has zero turning.
2555    ///
2556    /// The straight images are the exact weighted-least-squares lines already
2557    /// realized inside [`Self::compute_hybrid_split_report`] (no re-fit, no outer
2558    /// continuation, sidestepping #1051). Returns the curved reconstruction
2559    /// unchanged when no verdict selected linear, or when the report has not been
2560    /// computed yet (`hybrid_split_report == None`).
2561    pub fn hybrid_collapsed_reconstruction(
2562        &self,
2563        rho: &SaeManifoldRho,
2564    ) -> Result<Array2<f64>, String> {
2565        // #1026 — the hybrid collapse is realised by the SINGLE reconstruction
2566        // path ([`Self::try_fitted_with_rho`]) with the collapse flag set: a
2567        // verdict-linear `d = 1` slot decodes its straight sub-model image
2568        // instead of its curved curve. This replaces the dedicated re-collapse
2569        // loop this method used to carry (a parallel layer). The production
2570        // `try_fitted` shares the identical routine at `rho = None`; this entry
2571        // point keeps the rho-keyed collapse for the #1026 EV-dominance reporting
2572        // (`hybrid_collapsed_explained_variance`) and the regression battery.
2573        self.try_fitted_with_rho(Some(rho), true)
2574    }
2575
2576    /// #1026 — the reconstruction explained variance of the hybrid-collapsed
2577    /// dictionary (every verdict-linear slot decoded by its straight sub-model)
2578    /// against `target`. The companion of [`Self::per_atom_loao_explained_variance`]
2579    /// for the dominance claim: because each linear-collapsed slot is the curved
2580    /// family's `Θ → 0` sub-model and is only kept when its evidence beats the
2581    /// curved candidate's parameter price, the collapsed dictionary match-or-beats
2582    /// the all-curved one on EV-per-parameter — the strict-generalization floor
2583    /// the #1026 hybrid argument rests on. `None` when EV is undefined (degenerate
2584    /// target variance).
2585    pub fn hybrid_collapsed_explained_variance(
2586        &self,
2587        target: ArrayView2<'_, f64>,
2588        rho: &SaeManifoldRho,
2589    ) -> Result<Option<f64>, String> {
2590        let n = self.n_obs();
2591        let p = self.output_dim();
2592        if target.dim() != (n, p) {
2593            return Err(format!(
2594                "SaeManifoldTerm::hybrid_collapsed_explained_variance: target {:?} != ({n}, {p})",
2595                target.dim()
2596            ));
2597        }
2598        let collapsed = self.hybrid_collapsed_reconstruction(rho)?;
2599        Ok(reconstruction_explained_variance(target, collapsed.view()))
2600    }
2601
2602    /// #1026 ladder item 2/3 — the AMORTIZED ENCODER, wired from the fitted
2603    /// dictionary. Builds the offline certified [`EncodeAtlas`] over this term's
2604    /// frozen atoms and encodes a target corpus `targets` (`n × p`) through the
2605    /// per-chart distilled Jacobian predictor, with the Kantorovich certificate
2606    /// gating each row and an exact-solve fallback for the rows the amortized
2607    /// predictor cannot certify. Returns one [`EncodeResult`] per atom (the
2608    /// per-atom encoded coordinates + per-row certificate mask), in dictionary
2609    /// order.
2610    ///
2611    /// This is the thread's "encoder + certificate-gated exact fallback"
2612    /// deployment made reachable from a fit: the distilled map approximates
2613    /// inference at one mat-vec/row, and any row whose amortized prediction fails
2614    /// `h ≤ ½` falls back to the certified IFT-warm-start Newton encode
2615    /// ([`EncodeAtlas::certified_encode_row`]); rows that still cannot be
2616    /// certified ride the [`EncodeResult::encode_uncertified_count`] flag for the
2617    /// upstream exact multi-start solve (honesty, never a silent wrong encode).
2618    ///
2619    /// Magic by default: the atlas's worst-case bounds are auto-derived from the
2620    /// fit — `amplitude_bound[k]` is the largest fitted assignment mass `a[i,k]`
2621    /// the encode can produce for atom `k` (the encode recovers `t` from
2622    /// `x ≈ z·γ_k(t)` at amplitude `z = a[i,k]`), and `target_norm_bound` is the
2623    /// largest target row norm — so no caller supplies a knob. Per-row amplitudes
2624    /// are the fitted assignment masses for the same target the dictionary was fit
2625    /// against; an external corpus reuses the per-row masses the assignment
2626    /// produces for it upstream (passed in `amplitudes`, one column per atom).
2627    pub fn amortized_encode_target(
2628        &self,
2629        targets: ArrayView2<'_, f64>,
2630        amplitudes: ArrayView2<'_, f64>,
2631    ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2632        let p = self.output_dim();
2633        let k_atoms = self.k_atoms();
2634        let n = targets.nrows();
2635        if targets.ncols() != p {
2636            return Err(format!(
2637                "SaeManifoldTerm::amortized_encode_target: targets have {} cols but output_dim is {p}",
2638                targets.ncols()
2639            ));
2640        }
2641        if amplitudes.dim() != (n, k_atoms) {
2642            return Err(format!(
2643                "SaeManifoldTerm::amortized_encode_target: amplitudes {:?} must be (n={n}, K={k_atoms})",
2644                amplitudes.dim()
2645            ));
2646        }
2647
2648        // Magic-by-default offline bounds, auto-derived from the fit so no caller
2649        // supplies a knob. `target_norm_bound` is the largest target row L2 norm
2650        // (bounds `‖x‖` over the corpus); `amplitude_bound[k]` is the largest
2651        // fitted assignment mass for atom `k` (bounds `|z_k|`), with a strictly
2652        // positive floor so a near-inactive atom still certifies a finite radius.
2653        let mut target_norm_bound = 0.0_f64;
2654        for row in 0..n {
2655            let norm = targets.row(row).dot(&targets.row(row)).sqrt();
2656            if norm.is_finite() && norm > target_norm_bound {
2657                target_norm_bound = norm;
2658            }
2659        }
2660        let mut amplitude_bound = vec![0.0_f64; k_atoms];
2661        for atom_idx in 0..k_atoms {
2662            let mut bound = 0.0_f64;
2663            for row in 0..n {
2664                let z = amplitudes[[row, atom_idx]].abs();
2665                if z.is_finite() && z > bound {
2666                    bound = z;
2667                }
2668            }
2669            // A strictly positive amplitude floor keeps the offline Lipschitz
2670            // scaling finite for atoms with no active row in this corpus (those
2671            // rows encode to the chart center via the certificate anyway).
2672            amplitude_bound[atom_idx] = bound.max(1.0);
2673        }
2674
2675        let atlas = crate::encode::EncodeAtlas::build(
2676            &self.atoms,
2677            &amplitude_bound,
2678            target_norm_bound,
2679            crate::encode::AtlasConfig::default(),
2680        )?;
2681
2682        // Per-atom amortized encode with a certificate-gated exact-solve fallback:
2683        // a row whose distilled prediction fails `h ≤ ½` is retried through the
2684        // certified IFT-warm-start Newton path; a row that still cannot be
2685        // certified stays flagged for the upstream multi-start solve.
2686        // (The atlas is rho-free; the per-row amplitudes already carry the
2687        // rho-resolved assignment masses the caller produced upstream.)
2688        let mut results = Vec::with_capacity(k_atoms);
2689        for atom_idx in 0..k_atoms {
2690            let atom = &self.atoms[atom_idx];
2691            let amp_col = amplitudes.column(atom_idx).to_owned();
2692            let amortized =
2693                atlas.amortized_encode_batch(atom, atom_idx, targets, amp_col.view())?;
2694            let mut coords = amortized.coords;
2695            let mut certified = amortized.certified;
2696            for row in 0..n {
2697                if certified[row] {
2698                    continue;
2699                }
2700                let (t, cert) =
2701                    atlas.certified_encode_row(atom, atom_idx, targets.row(row), amp_col[row])?;
2702                if cert.certified() {
2703                    coords.row_mut(row).assign(&t);
2704                    certified[row] = true;
2705                }
2706            }
2707            results.push(crate::encode::EncodeResult::from_rows(
2708                coords, certified,
2709            ));
2710        }
2711        Ok(results)
2712    }
2713
2714    /// #1026 — the fitted per-row assignment masses `a[i,k]` (the activation
2715    /// amplitudes `z_k` the amortized encode recovers `t` against), as an
2716    /// `n × K` matrix. These are exactly the masses
2717    /// [`Self::try_fitted_with_rho`] assembles the reconstruction from, so
2718    /// feeding them to [`Self::amortized_encode_target`] re-encodes the SAME
2719    /// inference the dictionary was fit against — the self-consistency the
2720    /// distilled encoder is supervised to approximate.
2721    pub fn fitted_assignment_amplitudes(
2722        &self,
2723        rho: &SaeManifoldRho,
2724    ) -> Result<Array2<f64>, String> {
2725        let n = self.n_obs();
2726        let k_atoms = self.k_atoms();
2727        let mut amplitudes = Array2::<f64>::zeros((n, k_atoms));
2728        for row in 0..n {
2729            let a = self.assignment.try_assignments_row_for_rho(row, rho)?;
2730            for atom_idx in 0..k_atoms {
2731                amplitudes[[row, atom_idx]] = a[atom_idx];
2732            }
2733        }
2734        Ok(amplitudes)
2735    }
2736
2737    /// #1026 — encode the dictionary's own fit-time target with the amortized
2738    /// encoder, deriving the per-row amplitudes from the fitted assignment so the
2739    /// caller supplies neither bounds nor amplitudes (magic by default). The
2740    /// end-to-end "fit → distilled encoder → certificate-gated encode" path.
2741    pub fn amortized_encode_fitted(
2742        &self,
2743        targets: ArrayView2<'_, f64>,
2744        rho: &SaeManifoldRho,
2745    ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2746        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2747        self.amortized_encode_target(targets, amplitudes.view())
2748    }
2749
2750    /// #1154 — amortized-encoder consistency of the CURRENT dictionary against
2751    /// its own fit-time target. This is the co-training signal of the joint
2752    /// amortized-encoder + REML loop (Design A): the amortized (one-mat-vec)
2753    /// encode is built from the *current* fitted decoder, run on `targets`, and
2754    /// scored on two principled axes —
2755    ///
2756    /// * `recon_consistency` (the bilinear part of the co-training loss): the
2757    ///   mean per-element squared gap between the **amortized** reconstruction
2758    ///   `Σ_k z_k · Φ_k(t̂_k) B_k` (decode the amortized coords) and the
2759    ///   **exact** fitted reconstruction `Σ_k z_k · Φ_k(t_k^*) B_k` the inner
2760    ///   solve converged to. A dictionary whose encode map is well-approximated
2761    ///   to first order by the per-chart IFT predictor scores near zero; a
2762    ///   dictionary the amortized encoder *cannot* invert faithfully (sharp
2763    ///   curvature, poorly-charted regions) scores high. Minimising this jointly
2764    ///   with REML steers the fit toward dictionaries that admit a fast,
2765    ///   faithful amortized encode — the architectural co-adaptation #1154 adds.
2766    /// * `uncertified_fraction`: the share of (row, atom) encodes whose
2767    ///   Kantorovich certificate failed (`h > ½`), i.e. that fell back to the
2768    ///   certified IFT-warm-start Newton. This is the encoder's *certifiable coverage*
2769    ///   of the dictionary; co-training rewards dictionaries the cheap encode
2770    ///   certifies, not just ones it happens to land.
2771    ///
2772    /// The certificate keeps every accepted amortized coord honest (uncertified
2773    /// rows already ride the exact fallback inside `amortized_encode_target`), so
2774    /// this metric never silently trusts a wrong encode — it MEASURES how much of
2775    /// the dictionary the cheap encoder can faithfully and certifiably invert.
2776    pub fn amortized_encoder_consistency(
2777        &self,
2778        targets: ArrayView2<'_, f64>,
2779        rho: &SaeManifoldRho,
2780    ) -> Result<AmortizedEncoderConsistency, String> {
2781        let n = self.n_obs();
2782        let p = self.output_dim();
2783        let k_atoms = self.k_atoms();
2784        if targets.dim() != (n, p) {
2785            return Err(format!(
2786                "SaeManifoldTerm::amortized_encoder_consistency: targets {:?} must be (n={n}, p={p})",
2787                targets.dim()
2788            ));
2789        }
2790        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2791        let encodes = self.amortized_encode_target(targets, amplitudes.view())?;
2792        // The EXACT fitted reconstruction the inner solve converged to (pure
2793        // curved image, rho-keyed) is the supervision target for the amortized
2794        // reconstruction. Both are n×p ambient, so the comparison is layout-free.
2795        let exact_recon = self.try_fitted_for_rho(rho)?;
2796
2797        // Build the amortized reconstruction Σ_k z_k · Φ_k(t̂_k) B_k by decoding
2798        // each atom's amortized coords through that atom's own basis evaluator.
2799        let mut amortized_recon = Array2::<f64>::zeros((n, p));
2800        let mut uncertified = 0usize;
2801        for atom_idx in 0..k_atoms {
2802            let atom = &self.atoms[atom_idx];
2803            let result = &encodes[atom_idx];
2804            // An atom with no basis evaluator cannot decode an amortized
2805            // reconstruction; every one of its rows is necessarily uncertified
2806            // (the encode flagged them all), so it contributes nothing to the
2807            // amortized recon and its full row-count to the uncertified tally.
2808            // Count it and skip the decode rather than erroring — the consistency
2809            // fold stays a bounded penalty, never a hard abort of the criterion.
2810            let Some(evaluator) = atom.basis_evaluator.as_ref() else {
2811                uncertified += n;
2812                continue;
2813            };
2814            uncertified += result.encode_uncertified_count;
2815            // Decode the amortized coords: Φ_k(t̂) is (n × M_k); B_k is (M_k × p).
2816            let (phi, _jac) = evaluator.evaluate(result.coords.view())?;
2817            let decoded = phi.dot(&atom.decoder_coefficients); // (n × p)
2818            for row in 0..n {
2819                let z = amplitudes[[row, atom_idx]];
2820                if z == 0.0 {
2821                    continue;
2822                }
2823                for col in 0..p {
2824                    amortized_recon[[row, col]] += z * decoded[[row, col]];
2825                }
2826            }
2827        }
2828
2829        let mut sse = 0.0_f64;
2830        for row in 0..n {
2831            for col in 0..p {
2832                let gap = amortized_recon[[row, col]] - exact_recon[[row, col]];
2833                sse += gap * gap;
2834            }
2835        }
2836        let denom = (n.max(1) * p.max(1)) as f64;
2837        let recon_consistency = sse / denom;
2838        let total_encodes = (n * k_atoms).max(1) as f64;
2839        let uncertified_fraction = uncertified as f64 / total_encodes;
2840
2841        Ok(AmortizedEncoderConsistency {
2842            recon_consistency,
2843            uncertified_fraction,
2844            n_uncertified: uncertified,
2845            n_encodes: n * k_atoms,
2846        })
2847    }
2848
2849    /// #1154 — the co-trained REML criterion: the exact REML criterion at `rho`
2850    /// PLUS the amortized-encoder consistency penalty, so the outer optimizer
2851    /// co-adapts the dictionary + smoothing parameters λ TOWARD a dictionary the
2852    /// fast amortized encoder can faithfully and certifiably invert.
2853    ///
2854    /// This is Design A of #1154. The inner solve still converges the `(t, β)`
2855    /// system to stationarity at the engine's current ρ (so the implicit-function
2856    /// REML λ-gradient `dβ̂/dλ = −(H+S_λ)⁻¹(dS_λ/dλ)β̂` stays EXACT — the encoder
2857    /// only warm-starts/co-adapts, it never replaces the stationary point). The
2858    /// added term
2859    ///
2860    /// ```text
2861    ///   J_cotrain(ρ) = REML(ρ)  +  w · ‖x̂_amortized − x̂_exact‖²/(n·p)
2862    ///                            +  w_cert · uncertified_fraction
2863    /// ```
2864    ///
2865    /// folds the post-fit amortized-encode quality into the ranked objective. The
2866    /// weights are auto-scaled to the REML criterion magnitude (magic by default:
2867    /// no caller knob) so the consistency term is a meaningful but non-dominant
2868    /// fraction of the objective regardless of problem scale.
2869    pub fn reml_criterion_cotrained(
2870        &mut self,
2871        target: ArrayView2<'_, f64>,
2872        rho: &SaeManifoldRho,
2873        registry: Option<&AnalyticPenaltyRegistry>,
2874        inner_max_iter: usize,
2875        learning_rate: f64,
2876        ridge_ext_coord: f64,
2877        ridge_beta: f64,
2878    ) -> Result<(f64, SaeManifoldLoss, AmortizedEncoderConsistency), String> {
2879        // #1154: always attempt the amortized warm-start first inside
2880        // `reml_criterion_cotrained` (the encode/warm path for the cotrained
2881        // objective). Good warm-starts from the running dictionary land the
2882        // inner solve closer to the stationary point used for the fold.
2883        // Advisory only (0 or err falls back to cold); telemetry recorded by
2884        // outer objective callers when present.
2885        self.warm_start_latents_from_amortized_encoder(target, rho)
2886            .unwrap_or(0);
2887        let (reml, loss) = self.reml_criterion_with_refine_policy(
2888            target,
2889            rho,
2890            registry,
2891            inner_max_iter,
2892            learning_rate,
2893            ridge_ext_coord,
2894            ridge_beta,
2895            true,
2896        )?;
2897        let consistency = self.amortized_encoder_consistency(target, rho)?;
2898        // Auto-scale the co-training weights to the REML magnitude so the
2899        // consistency penalty is a bounded, scale-free fraction of the objective
2900        // (magic by default: no caller knob). `reml_scale` floors at 1 so a
2901        // near-zero criterion still admits a meaningful consistency contribution.
2902        let cotrained = Self::fold_cotrain_consistency(reml, &consistency);
2903        Ok((cotrained, loss, consistency))
2904    }
2905
2906    /// #1154 — the single source of the co-training fold arithmetic: add the
2907    /// auto-scaled amortized-encoder consistency penalty to an already-computed
2908    /// REML criterion at the converged dictionary. Both the public
2909    /// [`Self::reml_criterion_cotrained`] entry point and the outer-loop value /
2910    /// gradient lanes (`SaeManifoldOuterObjective::fold_cotrain_consistency`)
2911    /// route through THIS function, so the folded objective cannot drift between
2912    /// the criterion and the cascade-ranked cost (the objective↔gradient desync
2913    /// bug class). The weights are auto-scaled to the REML magnitude (`max(|REML|,
2914    /// 1)`) so the penalty is a bounded, scale-free fraction of the objective
2915    /// regardless of problem scale; the fold carries no analytic gradient (under
2916    /// Design A the REML λ-gradient stays the exact implicit-function path).
2917    #[must_use]
2918    pub fn fold_cotrain_consistency(
2919        reml_cost: f64,
2920        consistency: &AmortizedEncoderConsistency,
2921    ) -> f64 {
2922        let reml_scale = reml_cost.abs().max(1.0);
2923        reml_cost
2924            + COTRAIN_RECON_WEIGHT * reml_scale * consistency.recon_consistency
2925            + COTRAIN_CERT_WEIGHT * reml_scale * consistency.uncertified_fraction
2926    }
2927
2928    /// #1154 item 2 — warm-start the inner latent coordinates from the amortized
2929    /// encoder (Design A). Builds the per-chart IFT-Jacobian atlas from the
2930    /// CURRENT dictionary, runs the one-mat-vec amortized encode of `target`
2931    /// against each atom at the rho-resolved assignment masses, and overwrites
2932    /// each atom's stored latent coords with the predicted `t̂` ON THE ROWS THE
2933    /// KANTOROVICH CERTIFICATE ACCEPTS. Uncertified rows are left at their
2934    /// current coords (the previous-iterate start), so the
2935    /// warm-start can only HELP — a row the cheap predictor cannot certify never
2936    /// corrupts the seed. The subsequent inner Newton refines from this seed to
2937    /// the SAME stationary point (the warm-start changes only the basin entry,
2938    /// not the root), so the REML λ-gradient stays exactly the implicit-function
2939    /// path and the criterion is unchanged at convergence — the amortized encoder
2940    /// only accelerates/co-adapts the inner solve, it never replaces the
2941    /// stationary point.
2942    ///
2943    /// Returns the number of (row, atom) coords actually warm-started (the
2944    /// certified-prediction count), for instrumentation / tests. A first-build
2945    /// dictionary with no usable charts simply warm-starts nothing and returns 0
2946    /// (the cold path is byte-for-byte unchanged).
2947    pub fn warm_start_latents_from_amortized_encoder(
2948        &mut self,
2949        target: ArrayView2<'_, f64>,
2950        rho: &SaeManifoldRho,
2951    ) -> Result<usize, String> {
2952        let n = self.n_obs();
2953        let k_atoms = self.k_atoms();
2954        if n == 0 || k_atoms == 0 {
2955            return Ok(0);
2956        }
2957        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2958        let encodes = self.amortized_encode_target(target, amplitudes.view())?;
2959        let mut warm_started = 0usize;
2960        for atom_idx in 0..k_atoms {
2961            let d = self.atoms[atom_idx].latent_dim;
2962            if d == 0 {
2963                continue;
2964            }
2965            let result = &encodes[atom_idx];
2966            // Start from the atom's CURRENT coords so uncertified rows are left
2967            // exactly as they were; overwrite only the certified predictions.
2968            let mut coords = self.assignment.coords[atom_idx].as_matrix();
2969            if coords.dim() != (n, d) {
2970                return Err(format!(
2971                    "warm_start_latents_from_amortized_encoder: atom {atom_idx} coords {:?} != (n={n}, d={d})",
2972                    coords.dim()
2973                ));
2974            }
2975            for row in 0..n {
2976                if !result.certified[row] {
2977                    continue;
2978                }
2979                for axis in 0..d {
2980                    coords[[row, axis]] = result.coords[[row, axis]];
2981                }
2982                warm_started += 1;
2983            }
2984            // `as_matrix` lays coords out row-major (`[[row, axis]]`), exactly the
2985            // `values[row*d + axis]` order `set_flat` expects, so a plain
2986            // row-major iterator reconstructs the flat vector.
2987            let flat = Array1::from_iter(coords.iter().copied());
2988            self.assignment.coords[atom_idx].set_flat(flat.view());
2989        }
2990        // The basis caches must follow the freshly-seeded coords so the next
2991        // inner solve evaluates Φ at the warm-started t̂, not the stale coords.
2992        self.refresh_basis_from_current_coords()?;
2993        Ok(warm_started)
2994    }
2995
2996    pub fn loss(
2997        &self,
2998        target: ArrayView2<'_, f64>,
2999        rho: &SaeManifoldRho,
3000    ) -> Result<SaeManifoldLoss, String> {
3001        self.loss_scaled(target, rho, 1.0)
3002    }
3003
3004    /// Penalized objective with a `penalty_scale` applied to the β-tier
3005    /// (decoder smoothness) penalty, mirroring
3006    /// [`Self::assemble_arrow_schur_scaled`]. The streaming line search sums
3007    /// per-chunk `loss_scaled(..., n_chunk / N)` so that the global smoothness
3008    /// penalty is counted exactly once across a pass while the per-row data,
3009    /// assignment-prior, and ARD terms sum naturally. `penalty_scale == 1.0`
3010    /// recovers the full-batch objective.
3011    pub fn loss_scaled(
3012        &self,
3013        target: ArrayView2<'_, f64>,
3014        rho: &SaeManifoldRho,
3015        penalty_scale: f64,
3016    ) -> Result<SaeManifoldLoss, String> {
3017        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3018            return Err(format!(
3019                "SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3020            ));
3021        }
3022        if target.dim() != (self.n_obs(), self.output_dim()) {
3023            return Err(format!(
3024                "SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
3025                self.n_obs(),
3026                self.output_dim(),
3027                target.dim()
3028            ));
3029        }
3030        // The likelihood whitens through the RowMetric **only** when the metric
3031        // is a genuinely estimated noise model (`metric.whitens_likelihood()`,
3032        // i.e. `WhitenedStructured` — the #974 residual-covariance seam). For
3033        // Euclidean (default `None`) and for the OutputFisher *gauge* metric the
3034        // reconstruction data-fit stays the isotropic `0.5 * Σ r²`: a gauge /
3035        // output-Fisher inner product must NOT silently replace the
3036        // reconstruction loss with a Fisher pullback (#980). It only drives the
3037        // gauge (see `analytic_penalties::corrected_isometry_penalty`). The
3038        // producer of `WhitenedStructured` is
3039        // `inference::residual_factor::StructuredResidualModel::row_metric`; the
3040        // SAME metric whitens the assembled gradient/Hessian in
3041        // `assemble_arrow_schur` (the single #974 seam), so this value and that
3042        // gradient cannot desync. Without a whitening metric this path is
3043        // bit-for-bit the historical isotropic data-fit.
3044        let whitens = self
3045            .row_metric
3046            .as_ref()
3047            .is_some_and(|metric| metric.whitens_likelihood());
3048        // #991 design honesty weights: the reconstruction channel of row `i`
3049        // is weighted by `w_i` (mean-1 HT inclusion correction). The assembly
3050        // applies the same `w_i` via a `√w_i` scaling of the row residual /
3051        // Jacobian / β load at its single seam, so this value and that
3052        // gradient/Hessian carry the identical per-row factor. `None` ⇒ the
3053        // historical unweighted sum, bit-for-bit.
3054        let row_loss_w = self.row_loss_weights.as_deref();
3055        let n = self.n_obs();
3056        let p = self.output_dim();
3057        let k_atoms = self.k_atoms();
3058        // #1017: the data-fit is the dominant per-line-search-trial cost (it
3059        // re-runs every Armijo halving × every inner Newton iteration × every
3060        // outer ρ evaluation). The old path materialised the whole `n × p`
3061        // fitted matrix (`try_fitted_for_rho`) and then walked it AGAIN to form
3062        // the residual sum — two sequential `n·p` passes plus an `n·p`
3063        // allocation per trial. Fuse the reconstruction and the residual reduce
3064        // into ONE row-parallel pass that never materialises the fitted matrix:
3065        // each row decodes its atoms into per-worker scratch, differences
3066        // against the target, and contributes its scalar `0.5·w·‖r‖²` to a
3067        // chunk-ordered fold (bit-identical run-to-run). Per-worker scratch
3068        // (`map_init`) keeps the only allocations one `g_buf`/`fitted_row` pair
3069        // per rayon thread rather than per row. Stay sequential inside a worker
3070        // (the topology race owns the outer pool) to avoid nested
3071        // oversubscription.
3072        let parallel = n >= SAE_LOSS_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
3073        let row_data_fit =
3074            |row: usize,
3075             g_buf: &mut [f64],
3076             fitted_row: &mut [f64],
3077             assign_buf: &mut [f64]|
3078             -> Result<f64, String> {
3079                // #1557 — fill the per-atom assignment row into reused per-worker
3080                // scratch via the `_into` twin instead of heap-allocating a fresh
3081                // `Array1` per row per loss eval. Bit-identical to the allocating
3082                // `try_assignments_row_for_rho` (same arithmetic, same order); this
3083                // loss reruns every Armijo halving × inner Newton iter × outer ρ
3084                // eval, so the per-row K-sized allocation was a hot-path churn.
3085                self.assignment
3086                    .try_assignments_row_for_rho_into(row, rho, assign_buf)?;
3087                let a = &*assign_buf;
3088                for slot in fitted_row.iter_mut() {
3089                    *slot = 0.0;
3090                }
3091                for atom_idx in 0..k_atoms {
3092                    self.atoms[atom_idx].fill_decoded_row(row, g_buf);
3093                    let a_k = a[atom_idx];
3094                    for out_col in 0..p {
3095                        fitted_row[out_col] += a_k * g_buf[out_col];
3096                    }
3097                }
3098                for out_col in 0..p {
3099                    fitted_row[out_col] = target[[row, out_col]] - fitted_row[out_col];
3100                }
3101                let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3102                let mut acc = 0.0_f64;
3103                match self.row_metric.as_ref() {
3104                    Some(metric) if whitens => {
3105                        let resid = ArrayView1::from(&fitted_row[..p]);
3106                        for w in metric.whiten_residual_row(row, resid) {
3107                            acc += 0.5 * w_row * w * w;
3108                        }
3109                    }
3110                    _ => {
3111                        for &r in fitted_row[..p].iter() {
3112                            acc += 0.5 * w_row * r * r;
3113                        }
3114                    }
3115                }
3116                Ok(acc)
3117            };
3118        let data_fit = if parallel {
3119            use rayon::prelude::*;
3120            const CHUNK: usize = 32;
3121            let partials: Vec<Result<f64, String>> = (0..n)
3122                .into_par_iter()
3123                .chunks(CHUNK)
3124                .map_init(
3125                    || (vec![0.0_f64; p], vec![0.0_f64; p], vec![0.0_f64; k_atoms]),
3126                    |(g_buf, fitted_row, assign_buf), idxs| {
3127                        // #1557 — pin any faer GEMM reached from this row-parallel
3128                        // data-fit chunk to `Par::Seq` (no nested Rayon re-fan); the
3129                        // per-row reductions are tiny, so the result is bit-identical.
3130                        with_nested_parallel(|| {
3131                            let mut acc = 0.0_f64;
3132                            for row in idxs {
3133                                acc += row_data_fit(row, g_buf, fitted_row, assign_buf)?;
3134                            }
3135                            Ok(acc)
3136                        })
3137                    },
3138                )
3139                .collect();
3140            let mut total = 0.0_f64;
3141            for partial in partials {
3142                total += partial?;
3143            }
3144            total
3145        } else {
3146            let mut g_buf = vec![0.0_f64; p];
3147            let mut fitted_row = vec![0.0_f64; p];
3148            let mut assign_buf = vec![0.0_f64; k_atoms];
3149            let mut total = 0.0_f64;
3150            for row in 0..n {
3151                total += row_data_fit(row, &mut g_buf, &mut fitted_row, &mut assign_buf)?;
3152            }
3153            total
3154        };
3155        let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
3156        let smoothness = penalty_scale * self.decoder_smoothness_value(&rho.lambda_smooth_vec());
3157        let ard = self.ard_value(rho)?;
3158        Ok(SaeManifoldLoss {
3159            data_fit,
3160            assignment_sparsity,
3161            smoothness,
3162            ard,
3163            evidence_gauge_deflated_directions: 0,
3164        })
3165    }
3166
3167    /// Reconstruction data-fit `0.5·Σ_i w_i·‖whiten(Z_i − R_i)‖²` for an EXPLICIT
3168    /// reconstruction matrix `R` (e.g. the hard top-k–projected `fitted`), using
3169    /// the SAME per-row metric and design-honesty weights as [`Self::loss_scaled`]
3170    /// (the soft-assignment data-fit). The only difference is the residual source:
3171    /// `loss_scaled` decodes the soft assignments on the fly, this consumes a
3172    /// reconstruction the caller already assembled (so the projected loss and the
3173    /// returned projected `fitted` describe one and the same model). The penalty
3174    /// terms (`assignment_sparsity`/`smoothness`/`ard`) are decoder/ρ properties
3175    /// the top-k gate does not change, so the caller keeps them from the soft
3176    /// `loss_scaled` and only swaps this data-fit in — see #1232.
3177    pub fn data_fit_for_reconstruction(
3178        &self,
3179        target: ArrayView2<'_, f64>,
3180        reconstruction: ArrayView2<'_, f64>,
3181    ) -> Result<f64, String> {
3182        let n = self.n_obs();
3183        let p = self.output_dim();
3184        if target.dim() != (n, p) {
3185            return Err(format!(
3186                "SaeManifoldTerm::data_fit_for_reconstruction: Z must be ({n}, {p}); got {:?}",
3187                target.dim()
3188            ));
3189        }
3190        if reconstruction.dim() != (n, p) {
3191            return Err(format!(
3192                "SaeManifoldTerm::data_fit_for_reconstruction: reconstruction must be ({n}, {p}); got {:?}",
3193                reconstruction.dim()
3194            ));
3195        }
3196        let whitens = self
3197            .row_metric
3198            .as_ref()
3199            .is_some_and(|metric| metric.whitens_likelihood());
3200        let row_loss_w = self.row_loss_weights.as_deref();
3201        let mut resid = vec![0.0_f64; p];
3202        let mut total = 0.0_f64;
3203        for row in 0..n {
3204            for out_col in 0..p {
3205                resid[out_col] = target[[row, out_col]] - reconstruction[[row, out_col]];
3206            }
3207            let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3208            match self.row_metric.as_ref() {
3209                Some(metric) if whitens => {
3210                    let r = ArrayView1::from(&resid[..p]);
3211                    for w in metric.whiten_residual_row(row, r) {
3212                        total += 0.5 * w_row * w * w;
3213                    }
3214                }
3215                _ => {
3216                    for &r in resid[..p].iter() {
3217                        total += 0.5 * w_row * r * r;
3218                    }
3219                }
3220            }
3221        }
3222        Ok(total)
3223    }
3224
3225    pub fn analytic_penalty_value_total(
3226        &self,
3227        registry: &AnalyticPenaltyRegistry,
3228        penalty_scale: f64,
3229    ) -> Result<f64, ArrowSchurError> {
3230        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3231            return Err(ArrowSchurError::SchurFactorFailed {
3232                reason: format!(
3233                    "SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
3234                     and positive; got {penalty_scale}"
3235                ),
3236            });
3237        }
3238        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3239        let layout = registry.rho_layout();
3240        let beta = self.flatten_beta();
3241        let mut value = 0.0_f64;
3242        for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
3243            let rho_local = rho_global.slice(s![rho_slice.clone()]);
3244            // Skip the registry `ARDPenalty` here for the same reason it is
3245            // skipped in `add_sae_analytic_penalty_contributions`: the coordinate
3246            // ARD energy is already counted by `loss.ard` (the von-Mises
3247            // `ard_value`), and the registry penalty's legacy Gaussian `½λt²` is
3248            // period-discontinuous. Including it would double-count the energy and
3249            // make this line-search objective jump across the branch cut while the
3250            // assembled gradient (von-Mises only, after the assembly fix) stays
3251            // continuous — i.e. a near-zero step would change the objective by a
3252            // finite amount and Armijo would wrongly reject it.
3253            if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
3254                continue;
3255            }
3256            match tier {
3257                PenaltyTier::Psi => {
3258                    if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
3259                        for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3260                            value += penalty_scale
3261                                * per_atom.value(beta.slice(s![start..end]), rho_local);
3262                        }
3263                    } else {
3264                        if !sae_penalty_is_row_block_supported(penalty) {
3265                            return Err(ArrowSchurError::SchurFactorFailed {
3266                                reason: format!(
3267                                    "validate_analytic_penalty_registry should have refused \
3268                                     non-row-block Psi-tier penalty {:?} (registry layout name \
3269                                     {name:?})",
3270                                    penalty.name()
3271                                ),
3272                            });
3273                        }
3274                        for atom_idx in 0..self.k_atoms() {
3275                            let coord = &self.assignment.coords[atom_idx];
3276                            if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3277                                let corrected_kind =
3278                                    self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3279                                value += corrected_kind.value(coord.as_flat().view(), rho_local);
3280                            } else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
3281                                // Origin-anchored magnitude shrinkage (SCAD/MCP) is
3282                                // restricted to the Euclidean axes; periodic axes have
3283                                // no chart origin and would make this energy
3284                                // period-discontinuous (issue #795). This must mirror
3285                                // the gradient/curvature assembly in
3286                                // `add_sae_coord_penalty` exactly.
3287                                match sae_coord_penalty_euclidean_restriction(coord) {
3288                                    Some((_axes, compacted)) => {
3289                                        value += penalty.value(compacted.view(), rho_local);
3290                                    }
3291                                    None => {
3292                                        value += penalty.value(coord.as_flat().view(), rho_local);
3293                                    }
3294                                }
3295                            } else {
3296                                value += penalty.value(coord.as_flat().view(), rho_local);
3297                            }
3298                        }
3299                    }
3300                }
3301                PenaltyTier::Beta => {
3302                    if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
3303                        if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3304                            value += penalty_scale * per_fit.value(beta.view(), rho_local);
3305                        }
3306                    } else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
3307                        for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3308                            if start < end {
3309                                value += penalty_scale * per_atom.value(beta.view(), rho_local);
3310                            }
3311                        }
3312                    } else {
3313                        value += penalty_scale * penalty.value(beta.view(), rho_local);
3314                    }
3315                }
3316                PenaltyTier::Rho => {}
3317            }
3318        }
3319        Ok(value)
3320    }
3321
3322    /// Energy of the decoder-block analytic penalties that have no native
3323    /// `SaeManifoldLoss` counterpart, evaluated at the current decoder `β` and
3324    /// the converged SAE state. These act on the per-atom decoder coefficient
3325    /// matrices: cross-atom decoder incoherence (#671), mechanism
3326    /// (feature-group) sparsity, and nuclear-norm embedding rank (#672). Each
3327    /// is injected with its live per-atom shape / co-activation before its
3328    /// value is taken, mirroring the assemble path.
3329    ///
3330    /// This is deliberately narrower than [`Self::analytic_penalty_value_total`]:
3331    /// it excludes the Psi-tier coordinate / assignment penalties (ARD,
3332    /// Isometry, ScadMcp, BlockOrthogonality, IBP/softmax assignment sparsity).
3333    /// The SAE already carries its own ARD (`loss.ard`) and assignment sparsity
3334    /// (`loss.assignment_sparsity`) energy, so adding the registry ARD /
3335    /// assignment value on top would double-count, and the gauge-only
3336    /// coordinate penalties are not part of the penalized deviance the
3337    /// REML/Laplace criterion scores. The decoder-block penalties, by contrast,
3338    /// are real penalized-energy terms with no `loss.*` representative: the
3339    /// inner solve minimizes them (they enter `gb`/`hbb`) but they were absent
3340    /// from the criterion scalar `v`. This restores that consistency so the
3341    /// ρ-sweep ranks the same objective the inner solve descends — the #671
3342    /// incoherence lever in particular now shapes model selection, not just the
3343    /// Newton step.
3344    ///
3345    /// NOTE: the coordinate-block penalties with no native `loss.*` twin
3346    /// (`ScadMcp`, `BlockOrthogonality`) carry the same residual inconsistency
3347    /// (scored in the line search via `penalized_objective_total`, absent from
3348    /// the REML scalar). They are left out here because they share a registry
3349    /// dispatch with the always-on `Isometry` gauge, whose inclusion in the
3350    /// topology-comparison criterion is a separate design question (#673:
3351    /// topology evidence is gauge-conditional). Folding the coord-tier energy in
3352    /// is tracked apart from this #671 decoder fix.
3353    pub fn analytic_decoder_penalty_value_total(
3354        &self,
3355        registry: &AnalyticPenaltyRegistry,
3356    ) -> Result<f64, ArrowSchurError> {
3357        // Resolve each penalty's rho slice exactly as `analytic_penalty_value_total`
3358        // does (registry-local rho at zeros), so a learnable decoder-penalty weight
3359        // is honoured rather than indexing into an empty view.
3360        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3361        let layout = registry.rho_layout();
3362        let beta = self.flatten_beta();
3363        let mut value = 0.0_f64;
3364        for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3365            let rho_local = rho_global.slice(s![rho_slice.clone()]);
3366            match penalty {
3367                AnalyticPenaltyKind::DecoderIncoherence(base) => {
3368                    if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3369                        value += per_fit.value(beta.view(), rho_local);
3370                    }
3371                }
3372                AnalyticPenaltyKind::MechanismSparsity(base) => {
3373                    for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3374                        if start < end {
3375                            value += per_atom.value(beta.view(), rho_local);
3376                        }
3377                    }
3378                }
3379                AnalyticPenaltyKind::NuclearNorm(base) => {
3380                    for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3381                        value += per_atom.value(beta.slice(s![start..end]), rho_local);
3382                    }
3383                }
3384                _ => {}
3385            }
3386        }
3387        Ok(value)
3388    }
3389
3390    /// Energy of the COORDINATE-tier isometry penalty(ies) at the converged
3391    /// SAE state. This is the per-atom `½μ Σ_n ‖J_n^T W_n J_n / gbar − g_ref‖²`
3392    /// summed over atoms, evaluated through `corrected_isometry_penalty` so the
3393    /// live decoder/coordinate caches drive the value exactly as the assemble
3394    /// path does. It has no `SaeManifoldLoss` twin (the loss carries only
3395    /// data-fit / assignment / smoothness / ARD), so the Laplace/REML criterion
3396    /// must add it explicitly to score the same penalized objective the inner
3397    /// solve descends.
3398    pub fn isometry_penalty_value_total(
3399        &self,
3400        registry: &AnalyticPenaltyRegistry,
3401    ) -> Result<f64, ArrowSchurError> {
3402        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3403        let layout = registry.rho_layout();
3404        let mut value = 0.0_f64;
3405        for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3406            if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3407                let rho_local = rho_global.slice(s![rho_slice.clone()]);
3408                for atom_idx in 0..self.k_atoms() {
3409                    let coord = &self.assignment.coords[atom_idx];
3410                    let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3411                    value += corrected_kind.value(coord.as_flat().view(), rho_local);
3412                }
3413            }
3414        }
3415        Ok(value)
3416    }
3417
3418    /// Whether assembling `registry` will scatter an isometry Gauss-Newton
3419    /// cross-block (`H_tβ`) into the per-row dense `htbeta` slabs.
3420    ///
3421    /// `add_sae_isometry_metric_gn_blocks` writes the coupled cross-block (and
3422    /// flips on `activate_dense_htbeta_supplement`) only when (a) the registry
3423    /// carries an `Isometry` penalty and (b) the atom's chart
3424    /// `preserves_isometry_cross_block_coherence` (flat charts — `Euclidean`,
3425    /// `Circle`, and flat products — keep the full `μ AᵀA` coupling; curved /
3426    /// boundary charts drop it to stay PSD). On the non-frames matrix-free path
3427    /// the data-fit cross-block is carried by the Kronecker row operator and the
3428    /// per-row `htbeta` slab is allocated at zero width (#1406/#1407 anti-leak),
3429    /// so this dense isometry supplement has nowhere to land unless the slab is
3430    /// widened to the full `beta_dim`. This predicate decides exactly that. The
3431    /// effective isometry weight `μ` is NOT consulted here: a near-zero `μ`
3432    /// short-circuits the per-row write, but the slab must still exist so the
3433    /// solver's `htbeta_dense_supplement` read is well-shaped.
3434    pub(crate) fn registry_writes_dense_isometry_cross_block(
3435        &self,
3436        registry: &AnalyticPenaltyRegistry,
3437    ) -> bool {
3438        registry
3439            .penalties
3440            .iter()
3441            .any(|p| matches!(p, AnalyticPenaltyKind::Isometry(_)))
3442            && self
3443                .assignment
3444                .coords
3445                .iter()
3446                .any(|coord| coord.manifold().preserves_isometry_cross_block_coherence())
3447    }
3448
3449    /// Extra analytic-penalty energy that has no native `SaeManifoldLoss`
3450    /// component but is part of the penalized objective ranked by the SAE
3451    /// Laplace/REML criterion.
3452    pub fn reml_extra_penalty_value_total(
3453        &self,
3454        registry: &AnalyticPenaltyRegistry,
3455    ) -> Result<f64, ArrowSchurError> {
3456        Ok(self.analytic_decoder_penalty_value_total(registry)?
3457            + self.isometry_penalty_value_total(registry)?)
3458    }
3459
3460    pub fn penalized_objective_total(
3461        &self,
3462        target: ArrayView2<'_, f64>,
3463        rho: &SaeManifoldRho,
3464        registry: Option<&AnalyticPenaltyRegistry>,
3465        penalty_scale: f64,
3466    ) -> Result<f64, String> {
3467        let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
3468        if let Some(analytic_registry) = registry {
3469            total += self
3470                .analytic_penalty_value_total(analytic_registry, penalty_scale)
3471                .map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
3472        }
3473        // #1026 — decoder-repulsion value, on the SAME frozen gate the assembly
3474        // used, so the line search sees the term the Newton step optimizes. 0
3475        // unless two atoms are near-collinear (the no-op case).
3476        total += self.decoder_repulsion_value(penalty_scale);
3477        // #1026/#1522 — interior-point collapse-prevention barriers, on the SAME
3478        // decoders the assembly's gradient/curvature used, so the line search sees
3479        // exactly the term the inner Newton step optimises (no value/grad desync).
3480        total += self.separation_barrier_value(penalty_scale);
3481        Ok(total)
3482    }
3483
3484    pub(crate) fn decoder_smoothness_value(&self, lambda_smooth: &[f64]) -> f64 {
3485        // Smoothness penalty value is `0.5·λ·Σ_oc B[:,oc]ᵀ S B[:,oc]`. Form the
3486        // `S·B` matrix product once per atom (O(M²·p)) and reduce against `B`
3487        // with a single O(M·p) Hadamard sum, instead of the previous
3488        // four-factor multiply-accumulate inside an `O(M²·p)` triple loop.
3489        // The quadratic form only sees the symmetric part of `S`, so reusing
3490        // the raw (un-symmetrised) `smooth_penalty` here is numerically
3491        // identical to the symmetrised assembly form.
3492        // Per-atom `S_k · B_k` products are independent across atoms, so they ride
3493        // the multi-GPU batched smoothness GEMM (uniform-shape groups tiled across
3494        // every device); `symmetrize = false` because the quadratic form only sees
3495        // the symmetric part of `S` regardless. Exact CPU fallback per atom.
3496        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3497            .atoms
3498            .iter()
3499            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3500            .collect();
3501        let sb_all = batched_smooth_sb(&sb_inputs, false);
3502        let mut acc = 0.0;
3503        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3504            acc += 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3505        }
3506        acc
3507    }
3508
3509    /// Per-atom decoder-smoothness values (#1556): entry `k` is
3510    /// `0.5·λ_smooth[k]·<B_k, S_k B_k>` (sum = [`Self::decoder_smoothness_value`]).
3511    /// This is the explicit `∂loss.smoothness/∂log λ_smooth[k]` gradient entry.
3512    pub(crate) fn decoder_smoothness_value_per_atom(&self, lambda_smooth: &[f64]) -> Vec<f64> {
3513        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3514            .atoms
3515            .iter()
3516            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3517            .collect();
3518        let sb_all = batched_smooth_sb(&sb_inputs, false);
3519        let mut per_atom = vec![0.0_f64; self.atoms.len()];
3520        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3521            per_atom[atom_idx] =
3522                0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3523        }
3524        per_atom
3525    }
3526
3527    pub(crate) fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
3528        if rho.log_ard.len() != self.k_atoms() {
3529            return Err(format!(
3530                "ARD rho has {} atoms but term has {}",
3531                rho.log_ard.len(),
3532                self.k_atoms()
3533            ));
3534        }
3535        let n = self.n_obs();
3536        let mut acc = 0.0;
3537        for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3538            let d = coord.latent_dim();
3539            if rho.log_ard[atom_idx].is_empty() {
3540                continue;
3541            }
3542            if rho.log_ard[atom_idx].len() != d {
3543                return Err(format!(
3544                    "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
3545                    rho.log_ard[atom_idx].len()
3546                ));
3547            }
3548            // Per-axis periodicity selects the smooth von-Mises energy on
3549            // wrapped (Circle) axes and the Gaussian on Euclidean axes.
3550            let periods = coord.effective_axis_periods();
3551            for axis in 0..d {
3552                let log_alpha = rho.log_ard[atom_idx][axis];
3553                // Clamp the log-precision before exponentiating: a raw
3554                // `exp(log_ard)` overflows to `inf` for `log_ard ≳ 709`, and the
3555                // `inf` precision then poisons the ARD energy / curvature with
3556                // `inf · 0.0 = NaN` (#742, Issue 4).
3557                let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
3558                let period = periods[axis];
3559                let mut energy = 0.0;
3560                for row in 0..n {
3561                    let v = coord.row(row)[axis];
3562                    energy += ArdAxisPrior::eval(alpha, v, period).value;
3563                }
3564                // Negative-log prior for precision alpha. The data-dependent
3565                // energy is the (Gaussian or von-Mises) coordinate prior; the
3566                // accompanying normaliser is the precision log-partition.
3567                //
3568                // Euclidean axes keep the Gaussian normaliser `-0.5 n log α`.
3569                // Periodic (von-Mises) axes use the EXACT von-Mises precision
3570                // log-partition `n[-η + log I0(η)]`, η = α/κ², κ = 2π/P, rather
3571                // than the Gaussian surrogate: the von-Mises partition function
3572                // is `2π I0(η)` (up to the κ Jacobian), so the per-observation
3573                // normaliser is `-η + log I0(η)` and is exact across the cut.
3574                match period {
3575                    None => {
3576                        acc += energy - 0.5 * (n as f64) * log_alpha;
3577                    }
3578                    Some(p) => {
3579                        let kappa = std::f64::consts::TAU / p;
3580                        let eta = alpha / (kappa * kappa);
3581                        // Overflow-free `log I0(η)`; `bessel_i0(η).ln()` would be
3582                        // `+inf` for `η ≳ 709` (#1113).
3583                        let log_i0 = bessel_i0_log_and_ratio(eta).0;
3584                        acc += energy + (n as f64) * (-eta + log_i0);
3585                    }
3586                }
3587            }
3588        }
3589        Ok(acc)
3590    }
3591
3592    /// Assemble the enlarged `(logits, t)` row-local Arrow-Schur system.
3593    ///
3594    /// Full-batch entry point: a single chunk covering all rows, with the
3595    /// β-tier penalties (decoder smoothness, ARD, analytic β penalties) carrying
3596    /// their full strength. The streaming driver calls
3597    /// [`Self::assemble_arrow_schur_scaled`] directly with a `penalty_scale`
3598    /// equal to the minibatch fraction `n_chunk / N`, so that the sum of the
3599    /// per-chunk β-tier contributions over a full pass reconstructs exactly the
3600    /// single global β penalty (the smoothness/ARD/β terms are functions of `B`
3601    /// and the global coordinates, not of the chunk's rows).
3602    pub fn assemble_arrow_schur(
3603        &mut self,
3604        target: ArrayView2<'_, f64>,
3605        rho: &SaeManifoldRho,
3606        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3607    ) -> Result<ArrowSchurSystem, String> {
3608        self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
3609    }
3610
3611    /// Assemble the row-local Arrow-Schur system with a `penalty_scale` applied
3612    /// to the β-tier (decoder smoothness, ARD prior, analytic β penalties).
3613    ///
3614    /// `penalty_scale == 1.0` recovers the full-batch assembly. The streaming
3615    /// driver passes the minibatch fraction `n_chunk / N` so that the β-tier
3616    /// reduced-Schur and gradient contributions of the chunks sum to exactly one
3617    /// global copy across a full pass (data-fit, assignment-prior, and per-row
3618    /// coord/logit analytic terms are *not* scaled — they are genuine per-row
3619    /// sums).
3620    pub fn assemble_arrow_schur_scaled(
3621        &mut self,
3622        target: ArrayView2<'_, f64>,
3623        rho: &SaeManifoldRho,
3624        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3625        penalty_scale: f64,
3626    ) -> Result<ArrowSchurSystem, String> {
3627        self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3628            target,
3629            rho,
3630            analytic_penalties,
3631            penalty_scale,
3632            SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
3633        )
3634    }
3635
3636    pub(crate) fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3637        &mut self,
3638        target: ArrayView2<'_, f64>,
3639        rho: &SaeManifoldRho,
3640        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3641        penalty_scale: f64,
3642        dense_beta_penalty_probe_max_dim: usize,
3643    ) -> Result<ArrowSchurSystem, String> {
3644        self.assemble_arrow_schur_inner(
3645            target,
3646            rho,
3647            analytic_penalties,
3648            penalty_scale,
3649            dense_beta_penalty_probe_max_dim,
3650            None,
3651        )
3652    }
3653
3654    /// Innermost assembly entry. `forced_layout` overrides the budget-derived
3655    /// active-set layout so a caller can pin the dense (`Forced(None)`) or a
3656    /// specific compact (`Forced(Some(layout))`) path — used by the
3657    /// compact-vs-dense Riemannian-geometry equality regression test to drive
3658    /// both layouts on identical data. `Computed` is the production path:
3659    /// the layout is derived from the assignment mode + `sparse_active_plan`.
3660    pub(crate) fn assemble_arrow_schur_inner(
3661        &mut self,
3662        target: ArrayView2<'_, f64>,
3663        rho: &SaeManifoldRho,
3664        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3665        penalty_scale: f64,
3666        dense_beta_penalty_probe_max_dim: usize,
3667        forced_layout: ForcedRowLayout,
3668    ) -> Result<ArrowSchurSystem, String> {
3669        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3670            return Err(format!(
3671                "SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3672            ));
3673        }
3674        if target.dim() != (self.n_obs(), self.output_dim()) {
3675            return Err(format!(
3676                "SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
3677                self.n_obs(),
3678                self.output_dim(),
3679                target.dim()
3680            ));
3681        }
3682        if rho.log_ard.len() != self.k_atoms() {
3683            return Err(format!(
3684                "SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
3685                rho.log_ard.len(),
3686                self.k_atoms()
3687            ));
3688        }
3689        // `lambda_smooth` is indexed per-atom in the smoothness gradient/curvature
3690        // assembly (`lambda_smooth[atom_idx]`); a too-short vector (e.g. a growth
3691        // move that grew `k_atoms()` without extending ρ — #1556) would panic deep
3692        // in the assembly loop with an opaque index-out-of-bounds. Validate it here
3693        // alongside `log_ard` so the contract violation surfaces as a clear Err.
3694        if rho.log_lambda_smooth.len() != self.k_atoms() {
3695            return Err(format!(
3696                "SaeManifoldTerm::assemble_arrow_schur: log_lambda_smooth length {} != K {}",
3697                rho.log_lambda_smooth.len(),
3698                self.k_atoms()
3699            ));
3700        }
3701        for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3702            let ard_len = rho.log_ard[atom_idx].len();
3703            let d = coord.latent_dim();
3704            if ard_len != 0 && ard_len != d {
3705                return Err(format!(
3706                    "SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
3707                     has len {ard_len}; expected 0 (disabled) or atom dim {d}"
3708                ));
3709            }
3710        }
3711        // Reparameterize each atom's roughness Gram into arc length at the
3712        // current decoder/coordinates (issue #673). This is the single
3713        // chokepoint for both the inner Newton assembly and the undamped
3714        // evidence factorization, so freezing the pullback-metric weight here
3715        // (lagged-diffusivity) keeps the smoothness value, gradient, Kronecker
3716        // Hessian, and REML log-det mutually consistent within each assembly
3717        // and makes the converged penalty — hence the topology evidence —
3718        // gauge-invariant. Constant-speed (periodic) atoms are unaffected.
3719        for atom in &mut self.atoms {
3720            atom.refresh_intrinsic_smooth_penalty();
3721        }
3722        // #1026 — freeze the decoder-repulsion collinearity gate at the SAME
3723        // assembly chokepoint as the smoothness Gram, so the repulsion's
3724        // gradient/curvature (assembled below) and its value (read by the
3725        // line-search `penalized_objective_total`) share one frozen gate.
3726        self.refresh_decoder_repulsion_gate();
3727        // #1625 — freeze the SEPARATION barrier's normalized-coactivation `q_jk`
3728        // at the same chokepoint. The barrier weights its decoder-shape repulsion
3729        // by the routing coactivation, but its gradient treats that weight as a
3730        // constant; recomputing it from the trial logits in the line-search value
3731        // desyncs value vs gradient in the logit block and stalls the inner solve
3732        // (#1625). Freezing it here makes value/gradient/curvature consistent.
3733        self.refresh_barrier_coactivation_gate();
3734        let n = self.n_obs();
3735        let p = self.output_dim();
3736        let k_atoms = self.k_atoms();
3737        let assignment_dim = self.assignment.assignment_coord_dim();
3738        let q = self.assignment.row_block_dim();
3739        let beta_dim = self.beta_dim();
3740        let frame_projection = FrameProjection::new(self);
3741        let beta_offsets = frame_projection.beta_offsets.clone();
3742        let coord_offsets = self.assignment.coord_offsets();
3743        // β-tier decoder smoothness is a global (B-only) penalty; under a
3744        // minibatch pass it is scaled by the chunk fraction so the per-chunk
3745        // contributions sum to one global copy.
3746        // Per-atom decoder-smoothness strengths (#1556): atom k's penalty `S_k`
3747        // is scaled by `λ_smooth[k]·penalty_scale`. The minibatch `penalty_scale`
3748        // multiplies every atom uniformly.
3749        let lambda_smooth: Vec<f64> = rho
3750            .lambda_smooth_vec()
3751            .iter()
3752            .map(|&l| l * penalty_scale)
3753            .collect();
3754        let (assignment_grad, assignment_hdiag) =
3755            assignment_prior_grad_hdiag(&self.assignment, rho)?;
3756
3757        // #1038 softmax entropy: the exact per-row Hessian in logits is dense
3758        // (`H_kj = (λ/τ²) a_k[δ_kj(m−L_k−1)+a_j(L_k+L_j+1−2m)]`), not just the
3759        // `assignment_hdiag` diagonal. Build the shared penalty + `scale = λ/τ²`
3760        // once here so the dense row block written into `block.htt` below, the
3761        // criterion's `log|H|`, and the #1006 θ-adjoint all differentiate the
3762        // SAME operator. JumpReLU / IBP keep their (separately exact) diagonal /
3763        // cross-row channels and leave this `None`. The block is gauge-null in
3764        // isolation (`H·𝟙 = 0`); it is only ever summed onto the gauge-breaking
3765        // data-fit row block before the Cholesky factor, never factored alone.
3766        let softmax_dense: Option<(
3767            gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
3768            f64,
3769        )> = match self.assignment.mode {
3770            AssignmentMode::Softmax {
3771                temperature,
3772                sparsity,
3773            } if k_atoms > 1 => {
3774                let inv_tau = 1.0 / temperature;
3775                let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
3776                Some((
3777                    gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
3778                        k_atoms,
3779                        temperature,
3780                    ),
3781                    scale,
3782                ))
3783            }
3784            _ => None,
3785        };
3786
3787        // Decoder smoothness penalty: build one KroneckerPenaltyOp per atom
3788        // (structure = λ·S_k ⊗ I_p, offset = beta_offsets[k]) instead of
3789        // materialising the dense K×K block.  The gradient is a dense K-vector
3790        // accumulated into `smooth_grad_gb` and written into sys.gb after sys
3791        // is constructed (#296).
3792        let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
3793        // #972 / #977 T1: retain each atom's symmetrised `λ S_k` (`M_k × M_k`) so
3794        // the frame transform can rebuild the smooth penalty in the factored
3795        // coordinate space as `λ S_k ⊗ I_{r_k}` (the `tr(C_kᵀ S_k C_k)` form,
3796        // using `U_kᵀU_k = I`). Unused — and not even read — on the full-`B`
3797        // path, so this is a zero-cost capture there.
3798        let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
3799        let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
3800        // #1117 — rank deficiency is handled at the basis layer: any
3801        // rank-deficient atom was reparametrized onto its data-supported subspace
3802        // at fit entry (`reduce_atoms_to_data_supported_rank`), so the β-tier here
3803        // always sees a full-rank design and needs no step-time data-null
3804        // deflation operator. The well-conditioned (full-rank) path is unchanged.
3805        // Per-atom smoothness-gradient GEMMs `½(S_k+S_kᵀ)·B_k` are independent
3806        // across atoms; batch them across ALL GPUs (uniform-shape tiles) and
3807        // scale by `lambda_smooth` below. `symmetrize = true` reproduces the
3808        // per-atom symmetrised `scaled_s/λ` used by the Kronecker op. Exact CPU
3809        // fallback per atom keeps the result bit-for-bit with the all-CPU path.
3810        let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3811            .atoms
3812            .iter()
3813            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3814            .collect();
3815        let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
3816        for (atom_idx, atom) in self.atoms.iter().enumerate() {
3817            let m = atom.basis_size();
3818            let off = beta_offsets[atom_idx];
3819            // Symmetrise and scale the smoothness penalty matrix.
3820            let mut scaled_s = Array2::<f64>::zeros((m, m));
3821            for i in 0..m {
3822                for j in 0..m {
3823                    let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
3824                    scaled_s[[i, j]] = lambda_smooth[atom_idx] * s_ij;
3825                }
3826            }
3827            // Gradient: g[beta_i] += (λ_k S_k B_k)[i, out_col]. The (m×m)·(m×p)
3828            // GEMM `½(S+Sᵀ)·B_k` was computed in the multi-GPU batch above; here
3829            // we only apply atom k's `lambda_smooth[atom_idx]`.
3830            let sb = &sym_sb_all[atom_idx] * lambda_smooth[atom_idx];
3831            for out_col in 0..p {
3832                for i in 0..m {
3833                    let beta_i = off + i * p + out_col;
3834                    smooth_grad_gb[beta_i] += sb[[i, out_col]];
3835                }
3836            }
3837            // IdentityRightKroneckerPenaltyOp: factor_a = λ·S_k (m×m), factor_b = I_p.
3838            smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
3839                factor_a: scaled_s.clone(),
3840                p,
3841                global_offset: off,
3842                k: beta_dim,
3843            }));
3844            // Retain `λ S_k` for the factored rebuild (no-op cost on full-`B`).
3845            smooth_scaled_s.push(scaled_s);
3846        }
3847
3848        // Per-row active-set layout. Engaged for two regimes:
3849        //   * JumpReLU — structural gate plus the smooth prior's
3850        //     machine-precision support: atoms with
3851        //     `(logit - threshold)/tau > -36` enter the compact solve
3852        //     ([`jumprelu_in_optimization_band`]). Strictly gated-off atoms
3853        //     (logit ≤ threshold) carry zero assignment mass so their data-fit
3854        //     reconstruction contribution and data-fit logit JVP are zero, but
3855        //     supported atoms keep value-consistent prior gradient in the row block.
3856        //   * IBP-MAP at large `K` — the dense `(m_total · p)²` data
3857        //     Gram is infeasible, so each row is truncated to its
3858        //     top-`k_active` atoms above a relative magnitude cutoff
3859        //     ([`Self::sparse_active_plan`]). Small-`K` problems return `None`
3860        //     and keep the exact full-support layout.
3861        // The compact row block is sized `q_active = |active| + Σ_{k∈active}
3862        // d_k` instead of the full `q`.
3863        let coord_dims: Vec<usize> = self
3864            .assignment
3865            .coords
3866            .iter()
3867            .map(|c| c.latent_dim())
3868            .collect();
3869        let row_layout: Option<SaeRowLayout> = match forced_layout {
3870            Some(layout) => layout,
3871            None => match self.assignment.mode {
3872                AssignmentMode::JumpReLU {
3873                    threshold,
3874                    temperature,
3875                } => Some(SaeRowLayout::from_jumprelu(
3876                    n,
3877                    k_atoms,
3878                    threshold,
3879                    temperature,
3880                    &self.assignment.logits,
3881                    coord_dims.clone(),
3882                    self.assignment.coord_offsets(),
3883                )),
3884                // #1408/#1409 — Softmax engages the COMPACT top-`k` row layout
3885                // inside the optimization (no longer a post-fit projection).
3886                // The active set is each row's top-`k_active_cap` softmax atoms
3887                // above the relative cutoff; the cap comes from the user's
3888                // `top_k` (`softmax_active_cap`) and/or the in-core memory budget
3889                // ([`Self::softmax_active_plan`]). The full-`K` softmax
3890                // normalization still forms `a` (the gate map); only the dropped
3891                // tail logits, carrying negligible `O(a)` reconstruction mass and
3892                // `O(a²)` curvature, leave the per-row block.
3893                //
3894                // Coherence (the load-bearing correctness invariant): the
3895                // assembly's softmax curvature branch writes the ACTIVE×ACTIVE
3896                // principal sub-block of the Gershgorin Loewner majorizer
3897                // `D = diag(Σ_j|H_kj|)` (#1419; PSD and `D ⪰ H_entropy`) on the
3898                // compact logit slots — NOT the indefinite `assignment_hdiag`
3899                // diagonal. The logdet ρ-trace
3900                // (`assignment_log_strength_hessian_trace`) iterates the row's
3901                // active logit slots and indexes that SAME majorizer by global
3902                // atom, and the θ-adjoint reads its derivative via `jets.vars`
3903                // (global-atom indexed), so value, log|H|, and Γ differentiate
3904                // ONE operator on the compact support. The FFI's after-the-fit
3905                // top-`k` projection is then a no-op at the optimum.
3906                AssignmentMode::Softmax { .. } => match self.softmax_active_plan() {
3907                    Some((k_active_cap, relative_cutoff)) => {
3908                        let mut assignments_all = Vec::with_capacity(n);
3909                        for row in 0..n {
3910                            assignments_all
3911                                .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3912                        }
3913                        Some(SaeRowLayout::from_dense_weights(
3914                            &assignments_all,
3915                            k_active_cap,
3916                            relative_cutoff,
3917                            coord_dims.clone(),
3918                            self.assignment.coord_offsets(),
3919                        ))
3920                    }
3921                    None => None,
3922                },
3923                AssignmentMode::IBPMap { .. } => {
3924                    match self.sparse_active_plan() {
3925                        Some((k_active_cap, relative_cutoff)) => {
3926                            // Build per-row dense assignments once to derive the
3927                            // active set; the row loop re-derives `assignments`
3928                            // (cheap gate map at the same rho) and reuses these
3929                            // active sets.
3930                            let mut assignments_all = Vec::with_capacity(n);
3931                            for row in 0..n {
3932                                assignments_all
3933                                    .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3934                            }
3935                            // #1414: pass the RELATIVE cutoff through;
3936                            // `from_dense_weights` applies it per row against that
3937                            // row's own peak `max_k |a_{n,k}|`, matching the
3938                            // documented `sparse_active_plan` contract. A single
3939                            // global threshold (relative_cutoff · whole-dataset
3940                            // peak) wrongly drops every atom of a uniformly-small
3941                            // row when another row peaks high.
3942                            Some(SaeRowLayout::from_dense_weights(
3943                                &assignments_all,
3944                                k_active_cap,
3945                                relative_cutoff,
3946                                coord_dims.clone(),
3947                                self.assignment.coord_offsets(),
3948                            ))
3949                        }
3950                        None => None,
3951                    }
3952                }
3953            },
3954        };
3955        // #974 likelihood-whitening seam. The single per-row decision: when the
3956        // installed `RowMetric` is a genuinely estimated noise model
3957        // (`whitens_likelihood()` — only `WhitenedStructured`), the
3958        // reconstruction data-fit, its t-block Gauss-Newton row block, AND the
3959        // β-tier data-fit gradient are all assembled through the SAME per-row
3960        // metric `M_n = U_n U_nᵀ = Σ_n^{-1}`. There is exactly ONE construction
3961        // site (the `whiten_rows` closure below), so the value the line-search
3962        // sums and the gradient/Hessian the Newton step solves cannot drift apart
3963        // (the objective↔gradient-desync cure). For Euclidean / OutputFisher /
3964        // no-metric the closure is the identity and every downstream loop is
3965        // byte-identical to the historical isotropic path.
3966        let whitens_likelihood = self
3967            .row_metric
3968            .as_ref()
3969            .is_some_and(|metric| metric.whitens_likelihood());
3970        // #972 / #977 T1: engage the FACTORED Grassmann-coordinate β-tier when
3971        // any atom has an active decoder frame. The closed-form factorization
3972        // `Φᵀ(G ⊗ I_p)Φ = G ⊗ (U_iᵀU_j)` is EXACT only for the isotropic
3973        // likelihood; under an active whitening metric (`whitens_likelihood()`,
3974        // only `WhitenedStructured`) the per-row output factor would be
3975        // `U_iᵀ M_n U_j` and does NOT factor out of the basis Gram, so we fall
3976        // back to the full-`B` path there (frames + whitening is out of scope —
3977        // see #974). The common Euclidean / OutputFisher / no-metric case factors
3978        // cleanly. When `frames_engaged` is false, EVERY β-tier object below is
3979        // assembled bit-for-bit as the historical full-`B` path.
3980        let frames_engaged = self.any_frame_active() && !whitens_likelihood;
3981        // #1407: fixed-decoder mode skips the entire β decoder tier (G/gb/htbeta
3982        // operator/hbb/β-penalties); only per-row htt/gt are produced.
3983        let fixed_decoder = self.fixed_decoder_assembly;
3984        let admission_plan = self
3985            .streaming_plan()
3986            .admitted_or_error(self.n_obs(), self.output_dim(), self.k_atoms())
3987            .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
3988        // #1407: fixed-decoder builds NO dense β-Hessian (hbb) — force the
3989        // empty-hbb system constructor so no `beta_dim × beta_dim` workspace is
3990        // taken (the early return skips `reclaim_border_hbb_workspace`).
3991        let dense_beta_curvature = !fixed_decoder
3992            && admission_plan.direct_admitted
3993            && !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
3994        // #1406: the dense per-row cross-block slab `block.htbeta` is only WRITTEN
3995        // (line ~4243) and READ by the solver when `frames_engaged` (the factored
3996        // full-B path, which installs NO matrix-free row operator → the solver's
3997        // `sys_htbeta_apply_row` falls back to the dense slab). On the
3998        // `!frames_engaged` path the cross block is carried entirely by the
3999        // matrix-free Kronecker operator (`set_row_htbeta_operator`, ~line 4491);
4000        // `activate_dense_htbeta_supplement` is never called, so the solver never
4001        // touches `block.htbeta`. Allocating it at `beta_dim = K·M·p` there is the
4002        // ~6 TiB high-K leak (#1405/#1406): allocate ZERO columns instead. Frames
4003        // still use the (much smaller) factored border width.
4004        // #795/#1406/#1407: the non-frames matrix-free path normally holds a
4005        // ZERO-width per-row cross-block slab — the data-fit `H_tβ` is carried by
4006        // the Kronecker row operator (`set_row_htbeta_operator`), and allocating
4007        // the dense slab at `beta_dim = K·M·p` is the high-K memory leak. But an
4008        // ISOMETRY penalty on a coherence-preserving (flat) chart scatters an
4009        // ADDITIONAL Gauss-Newton cross-block into the dense per-row `htbeta`
4010        // slab and flips on `activate_dense_htbeta_supplement` — dropping it would
4011        // leave the Newton system block-diagonal and forfeit the strong `t↔B`
4012        // isometry coupling the circle fit needs to reach KKT stationarity (#795).
4013        // So on the non-frames path widen the slab to `beta_dim` exactly when that
4014        // dense supplement will be written, and keep zero width otherwise.
4015        let dense_isometry_cross_block = !fixed_decoder
4016            && analytic_penalties
4017                .map(|registry| self.registry_writes_dense_isometry_cross_block(registry))
4018                .unwrap_or(false);
4019        let row_htbeta_dim = if fixed_decoder {
4020            // Fixed-decoder mode skips the β tier entirely.
4021            0
4022        } else if frames_engaged {
4023            self.factored_border_dim()
4024        } else if dense_isometry_cross_block {
4025            // Matrix-free data-fit cross-block + dense isometry supplement: the
4026            // supplement is written/read in the full-`B` β coordinate system.
4027            beta_dim
4028        } else {
4029            // Matrix-free path with no dense cross-block supplement.
4030            0
4031        };
4032        // Build the Arrow-Schur system: heterogeneous row dims when a compact
4033        // layout is active, uniform `q` otherwise.
4034        let mut sys = if let Some(ref layout) = row_layout {
4035            let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
4036            if dense_beta_curvature {
4037                let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4038                ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
4039                    per_row_dims,
4040                    beta_dim,
4041                    hbb_workspace,
4042                    row_htbeta_dim,
4043                )
4044            } else {
4045                self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4046                ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
4047                    per_row_dims,
4048                    beta_dim,
4049                    row_htbeta_dim,
4050                )
4051            }
4052        } else if dense_beta_curvature {
4053            let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4054            ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
4055                n,
4056                q,
4057                beta_dim,
4058                hbb_workspace,
4059                row_htbeta_dim,
4060            )
4061        } else {
4062            self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4063            ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
4064        };
4065        // Apply accumulated smoothness-penalty gradients into sys.gb.
4066        for (i, g) in smooth_grad_gb.iter().enumerate() {
4067            sys.gb[i] += g;
4068        }
4069        // `w_dim` is the whitened output dimension: `rank` of the metric factor
4070        // when whitening, else `p` (identity). `error_white` is the whitened
4071        // residual `U_nᵀ r_n ∈ ℝ^{w_dim}` whose squared norm is `r_nᵀ M_n r_n`,
4072        // shared by the value path, the t-block GN, and (lifted back to p-space)
4073        // the β-tier gradient.
4074        let w_dim = match self.row_metric.as_ref() {
4075            Some(metric) if whitens_likelihood => metric.metric_rank(),
4076            _ => p,
4077        };
4078        // Data-fit Gauss-Newton β-Hessian is block-diagonal across the `p`
4079        // output channels and identical in each: with the flat β layout
4080        // `β[μ·p + oc] = B[μ, oc]` (μ enumerating (atom, basis_col)) the GN
4081        // outer product `Jβᵀ Jβ` couples only equal `oc`, with the same
4082        // `(M_total × M_total)` block `G[μ, μ'] = Σ_rows (a_k φ_k[m])(a_{k'} φ_{k'}[m'])`
4083        // for every channel. So `H_data = G ⊗ I_p`. The `μ` index of an `a_phi`
4084        // entry whose global β base is `beta_base` is `beta_base / p` (every
4085        // `beta_offset` and the `basis_col·p` stride are multiples of `p`).
4086        //
4087        // `G` is only non-zero on `(atom_i, atom_j)` pairs that co-occur in
4088        // some row's active set, so we accumulate it as a sparse map of dense
4089        // per-atom-pair `(m_i × m_j)` blocks keyed by `(atom_i, atom_j)` rather
4090        // than as a dense `(m_total × m_total)` matrix. At `K = 100K` with
4091        // per-row active sets of size `k_active ≪ K`, only `O(N · k_active²)`
4092        // pairs are ever touched, so the data Gram (and every matvec /
4093        // diagonal pass over it via `SparseBlockKroneckerPenaltyOp`) tracks the
4094        // active atoms instead of `K²`. In the dense full-support layout the
4095        // map degenerates to every co-occurring pair, reproducing the dense
4096        // Gram exactly. A `BTreeMap` key order keeps the installed op's
4097        // fingerprint deterministic. The `μ`-space offset of atom `k` is
4098        // `beta_offsets[k] / p`.
4099        type SaeGBlocks = std::collections::BTreeMap<(usize, usize), Array2<f64>>;
4100        let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
4101        let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
4102        // Stick-breaking prior for IBP-MAP depends only on (k_atoms, alpha_eff)
4103        // which are constant across rows for the current rho; precompute once.
4104        let ibp_prior_vec = match self.assignment.mode {
4105            AssignmentMode::IBPMap { .. } => {
4106                let alpha = self
4107                    .assignment
4108                    .mode
4109                    .resolved_ibp_alpha(rho)
4110                    .ok_or_else(|| "IBP assignment alpha resolution failed".to_string())?;
4111                Some(ordered_geometric_shrinkage_prior(k_atoms, alpha).to_vec())
4112            }
4113            _ => None,
4114        };
4115        let ibp_prior_slice = ibp_prior_vec.as_deref();
4116        // #991 design honesty weights (mean-1 HT inclusion corrections); see
4117        // the seam comment at the per-row residual below.
4118        let row_loss_w = self.row_loss_weights.as_deref();
4119        // Dense full-support index `[0, k_atoms)`, used by the row loop when no
4120        // compact layout is engaged so the active-atom iteration is uniform.
4121        let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
4122        // Per-atom per-axis periodicity, hoisted out of the row loop. Selects
4123        // the smooth von-Mises coordinate prior on wrapped (Circle) axes and
4124        // the Gaussian prior on Euclidean axes; see `ArdAxisPrior`.
4125        let ard_axis_periods: Vec<Vec<Option<f64>>> = self
4126            .assignment
4127            .coords
4128            .iter()
4129            .map(|coord| coord.effective_axis_periods())
4130            .collect();
4131        struct SaeAssemblyRow {
4132            pub(crate) row: usize,
4133            pub(crate) block: ArrowRowBlock,
4134            pub(crate) gb_delta: Vec<(usize, f64)>,
4135            pub(crate) g_blocks: SaeGBlocks,
4136            pub(crate) kron_a_phi: Option<Vec<(usize, f64)>>,
4137            pub(crate) kron_jac: Option<Vec<f64>>,
4138        }
4139
4140        // Per-row scratch reused across all rows a rayon worker processes
4141        // (#1017). The assembly closure is re-run every inner Newton iteration ×
4142        // every outer ρ evaluation; allocating these eight loop-invariant-sized
4143        // buffers (`k_atoms·p`, several `p`, one `q·max(w_dim,p)`) once per
4144        // worker via `map_init` — rather than once per (row × assembly) inside
4145        // the closure — removes the dominant small-allocation traffic the
4146        // eu-stack profile attributed to allocator/barrier spin at the SAE LLM
4147        // shape (p≈5120). Every buffer is fully filled (or `.fill(0.0)`'d) before
4148        // it is read each row, so reuse is bit-identical to the fresh-alloc path;
4149        // `gb_delta`/`g_blocks` are NOT scratch (they move into the returned
4150        // `SaeAssemblyRow`) and stay allocated per row.
4151        struct RowScratch {
4152            pub(crate) decoded: Array2<f64>,
4153            pub(crate) dg_buf: Vec<f64>,
4154            pub(crate) fitted: Array1<f64>,
4155            pub(crate) error: Array1<f64>,
4156            pub(crate) error_white: Vec<f64>,
4157            pub(crate) error_metric: Array1<f64>,
4158            pub(crate) jac_white: Vec<f64>,
4159            pub(crate) decoded_scratch: Vec<f64>,
4160            // #1557 — per-worker scratch for the row assignment vector (filled via
4161            // `_into`, not allocated per row); full `k_atoms`, global-atom indexed.
4162            pub(crate) assignments: Array1<f64>,
4163        }
4164        // #1410: size the per-worker scratch by the COMPACT row dimensions, not
4165        // full `K`/`q`. With a compact layout the assembly only ever touches each
4166        // row's active atoms (≤ `max_active`) and its compact tangent block
4167        // (≤ `max_q_row`); allocating `decoded` at `k_atoms·p` and `jac_white` at
4168        // `q·max(w_dim,p)` was the per-worker `O(K)` blow-up (≈11 GiB/worker at
4169        // K=100k, p=5120 — and `map_init` gives every Rayon worker its own copy).
4170        // Without a layout the dense path needs full `k_atoms`/`q`. `decoded` rows
4171        // are addressed by COMPACT SLOT in the compact branch below (the dense
4172        // branch keeps global-atom rows), so the row count is the max active set.
4173        //
4174        // #1410/#1408/#1409: SOFTMAX now ALSO takes the `Some(layout)` branch
4175        // whenever a `top_k` cap (`set_softmax_active_cap`) or an in-core memory
4176        // breach engages `softmax_active_plan` → `from_dense_weights`, so its
4177        // per-worker `decoded`/`jac_white` scratch is the COMPACT
4178        // `max_active`/`max_q_row` size too — no longer the full `(k_atoms·p)` /
4179        // `(q·max(w_dim,p))` blow-up. JumpReLU / IBP-MAP likewise pay only
4180        // `max_active`. The remaining `None` (full-`K`) branch is the UNCAPPED
4181        // softmax / no-budget-breach case, which genuinely assembles the dense
4182        // entropy block over all `K`; capping it (the compact contract) removes
4183        // the per-worker `O(K)` footprint entirely. (#1410: the residual per-row
4184        // `O(K)` softmax-majorizer scratch — a `row_logits` copy and the full-`K`
4185        // `d`/`H_entropy` blocks — is removed separately; see the active-only
4186        // `active_softmax_gershgorin_majorizer_entry` /
4187        // `softmax_dense_entropy_hessian_entry` helpers below.)
4188        let (decoded_rows, scratch_q) = match row_layout.as_ref() {
4189            Some(layout) => {
4190                let max_active = (0..n)
4191                    .map(|r| layout.active_atoms[r].len())
4192                    .max()
4193                    .unwrap_or(0)
4194                    .max(1);
4195                let max_q_row = (0..n)
4196                    .map(|r| layout.row_q_active(r))
4197                    .max()
4198                    .unwrap_or(q)
4199                    .max(1);
4200                (max_active, max_q_row)
4201            }
4202            None => (k_atoms, q),
4203        };
4204        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4205        // #1033 large-n: fold the per-row assembly results in row-ordered CHUNKS
4206        // rather than collecting all `n` `SaeAssemblyRow`s at once. The previous
4207        // path materialized the FULL `Vec<SaeAssemblyRow>` (every row's htt/gt
4208        // block + per-row `g_blocks` + `kron_a_phi`/`kron_jac`) AND the fold
4209        // destinations simultaneously — a ~2× transient peak over the resident
4210        // system during the fold, the assembly-side OOM cliff at large `n`. By
4211        // collecting one chunk, folding it into `sys.rows`/`g_blocks`/`kron_*`,
4212        // and dropping the chunk's `Vec` before the next chunk, the transient
4213        // intermediate is bounded to `O(chunk_size)` while the resident output is
4214        // unchanged. The fold stays STRICTLY row-ascending (chunk `[c0..c1)` then
4215        // `[c1..c2)`, rows in order within each chunk), so every `+=` into
4216        // `sys.gb`, the `g_blocks` BTreeMap, and the `kron_*` pushes lands in the
4217        // identical order as the single-pass fold — bit-for-bit the same system.
4218        // Chunk width is the admission plan's `chunk_size` (the same value
4219        // `streaming_plan` sizes for the matrix-free window), floored so a tiny
4220        // plan still makes forward progress.
4221        let assembly_chunk_rows = self
4222            .assembly_chunk_override
4223            .unwrap_or(admission_plan.chunk_size)
4224            .clamp(1, n.max(1));
4225        let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4226        let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
4227        let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
4228        let mut chunk_start = 0usize;
4229        while chunk_start < n {
4230            let chunk_end = (chunk_start + assembly_chunk_rows).min(n);
4231            let mut fold_offset_in_chunk = 0usize;
4232            let row_results: Vec<SaeAssemblyRow> = (chunk_start..chunk_end)
4233                .into_par_iter()
4234                .map_init(
4235                    || RowScratch {
4236                        decoded: Array2::<f64>::zeros((decoded_rows, p)),
4237                        dg_buf: vec![0.0_f64; p],
4238                        fitted: Array1::<f64>::zeros(p),
4239                        error: Array1::<f64>::zeros(p),
4240                        error_white: vec![0.0_f64; w_dim],
4241                        error_metric: Array1::<f64>::zeros(p),
4242                        jac_white: vec![0.0_f64; scratch_q * w_dim.max(p)],
4243                        decoded_scratch: vec![0.0_f64; p],
4244                        assignments: Array1::<f64>::zeros(k_atoms),
4245                    },
4246                    |scratch, row| -> Result<SaeAssemblyRow, String> {
4247                        // #1557 — mark this rayon row worker as a nested data-parallel
4248                        // region so any faer GEMM reached transitively from the per-row
4249                        // assembly (frame `Uᵀ` products, the per-row cross-block /
4250                        // Schur-accumulation matmuls, the Riemannian projections) pins to
4251                        // `Par::Seq` via `effective_global_parallelism` instead of
4252                        // re-fanning the global Rayon pool against this outer fan-out
4253                        // (the `spindle` barrier-spin). Serial vs parallel over these tiny
4254                        // per-row blocks is a single small product, so the result is
4255                        // bit-identical. The guard is held for the whole closure body
4256                        // including its `?`/`return` paths.
4257                        with_nested_parallel(|| {
4258                        let RowScratch {
4259                            decoded,
4260                            dg_buf,
4261                            fitted,
4262                            error,
4263                            error_white,
4264                            error_metric,
4265                            jac_white,
4266                            decoded_scratch,
4267                            assignments,
4268                        } = scratch;
4269                        let mut gb_delta: Vec<(usize, f64)> = Vec::new();
4270                        let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4271                        // #1557 — fill per-worker scratch (bit-identical to alloc path).
4272                        let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
4273                        self.assignment
4274                            .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
4275                        // Reconstruction uses the row's active support: for the dense
4276                        // full-support layout this is all atoms (exact); for a compact
4277                        // layout the dropped atoms carry negligible `O(a)` reconstruction
4278                        // mass and zero curvature, so excluding them keeps `fitted`,
4279                        // `error`, and the logit-JVP cross term `(decoded[k] − fitted)`
4280                        // mutually consistent with the curvature actually assembled.
4281                        fitted.fill(0.0);
4282                        let row_active_owned: Option<&[usize]> =
4283                            row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
4284                        match row_active_owned {
4285                            Some(active) => {
4286                                // #1410: `decoded` is a compact (max_active × p) buffer
4287                                // here; index it by the active-set SLOT `j` (the same
4288                                // index the compact tangent block / `coord_starts` use),
4289                                // NOT the global `atom_idx`.
4290                                for (j, &atom_idx) in active.iter().enumerate() {
4291                                    let a_k = assignments[atom_idx];
4292                                    self.atoms[atom_idx]
4293                                        .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4294                                    for out_col in 0..p {
4295                                        decoded[[j, out_col]] = decoded_scratch[out_col];
4296                                        fitted[out_col] += a_k * decoded_scratch[out_col];
4297                                    }
4298                                }
4299                            }
4300                            None => {
4301                                for atom_idx in 0..k_atoms {
4302                                    let a_k = assignments[atom_idx];
4303                                    self.atoms[atom_idx]
4304                                        .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4305                                    for out_col in 0..p {
4306                                        decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
4307                                        fitted[out_col] += a_k * decoded_scratch[out_col];
4308                                    }
4309                                }
4310                            }
4311                        }
4312                        for out_col in 0..p {
4313                            error[out_col] = fitted[out_col] - target[[row, out_col]];
4314                        }
4315                        // #991 design-honesty seam: a per-row scalar weight `w_row` on the
4316                        // reconstruction channel is exactly the metric `w_row · I_p`, so it
4317                        // is realized as a `√w_row` scaling of the THREE row-local data
4318                        // quantities at their construction sites — this residual, the
4319                        // latent Jacobian (below), and the β basis load `a·φ` (below).
4320                        // Every downstream data object then carries exactly one factor of
4321                        // `w_row` (gt, htt, htbeta, the β Gram `G`, and the β gradient),
4322                        // matching the `w_row`-weighted value `loss_scaled` sums; the
4323                        // per-row latent priors (assignment / ARD, added to `gt`/`htt`
4324                        // further down) are deliberately unweighted — see the
4325                        // `row_loss_weights` field docs. `None` ⇒ `sqrt_row_w == 1.0` and
4326                        // no multiply is applied (bit-identical unweighted path).
4327                        let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
4328                        if sqrt_row_w != 1.0 {
4329                            for out_col in 0..p {
4330                                error[out_col] *= sqrt_row_w;
4331                            }
4332                        }
4333                        // #974 seam (step 1/2): whiten the per-row residual ONCE.
4334                        //   * not whitening ⇒ `error_white == error` (length p) and
4335                        //     `error_metric == error`; every downstream loop is the
4336                        //     historical isotropic path bit-for-bit.
4337                        //   * whitening ⇒ `error_white = U_nᵀ r_n ∈ ℝ^{w_dim}` (its squared
4338                        //     norm is `r_nᵀ M_n r_n`, the value the data-fit sums) and
4339                        //     `error_metric = U_n (U_nᵀ r_n) = M_n r_n ∈ ℝ^p` (the p-space
4340                        //     metric-applied residual the β-tier gradient contracts).
4341                        match self.row_metric.as_ref() {
4342                            Some(metric) if whitens_likelihood => {
4343                                let wr = metric.whiten_residual_row(row, error.view());
4344                                for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
4345                                    *slot = v;
4346                                }
4347                                let mr = metric.apply_metric_row(row, error.view());
4348                                for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
4349                                    *slot = v;
4350                                }
4351                            }
4352                            _ => {
4353                                for out_col in 0..p {
4354                                    error_white[out_col] = error[out_col];
4355                                    error_metric[out_col] = error[out_col];
4356                                }
4357                            }
4358                        }
4359
4360                        // Determine whether this row uses the compact active-set layout.
4361                        //   * JumpReLU: gated atoms plus the smooth prior's
4362                        //     machine-precision support enter.
4363                        //   * IBP-MAP at large K: only the top-`k_active` atoms.
4364                        //   * Otherwise (small K): the dense uniform-q layout.
4365                        let (q_row, mut local_jac_row) = if let Some(layout) = row_layout.as_ref() {
4366                            let active = &layout.active_atoms[row];
4367                            let starts = &layout.coord_starts[row];
4368                            let q_active = layout.row_q_active(row);
4369                            let mut jac_compact = Array2::<f64>::zeros((q_active, p));
4370                            // Logit JVP rows for active atoms only, using the per-mode
4371                            // assignment sensitivity `da_k/dl_k` contracted into the
4372                            // decoded / fitted-corrected output direction.
4373                            let logits_row = self.assignment.logits.row(row);
4374                            for (j, &k) in active.iter().enumerate() {
4375                                fill_active_atom_logit_jvp(
4376                                    ActiveAtomLogitJvp {
4377                                        mode: self.assignment.mode,
4378                                        k,
4379                                        logit_k: logits_row[k],
4380                                        a_k: assignments[k],
4381                                        // #1410: compact slot `j`, not global atom `k`.
4382                                        decoded_k: decoded.row(j),
4383                                        fitted: fitted.view(),
4384                                        ibp_prior: ibp_prior_slice,
4385                                        compact_index: j,
4386                                        // #1026/#1033: a FIXED logit (ungated, or every
4387                                        // atom under frozen routing) has a constant gate
4388                                        // ⇒ zero logit-JVP.
4389                                        ungated: self.assignment.logit_is_fixed(k),
4390                                    },
4391                                    &mut jac_compact,
4392                                );
4393                            }
4394                            // Coordinate JVP rows for active atoms only.
4395                            for (j, &k) in active.iter().enumerate() {
4396                                let d = self.atoms[k].latent_dim;
4397                                let a_k = assignments[k];
4398                                let coord_start = starts[j];
4399                                for axis in 0..d {
4400                                    self.atoms[k].fill_decoded_derivative_row(
4401                                        row,
4402                                        axis,
4403                                        dg_buf.as_mut_slice(),
4404                                    );
4405                                    for out_col in 0..p {
4406                                        jac_compact[[coord_start + axis, out_col]] =
4407                                            a_k * dg_buf[out_col];
4408                                    }
4409                                }
4410                            }
4411                            (q_active, jac_compact)
4412                        } else {
4413                            // Fresh per-row Jacobian, structurally identical to the
4414                            // JumpReLU branch: every (q × p) element is unconditionally
4415                            // overwritten below (assignment-chart JVP rows + coordinate rows), so the
4416                            // `Array2::zeros` allocation needs no separate `fill(0.0)` and
4417                            // the populated buffer is returned by move without a clone.
4418                            let mut jac_row = Array2::<f64>::zeros((q, p));
4419                            fill_assignment_logit_jvp_rows(
4420                                self.assignment.mode,
4421                                self.assignment.logits.row(row),
4422                                assignments.view(),
4423                                decoded.view(),
4424                                fitted.view(),
4425                                ibp_prior_slice,
4426                                // #1026/#1033: zero logit-JVP rows for FIXED-logit atoms
4427                                // (ungated, and all atoms under frozen routing).
4428                                &self.assignment.fixed_logit_mask(),
4429                                &mut jac_row,
4430                            );
4431                            // Coordinate columns for all atoms.
4432                            for atom_idx in 0..k_atoms {
4433                                let d = self.atoms[atom_idx].latent_dim;
4434                                let off = coord_offsets[atom_idx];
4435                                let a_k = assignments[atom_idx];
4436                                for axis in 0..d {
4437                                    self.atoms[atom_idx].fill_decoded_derivative_row(
4438                                        row,
4439                                        axis,
4440                                        dg_buf.as_mut_slice(),
4441                                    );
4442                                    for out_col in 0..p {
4443                                        jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
4444                                    }
4445                                }
4446                            }
4447                            (q, jac_row)
4448                        };
4449
4450                        // #991 design-honesty seam, Jacobian leg: scale the row's latent
4451                        // Jacobian by `√w_row` BEFORE the whitening / Kronecker capture so
4452                        // htt (= J̃J̃ᵀ), the data part of gt (= J̃ẽ, the residual already
4453                        // carries its own √w_row), and the htbeta cross block (J paired
4454                        // with the √w_row-scaled β load below) each carry exactly one
4455                        // factor of `w_row`. No-op on the unweighted path.
4456                        if sqrt_row_w != 1.0 {
4457                            for a in 0..q_row {
4458                                for out_col in 0..p {
4459                                    local_jac_row[[a, out_col]] *= sqrt_row_w;
4460                                }
4461                            }
4462                        }
4463
4464                        // #974 seam (step 2/2): whiten the per-row Jacobian through the SAME
4465                        // metric the residual was whitened by. `jac_white[a*w_dim + k]` holds
4466                        // `J̃[a, k] = Σ_out U_n[out, k] · J_n[a, out]` so the t-block
4467                        // Gauss-Newton row block is `htt = J̃ J̃ᵀ = J_n M_n J_nᵀ` and
4468                        // `gt = J̃ ẽ = J_nᵀ M_n r_n`. When not whitening, `w_dim == p` and the
4469                        // whitened jac equals the raw Jacobian, so htt/gt are byte-identical
4470                        // to the historical isotropic assembly. Because the SAME `error_white`
4471                        // feeds both the value-path data-fit (Σ½ ẽ²) and this gradient
4472                        // (J̃ ẽ), the objective and its t-block gradient share one whitening
4473                        // — they cannot desync.
4474                        if whitens_likelihood {
4475                            if let Some(metric) = self.row_metric.as_ref() {
4476                                for a in 0..q_row {
4477                                    for k in 0..w_dim {
4478                                        let mut acc = 0.0;
4479                                        // U_n[out, k] read through the metric's factor layout.
4480                                        for out_col in 0..p {
4481                                            acc += metric.factor_entry(row, out_col, k)
4482                                                * local_jac_row[[a, out_col]];
4483                                        }
4484                                        jac_white[a * w_dim + k] = acc;
4485                                    }
4486                                }
4487                            }
4488                        } else {
4489                            for a in 0..q_row {
4490                                for out_col in 0..p {
4491                                    jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
4492                                }
4493                            }
4494                        }
4495
4496                        // Build the per-row Arrow-Schur block at the row's active dim.
4497                        let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
4498                        for a in 0..q_row {
4499                            let jac_a = &jac_white[a * w_dim..(a + 1) * w_dim];
4500                            let g = jac_a
4501                                .iter()
4502                                .zip(error_white.iter())
4503                                .map(|(&j, &e)| j * e)
4504                                .sum::<f64>();
4505                            block.gt[a] += g;
4506                            for b in 0..q_row {
4507                                let jac_b = &jac_white[b * w_dim..(b + 1) * w_dim];
4508                                let h = jac_a
4509                                    .iter()
4510                                    .zip(jac_b.iter())
4511                                    .map(|(&ja, &jb)| ja * jb)
4512                                    .sum::<f64>();
4513                                block.htt[[a, b]] += h;
4514                            }
4515                        }
4516
4517                        // Assignment prior in logit space.
4518                        // For compact layout: position `j` = active_atoms index.
4519                        // For dense layout: position `atom_idx` directly.
4520                        //
4521                        // H-consistency note (#1006 audit / #1416 update). This
4522                        // `assignment_hdiag` is the assignment channel's raw diagonal
4523                        // curvature, added un-majorized. It is exact for JumpReLU and exact
4524                        // within each IBP row/column diagonal, and stores ONLY the diagonal of
4525                        // two full-Hessian structures — but those off-diagonal structures are
4526                        // now carried elsewhere, not dropped:
4527                        //
4528                        //   * softmax entropy has dense within-row Hessian
4529                        //     H_kj = (λ/τ²) a_k[δ_kj(m-L_k-1) + a_j(L_k+L_j+1-2m)];
4530                        //     this diagonal stores its Gershgorin Loewner majorizer (#1419).
4531                        //   * IBP empirical-π has cross-row rank-one terms per column
4532                        //     H_(i,k),(j,k) = w score_derivative_k z'_ik z'_jk for i != j.
4533                        //     This per-row diagonal stores only the diagonal/self-row part;
4534                        //     the FULL rank-one cross-row block `U D Uᵀ` is now INSTALLED as a
4535                        //     separate Woodbury source by `set_ibp_cross_row_source` (#1038),
4536                        //     so the assembled operator is `H_full = H₀' + U D Uᵀ` on the
4537                        //     NO-SELF base `H₀' = H₀ − Σ_k d_k diag(z'_ik²)` (self term
4538                        //     downdated, see `IbpCrossRowSource::self_term_downdate`). The
4539                        //     scalar `D`-coefficient `d_k = w·s'_k` is
4540                        //     `IbpHessianDiagThirdChannels::cross_row_d` (FD-verified against
4541                        //     ∂²value/∂ℓ_ik∂ℓ_jk in
4542                        //     `ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian`),
4543                        //     and `z_jac` carries `u_k`'s entries `z'_ik`.
4544                        //
4545                        // The criterion's log|H| and Γ adjoint differentiate this SAME
4546                        // `H_full`: the ρ-trace adds the cross-row off-diagonal in
4547                        // `assignment_log_strength_hessian_trace` (#1416, dense AND compact
4548                        // layouts) and the θ-adjoint adds it in `logdet_theta_adjoint`
4549                        // (#1416/#1641), so value and gradient stay on one operator.
4550                        let assignment_base = row * k_atoms;
4551                        if let Some(layout) = row_layout.as_ref() {
4552                            let active = &layout.active_atoms[row];
4553                            // #1408/#1409 softmax compact curvature: the entropy
4554                            // Hessian diagonal in `assignment_hdiag` is INDEFINITE,
4555                            // so on a compact softmax layout write the Gershgorin
4556                            // Loewner majorizer `D_kk = Σ_j|H_kj|` (#1419) — the same
4557                            // PSD operator the dense softmax branch writes — at each
4558                            // active logit slot. `D` is diagonal, so its active
4559                            // principal sub-block is `diag(D_kk : k ∈ active)`; each
4560                            // `D_kk` is the FULL-`K` abs-row-sum, so it still
4561                            // dominates the active principal sub-block of `H_entropy`
4562                            // (a genuine majorizer on the retained support). The
4563                            // gradient stays the EXACT entropy gradient (it sets the
4564                            // fixed point), so majorizing only conditions the Newton
4565                            // step. JumpReLU/IBP keep their (exact) diagonal.
4566                            //
4567                            // #1410: compute only the active `D_kk` directly from this
4568                            // row's softmax assignments `a` (= `assignments`, already
4569                            // in hand), via `active_softmax_gershgorin_majorizer_entry`.
4570                            // The previous `psd_majorizer_abs_row_sums(&row_logits, ..)`
4571                            // call allocated TWO length-`K` per-row scratch vectors (a
4572                            // fresh `row_logits` copy and the full-`K` returned `d`)
4573                            // only to read `d[k]` for the `≤ top_k` active `k` — an
4574                            // `O(K)` per-row allocation on the path the compact
4575                            // contract keeps `K`-free. The shared `m = Σ_j a_j l_j` is
4576                            // the one irreducible `O(K)` pass, computed once per row.
4577                            let assignments_slice = assignments
4578                                .as_slice()
4579                                .expect("softmax assignments row must be contiguous");
4580                            let majorizer_log_mean: Option<f64> = softmax_dense
4581                                .as_ref()
4582                                .map(|_| softmax_majorizer_log_mean(assignments_slice));
4583                            for (j, &k) in active.iter().enumerate() {
4584                                block.gt[j] += assignment_grad[assignment_base + k];
4585                                match (softmax_dense.as_ref(), majorizer_log_mean) {
4586                                    (Some((_penalty, scale)), Some(m)) => {
4587                                        block.htt[[j, j]] +=
4588                                            active_softmax_gershgorin_majorizer_entry(
4589                                                assignments_slice,
4590                                                k,
4591                                                m,
4592                                                *scale,
4593                                            );
4594                                    }
4595                                    _ => block.htt[[j, j]] += assignment_hdiag[assignment_base + k],
4596                                }
4597                            }
4598                        } else {
4599                            for free_idx in 0..assignment_dim {
4600                                block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
4601                            }
4602                            if let Some((penalty, scale)) = softmax_dense.as_ref() {
4603                                // #1419: write the genuine Gershgorin Loewner majorizer
4604                                // `D = diag(Σ_j|H_kj|)` of the exact entropy Hessian onto the
4605                                // row's logit block in place of the EXACT entropy Hessian. The
4606                                // entropy Hessian is INDEFINITE (concave directions on
4607                                // long-tailed rows), which drove the per-row evidence block
4608                                // non-PD and forced the downstream Faddeev–Popov deflation to
4609                                // flatten data-relevant logit directions (under-identifying the
4610                                // atoms). `D` is a nonnegative diagonal, hence exactly PSD and
4611                                // PD-preserving like the previous Fisher surrogate, so the block
4612                                // stays PD and the deflation no longer fires on the entropy
4613                                // block. Unlike the Fisher metric `G = scale·(diag(a) − a aᵀ)`,
4614                                // which is PSD but NOT a majorizer (`G − H_entropy` can be
4615                                // indefinite — K=2, a=(0.95,0.05): G₁₁=0.0475 < H₁₁=0.0784,
4616                                // #1419), `D` actually satisfies `D ⪰ H_entropy` and `D ⪰ 0`,
4617                                // so it is a true MM/Loewner curvature majorizer. Because the
4618                                // entropy penalty is a FIXED prior whose stationary point is set
4619                                // by its (unchanged) EXACT gradient, replacing its curvature
4620                                // with the majorizer only conditions the Newton step and the
4621                                // Laplace normalizer's curvature operator — it does NOT move the
4622                                // optimum.
4623                                //
4624                                // Softmax uses the REDUCED K−1 free-logit chart (the last
4625                                // reference logit is fixed at 0, `assignment_coord_dim() = K−1`).
4626                                // Holding z_{K-1} fixed, the reduced curvature over the free
4627                                // logits 0..K−1 is exactly the top-left (K−1)×(K−1) submatrix of
4628                                // the full K×K majorizer (the fixed logit contributes no
4629                                // row/column to the free curvature). The criterion's `log|H|`
4630                                // and the #1006 θ-adjoint differentiate this SAME `D` (see the
4631                                // `row_psd_majorizer_logit_derivative` site below), so value and
4632                                // adjoint stay on one exact branch.
4633                                let row_logits: Vec<f64> = (0..k_atoms)
4634                                    .map(|k| self.assignment.logits[[row, k]])
4635                                    .collect();
4636                                let h_dense = penalty.row_psd_majorizer(&row_logits, *scale);
4637                                for ki in 0..assignment_dim {
4638                                    for kj in 0..assignment_dim {
4639                                        block.htt[[ki, kj]] += h_dense[[ki, kj]];
4640                                    }
4641                                }
4642                            } else {
4643                                for free_idx in 0..assignment_dim {
4644                                    block.htt[[free_idx, free_idx]] +=
4645                                        assignment_hdiag[assignment_base + free_idx];
4646                                }
4647                            }
4648                        }
4649
4650                        // ARD on each on-atom coordinate.
4651                        // For compact layout: only active atoms; coord positions use compact starts.
4652                        // For dense layout: all atoms; coord positions use coord_offsets.
4653                        if let Some(layout) = row_layout.as_ref() {
4654                            let active = &layout.active_atoms[row];
4655                            let starts = &layout.coord_starts[row];
4656                            for (j, &k) in active.iter().enumerate() {
4657                                let coord = &self.assignment.coords[k];
4658                                let d = coord.latent_dim();
4659                                if rho.log_ard[k].is_empty() {
4660                                    continue;
4661                                }
4662                                if rho.log_ard[k].len() != d {
4663                                    return Err(format!(
4664                                        "ARD rho atom {k} has len {} but atom dim is {d}",
4665                                        rho.log_ard[k].len()
4666                                    ));
4667                                }
4668                                let row_t = coord.row(row);
4669                                let periods = &ard_axis_periods[k];
4670                                for axis in 0..d {
4671                                    // ARD on coords is a genuine per-row prior (each row
4672                                    // contributes the per-axis prior energy), so it is NOT
4673                                    // minibatch-scaled — the per-chunk row sums already
4674                                    // reconstruct the full coordinate prior across a pass.
4675                                    // The value (`ard_value`/`loss.ard`) and the gradient
4676                                    // both come from the SAME `ArdAxisPrior` energy, so they
4677                                    // stay FD-consistent on periodic axes. The exact
4678                                    // von-Mises curvature `V'' = α·cos(κt)` is INDEFINITE —
4679                                    // it goes negative for |t| past a quarter period — so
4680                                    // writing it raw into the Newton/Schur `htt` diagonal
4681                                    // makes that PSD curvature block indefinite and the Schur
4682                                    // Cholesky (used both for the Newton step and the exact
4683                                    // log-det) fails on a non-PD pivot. Accumulate the PSD
4684                                    // majorizer `max(V'', 0)` instead, exactly as
4685                                    // `add_sae_coord_penalty` does for the registry coord
4686                                    // penalties: the positive part keeps `htt` PSD so the
4687                                    // factorization succeeds, and majorizing the curvature of
4688                                    // a fixed prior only damps the Newton step — it does not
4689                                    // move the stationary point (the gradient, which sets the
4690                                    // fixed point, stays the exact `V'`).
4691                                    let alpha =
4692                                        SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
4693                                    let prior =
4694                                        ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4695                                    block.gt[starts[j] + axis] += prior.grad;
4696                                    block.htt[[starts[j] + axis, starts[j] + axis]] +=
4697                                        prior.hess.max(0.0);
4698                                }
4699                            }
4700                        } else {
4701                            for atom_idx in 0..k_atoms {
4702                                let coord = &self.assignment.coords[atom_idx];
4703                                let d = coord.latent_dim();
4704                                if rho.log_ard[atom_idx].is_empty() {
4705                                    continue;
4706                                }
4707                                if rho.log_ard[atom_idx].len() != d {
4708                                    return Err(format!(
4709                                        "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
4710                                        rho.log_ard[atom_idx].len()
4711                                    ));
4712                                }
4713                                let off = coord_offsets[atom_idx];
4714                                let row_t = coord.row(row);
4715                                let periods = &ard_axis_periods[atom_idx];
4716                                for axis in 0..d {
4717                                    // PSD-majorize the (possibly negative) von-Mises curvature
4718                                    // into the Newton/Schur `htt` block; see the compact-layout
4719                                    // branch above for why `max(V'', 0)` is required to keep
4720                                    // `htt` PD (the exact `V'' = α·cos κt` is indefinite past a
4721                                    // quarter period and breaks the Schur/log-det Cholesky).
4722                                    let alpha = SaeManifoldRho::stable_exp_strength(
4723                                        rho.log_ard[atom_idx][axis],
4724                                    );
4725                                    let prior =
4726                                        ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4727                                    block.gt[off + axis] += prior.grad;
4728                                    block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
4729                                }
4730                            }
4731                        }
4732
4733                        // Beta gradient/Hessian — Kronecker form J_β = φᵀ ⊗ I_p.
4734                        //
4735                        // The per-row beta Jacobian is
4736                        //   J_β[out_col, beta_idx] = a_k · phi_k[basis_col]   if out_col == out_col(beta_idx)
4737                        //                            0                         otherwise
4738                        // so the data-fit Gauss-Newton beta-Hessian factors as a rank-`p`
4739                        // sum of outer products. We pre-compute the per-(atom, basis_col)
4740                        // scalar `a_k · phi_k` once and reuse it across the `out_col`
4741                        // and inner `(atom_j, basis_col2)` loops.
4742                        //
4743                        // Full-B rows keep the matrix-free Kronecker path below. Factored
4744                        // rows write the `q_i × Σ M_k r_k` C-space cross slab directly by
4745                        // folding each output-channel contribution through the atom frame,
4746                        // so no `q_i × β_dim` slab is ever materialized.
4747                        //
4748                        // Only the row's active atoms contribute `a_phi` support and data
4749                        // curvature: in a compact layout (JumpReLU gate or large-K
4750                        // top-`k_active` truncation) the inactive atoms carry zero (gated)
4751                        // or sub-cutoff assignment mass and are excluded — this is what
4752                        // keeps both the htbeta support and the `G` accumulation
4753                        // `O(k_active)` rather than `O(K)`. In the dense full-support
4754                        // layout `row_active` spans all atoms.
4755                        let row_active: &[usize] = match row_layout.as_ref() {
4756                            Some(layout) => layout.active_atoms[row].as_slice(),
4757                            None => &all_atoms_index,
4758                        };
4759                        // #1407: in fixed-decoder mode the β tier is not assembled at
4760                        // all — leave gb_delta/g_blocks empty and kron None. htt/gt
4761                        // (built above) are the only outputs the frozen-decoder step
4762                        // consumes.
4763                        let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
4764                        // Per-active-atom weighted basis row `a_k · φ_k[·]`, retained so the
4765                        // data Gram blocks can be accumulated as clean per-atom-pair outer
4766                        // products `(a_k φ_k) (a_{k'} φ_{k'})ᵀ`.
4767                        let mut weighted_phi: Vec<(usize, Vec<f64>)> =
4768                            Vec::with_capacity(row_active.len());
4769                        if !fixed_decoder {
4770                            for &atom_idx in row_active {
4771                                let atom = &self.atoms[atom_idx];
4772                                let atom_beta_off = beta_offsets[atom_idx];
4773                                let m = atom.basis_size();
4774                                let a_k = assignments[atom_idx];
4775                                let mut wphi = Vec::with_capacity(m);
4776                                for basis_col in 0..m {
4777                                    let phi = atom.basis_values[[row, basis_col]];
4778                                    // #991 design-honesty seam, β leg: the `√w_row` here pairs
4779                                    // with the `√w_row` on the residual (β gradient =
4780                                    // `a·φ · M r` ⇒ w_row) and with itself (β Gram `G` and the
4781                                    // htbeta Kronecker capture ⇒ w_row). `1.0` when unweighted.
4782                                    let w = a_k * phi * sqrt_row_w;
4783                                    a_phi.push((atom_beta_off + basis_col * p, w));
4784                                    wphi.push(w);
4785                                }
4786                                weighted_phi.push((atom_idx, wphi));
4787                            }
4788                            // β data-fit gradient `gᵦ += J_βᵀ M_n r_n`. The β-Jacobian is
4789                            // `J_β = φ_nᵀ ⊗ I_p`, so `J_βᵀ M_n r_n = φ_n ⊗ (M_n r_n)` —
4790                            // contract the basis weight `a·φ` against the p-space metric-applied
4791                            // residual `error_metric` (= `M_n r_n`), the SAME whitening the value
4792                            // path and t-block share. When not whitening, `error_metric == error`
4793                            // and this is byte-identical to the historical `J_βᵀ r`.
4794                            for &(beta_base_i, j_beta_i) in a_phi.iter() {
4795                                if j_beta_i == 0.0 {
4796                                    continue;
4797                                }
4798                                for out_col in 0..p {
4799                                    gb_delta.push((
4800                                        beta_base_i + out_col,
4801                                        j_beta_i * error_metric[out_col],
4802                                    ));
4803                                    // No dense hbb write — the sparse `G ⊗ I_p` op installed
4804                                    // after the loop carries the data-fit GN β-Hessian.
4805                                }
4806                            }
4807                            if frames_engaged {
4808                                for &atom_idx in row_active {
4809                                    let atom = &self.atoms[atom_idx];
4810                                    let m = atom.basis_size();
4811                                    let a_k = assignments[atom_idx];
4812                                    for basis_col in 0..m {
4813                                        let phi = atom.basis_values[[row, basis_col]];
4814                                        let w = a_k * phi * sqrt_row_w;
4815                                        if w == 0.0 {
4816                                            continue;
4817                                        }
4818                                        let c_base = frame_projection.border_offsets[atom_idx]
4819                                            + basis_col * frame_projection.ranks[atom_idx];
4820                                        for c in 0..q_row {
4821                                            let mut hrow = block.htbeta.row_mut(c);
4822                                            let hrow_slice = hrow
4823                                                .as_slice_mut()
4824                                                .expect("htbeta row is contiguous");
4825                                            for out_col in 0..p {
4826                                                let value = local_jac_row[[c, out_col]] * w;
4827                                                frame_projection.accumulate_output_project(
4828                                                    atom_idx, c_base, out_col, value, hrow_slice,
4829                                                );
4830                                            }
4831                                        }
4832                                    }
4833                                }
4834                            }
4835                            // Data-fit GN β-Hessian: accumulate the channel-independent block
4836                            // `G[μ_i, μ_j] += (a_k φ_k)[μ_i] (a_{k'} φ_{k'})[μ_j]` into the
4837                            // sparse per-atom-pair map (the `out_col` dimension is carried by
4838                            // `I_p`). Only co-occurring `(atom_i, atom_j)` pairs are touched.
4839                            for ai in 0..weighted_phi.len() {
4840                                let (atom_i, ref wphi_i) = weighted_phi[ai];
4841                                let m_i = wphi_i.len();
4842                                for aj in 0..weighted_phi.len() {
4843                                    let (atom_j, ref wphi_j) = weighted_phi[aj];
4844                                    let m_j = wphi_j.len();
4845                                    let blk = g_blocks
4846                                        .entry((atom_i, atom_j))
4847                                        .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4848                                    for li in 0..m_i {
4849                                        let wi = wphi_i[li];
4850                                        if wi == 0.0 {
4851                                            continue;
4852                                        }
4853                                        for lj in 0..m_j {
4854                                            blk[[li, lj]] += wi * wphi_j[lj];
4855                                        }
4856                                    }
4857                                }
4858                            }
4859                        } // #1407 end `if !fixed_decoder` β-tier accumulation
4860                        let (kron_a_phi, kron_jac) = if !frames_engaged && !fixed_decoder {
4861                            // Flatten local_jac_row row-major into a plain Vec<f64> (q_row * p entries).
4862                            let mut jac_flat = vec![0.0_f64; q_row * p];
4863                            for c in 0..q_row {
4864                                for j in 0..p {
4865                                    jac_flat[c * p + j] = local_jac_row[[c, j]];
4866                                }
4867                            }
4868                            (Some(a_phi), Some(jac_flat))
4869                        } else {
4870                            (None, None)
4871                        };
4872                        Ok(SaeAssemblyRow {
4873                            row,
4874                            block,
4875                            gb_delta,
4876                            g_blocks,
4877                            kron_a_phi,
4878                            kron_jac,
4879                        })
4880                        }) // #1557 with_nested_parallel
4881                    },
4882                )
4883                .collect::<Result<Vec<_>, String>>()?;
4884
4885            // Fold THIS chunk's rows (ascending) into the global accumulators.
4886            // The parallel collect preserves index order within the chunk and
4887            // chunks are visited in ascending `chunk_start` order, so the overall
4888            // fold order is `0,1,2,…,n-1` — identical to the former single-pass
4889            // fold. The `row == chunk_start + fold_offset_in_chunk` assert pins
4890            // that strict sequential arrival (the invariant the `kron_*`
4891            // row-aligned pushes depend on).
4892            for row_result in row_results.into_iter() {
4893                let row = row_result.row;
4894                assert_eq!(
4895                    row,
4896                    chunk_start + fold_offset_in_chunk,
4897                    "parallel SAE row assembly returned rows out of order"
4898                );
4899                fold_offset_in_chunk += 1;
4900                for (idx, value) in row_result.gb_delta {
4901                    sys.gb[idx] += value;
4902                }
4903                for ((atom_i, atom_j), data) in row_result.g_blocks {
4904                    let m_i = data.nrows();
4905                    let m_j = data.ncols();
4906                    let blk = g_blocks
4907                        .entry((atom_i, atom_j))
4908                        .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4909                    for li in 0..m_i {
4910                        for lj in 0..m_j {
4911                            blk[[li, lj]] += data[[li, lj]];
4912                        }
4913                    }
4914                }
4915                if !frames_engaged && !fixed_decoder {
4916                    // Rows arrive in ascending order across chunks, so pushing
4917                    // here yields `kron_*[row]` aligned to the row index exactly
4918                    // as the single-pass `push` did.
4919                    kron_a_phi.push(
4920                        row_result
4921                            .kron_a_phi
4922                            .expect("full-B SAE row assembly must return a_phi rows"),
4923                    );
4924                    kron_jac.push(
4925                        row_result
4926                            .kron_jac
4927                            .expect("full-B SAE row assembly must return local Jacobian rows"),
4928                    );
4929                }
4930                sys.rows[row] = row_result.block;
4931            }
4932            chunk_start = chunk_end;
4933        }
4934        // #1407: fixed-decoder early return. The per-row htt/gt are now fully
4935        // assembled (data GN + assignment/ARD prior). Apply only the htt/gt
4936        // Riemannian projection (the decoder/β tier is intentionally absent), then
4937        // return the block-diagonal system. `fixed_decoder_step_from_rows` reads
4938        // only `rows[*].htt`/`gt` + `row_offsets`, so no β-tier object is needed.
4939        if fixed_decoder {
4940            match row_layout.as_ref() {
4941                None => {
4942                    // Dense uniform-q: project htt/gt (and the 0-width htbeta, a
4943                    // no-op) through the ext-coord manifold.
4944                    self.apply_sae_riemannian_geometry(&mut sys);
4945                }
4946                Some(layout) => {
4947                    // Compact heterogeneous-q: project each row's htt/gt at its
4948                    // own ext-coord point, mirroring the full path's compact
4949                    // Riemannian block (htbeta is 0-width here, so skipped).
4950                    if !self.ext_coord_manifold().is_euclidean() {
4951                        for row_idx in 0..n {
4952                            let (manifold_i, point_i) =
4953                                self.compact_row_ext_manifold_and_point(row_idx, layout);
4954                            let t_i = point_i.view();
4955                            let gt_e = sys.rows[row_idx].gt.clone();
4956                            let htt_e = sys.rows[row_idx].htt.clone();
4957                            sys.rows[row_idx].gt =
4958                                manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
4959                            sys.rows[row_idx].htt = manifold_i.riemannian_hessian_matrix(
4960                                t_i,
4961                                gt_e.view(),
4962                                htt_e.view(),
4963                            );
4964                        }
4965                    }
4966                }
4967            }
4968            if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
4969                sys.set_row_gauge_deflation(deflation);
4970            }
4971            self.last_row_layout = row_layout;
4972            self.last_frames_active = frames_engaged;
4973            return Ok(sys);
4974        }
4975        // Apply Riemannian geometry to the per-row row blocks (htt, gt) and
4976        // also to the per-row Kronecker local Jacobians stored in kron_jac.
4977        // When the SAE ext-coord manifold is non-Euclidean (any atom latent
4978        // on sphere / circle / interval), the local Jacobian rows that map
4979        // into the t-block tangent space must be projected via the per-row
4980        // tangent projector P_i.  This mirrors what
4981        // `apply_riemannian_latent_geometry` does to `row.htbeta`, applied
4982        // here to the (q × p) kron_jac so the Kronecker htbeta_matvec uses
4983        // the Riemannian-projected form.
4984        // Apply Riemannian geometry only for the dense uniform-q layout. Any
4985        // compact active-set layout (JumpReLU gate or large-K softmax/IBP
4986        // truncation) has heterogeneous q_i; the Riemannian projector path
4987        // requires a uniform latent dimension. The sparse plan only engages on
4988        // Euclidean ext-coord manifolds (see `sparse_active_plan`), so skipping
4989        // the projector here is correct — there is nothing to project.
4990        match row_layout.as_ref() {
4991            None => {
4992                let raw_gt_rows: Vec<Array1<f64>> =
4993                    sys.rows.iter().map(|row| row.gt.clone()).collect();
4994                self.apply_sae_riemannian_geometry(&mut sys);
4995                let manifold = self.ext_coord_manifold();
4996                if !frames_engaged && !manifold.is_euclidean() {
4997                    let ext = self.ext_coord_matrix();
4998                    // Project the local Jacobian columns onto the tangent space at
4999                    // each row's ext-coord point. Each column `j` of the row's
5000                    // (q_row × p) Jacobian is an ambient-space vector of length
5001                    // `q_row`; the manifold projector acts on one such column at a
5002                    // time. Working directly on the row-major `jac_flat` storage via
5003                    // a single reusable `col_buf` avoids the two dense (q × p) copies
5004                    // (flatten→Array2, project, unflatten→Vec) that previously fired
5005                    // per row. `t_buf` still holds the row's ext-coord vector.
5006                    let mut t_buf = vec![0.0_f64; q];
5007                    let mut col_buf = Array1::<f64>::zeros(q);
5008                    for row_idx in 0..n {
5009                        let ext_row = ext.row(row_idx);
5010                        for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
5011                            *slot = v;
5012                        }
5013                        let t_i = ArrayView1::from(t_buf.as_slice());
5014                        let raw_gt = raw_gt_rows[row_idx].view();
5015                        let jac_flat = &mut kron_jac[row_idx];
5016                        let q_row = jac_flat.len() / p;
5017                        for j in 0..p {
5018                            for c in 0..q_row {
5019                                col_buf[c] = jac_flat[c * p + j];
5020                            }
5021                            let projected_col = manifold.project_vector_to_gradient_tangent(
5022                                t_i,
5023                                raw_gt.slice(ndarray::s![..q_row]),
5024                                col_buf.slice(ndarray::s![..q_row]),
5025                            );
5026                            for c in 0..q_row {
5027                                jac_flat[c * p + j] = projected_col[c];
5028                            }
5029                        }
5030                    }
5031                }
5032            }
5033            Some(layout) => {
5034                // Compact active-set layout (#1117 follow-up): the dense
5035                // `ext_coord_manifold()` is keyed to the uniform full-`q` block
5036                // ordering, so it cannot be applied to the heterogeneous compact
5037                // rows directly. Instead we rebuild, PER ROW, the product manifold
5038                // and ext-coord point in that row's compact column order (see
5039                // `compact_row_ext_manifold_and_point`) and apply the SAME three
5040                // per-row Riemannian operations the dense
5041                // `apply_riemannian_latent_geometry` applies — gradient tangent
5042                // projection of `gt`, the Riemannian Hessian correction of `htt`,
5043                // and the column tangent projection of `htbeta` — plus the
5044                // identical Kronecker `kron_jac` column projection. On the shared
5045                // active support this is byte-identical to slicing the dense
5046                // product manifold, so engaging the sparse plan on a non-Euclidean
5047                // ext manifold is now correct (the former
5048                // `is_euclidean()`-only guard in `sparse_active_plan` is lifted).
5049                //
5050                // Euclidean ext manifolds still skip all of this (every
5051                // per-row manifold is a product of Euclidean parts whose
5052                // projector is the identity); we early-out so those rows stay
5053                // byte-for-byte the historical compact path.
5054                if !self.ext_coord_manifold().is_euclidean() {
5055                    for row_idx in 0..n {
5056                        let (manifold_i, point_i) =
5057                            self.compact_row_ext_manifold_and_point(row_idx, layout);
5058                        let t_i = point_i.view();
5059                        // gt / htt / htbeta on the compact ArrowRowBlock, exactly
5060                        // as `apply_riemannian_latent_geometry` does for dense
5061                        // uniform-q rows.
5062                        let gt_e = sys.rows[row_idx].gt.clone();
5063                        let htt_e = sys.rows[row_idx].htt.clone();
5064                        sys.rows[row_idx].gt =
5065                            manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
5066                        sys.rows[row_idx].htt =
5067                            manifold_i.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
5068                        // #1406: only the frames path holds a real dense `htbeta`
5069                        // slab; the matrix-free path leaves it 0-width (the
5070                        // cross-block geometry is applied to `kron_jac` below), so
5071                        // projecting a zero-column matrix is a no-op we skip.
5072                        if frames_engaged {
5073                            let htbeta_e = sys.rows[row_idx].htbeta.clone();
5074                            sys.rows[row_idx].htbeta = manifold_i
5075                                .project_matrix_columns_to_gradient_tangent(
5076                                    t_i,
5077                                    gt_e.view(),
5078                                    htbeta_e.view(),
5079                                );
5080                        }
5081                        // Kronecker local-Jacobian column projection (full-B path
5082                        // only), using the SAME pre-projection gradient `gt_e` so
5083                        // the cross-block geometry matches the dense branch.
5084                        if !frames_engaged {
5085                            let jac_flat = &mut kron_jac[row_idx];
5086                            let q_row = jac_flat.len() / p;
5087                            let mut col_buf = Array1::<f64>::zeros(q_row);
5088                            for j in 0..p {
5089                                for c in 0..q_row {
5090                                    col_buf[c] = jac_flat[c * p + j];
5091                                }
5092                                let projected_col = manifold_i.project_vector_to_gradient_tangent(
5093                                    t_i,
5094                                    gt_e.view(),
5095                                    col_buf.view(),
5096                                );
5097                                for c in 0..q_row {
5098                                    jac_flat[c * p + j] = projected_col[c];
5099                                }
5100                            }
5101                        }
5102                    }
5103                }
5104            }
5105        }
5106        // Build and install the full-B Kronecker htbeta_matvec.
5107        //
5108        // `SaeKroneckerRows` holds per-row `(a_phi, local_jac)` and implements
5109        // the cross-block operator without ever materialising the dense
5110        // `(q × K·p)` slab.  The cross-block factorises as `H_tβ = L · J_β`,
5111        // where `J_β = φᵀ ⊗ I_p` projects a length-`K` β vector onto the
5112        // `p`-dimensional decoded output space (`apply_jbeta`) and `L_i` is
5113        // the per-row `(q_i × p)` assignment+coordinate Jacobian that lifts
5114        // that p-vector into the row's `q_i`-dim tangent block (`apply_l`).
5115        // Both factors are required: the contract of `set_row_htbeta_operator`
5116        // is `out.len() == d` (= `q_i`), so writing `apply_jbeta`'s p-vector
5117        // output directly into a length-`q_i` buffer overflows whenever
5118        // `p > q_i` (the common case once `p` reflects real feature width).
5119        // Symmetric for the transpose: `H_βt = J_βᵀ · Lᵀ`, so apply `Lᵀ`
5120        // first to map the q_i-vector back to p-space, then scatter through
5121        // the support.
5122        // #1017/#1026: the legacy full-B device PCG assumes `G ⊗ I_p`, while
5123        // framed systems carry `G_ij ⊗ W_ij` with rank-r atom blocks. Feeding a
5124        // framed system to that kernel would silently return the wrong Newton
5125        // step. Framed device PCG therefore needs the dedicated factored kernel.
5126        // #1033 large-n: the per-row support `kron_a_phi` and local Jacobians
5127        // `kron_jac` are consumed by BOTH the host matrix-free row operator
5128        // (`SaeKroneckerRows`) and the solver's `DeviceSaePcgData`. Previously
5129        // each took its own full `O(n·q·p)` / `O(n·k_active)` clone, so the
5130        // always-resident footprint of the CPU non-frames path carried TWO copies
5131        // of the dominant Jacobian slab. Promote each to a single `Arc<[…]>` once
5132        // and hand both consumers a refcount bump (`O(1)`) — the backing
5133        // allocation is shared, halving the resident per-row Jacobian memory.
5134        // Reads are identical (`&arc[row]`, `.len()`), so the assembled system and
5135        // every matvec are bit-for-bit unchanged.
5136        let device_rows = if frames_engaged {
5137            None
5138        } else {
5139            let a_phi_shared: Arc<[Vec<(usize, f64)>]> =
5140                Arc::from(std::mem::take(&mut kron_a_phi).into_boxed_slice());
5141            let jac_shared: Arc<[Vec<f64>]> =
5142                Arc::from(std::mem::take(&mut kron_jac).into_boxed_slice());
5143            Some((a_phi_shared, jac_shared))
5144        };
5145        if !frames_engaged {
5146            let (a_phi_shared, jac_shared) = device_rows
5147                .clone()
5148                .expect("non-frames path always populates device_rows");
5149            let kron = Arc::new(SaeKroneckerRows::new(p, a_phi_shared, jac_shared));
5150            let kron_t = Arc::clone(&kron);
5151            let p_dim = p;
5152            sys.set_row_htbeta_operator(
5153                move |row_idx, x, out| {
5154                    // out = L_i · (J_β · x). Allocate a length-p scratch buffer
5155                    // for the intermediate decoded-output vector; both factors
5156                    // overwrite their output buffers (`apply_jbeta` zeroes
5157                    // before accumulating, `apply_l` writes per-row), so no
5158                    // pre-zeroing of `u_p`/`out` is needed.
5159                    let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5160                    let mut u_p = vec![0.0_f64; p_dim];
5161                    if let Some(xs) = x.as_slice() {
5162                        kron.apply_jbeta(row_idx, xs, &mut u_p);
5163                    } else {
5164                        let x_vec: Vec<f64> = x.iter().copied().collect();
5165                        kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
5166                    }
5167                    kron.apply_l(row_idx, &u_p, out_slice);
5168                },
5169                move |row_idx, v, out| {
5170                    // out += J_βᵀ · (Lᵀ · v). `apply_l_t` accumulates into a
5171                    // zero-initialised length-p buffer to produce the p-vector
5172                    // `Lᵀ v`; `scatter_jbeta_t` then adds φ_i[s] · u_p[j] into
5173                    // the length-K β accumulator at each active `(s, j)`.
5174                    let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5175                    let mut u_p = vec![0.0_f64; p_dim];
5176                    if let Some(vs) = v.as_slice() {
5177                        kron_t.apply_l_t(row_idx, vs, &mut u_p);
5178                    } else {
5179                        let v_vec: Vec<f64> = v.iter().copied().collect();
5180                        kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
5181                    }
5182                    kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
5183                },
5184            );
5185        }
5186        let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
5187        let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
5188            Some(&frame_projection)
5189        } else {
5190            None
5191        };
5192        if let Some(registry) = analytic_penalties {
5193            // Upfront validation: refuse penalty kinds the SAE row layout
5194            // cannot host, and refuse mixed-d row-block configurations.
5195            // This makes the dispatch loop below total — no runtime
5196            // "unsupported penalty" fallthrough, no K-gating.
5197            self.validate_analytic_penalty_registry(registry)
5198                .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5199            beta_penalty_assembly = self
5200                .add_sae_analytic_penalty_contributions(
5201                    &mut sys,
5202                    registry,
5203                    penalty_scale,
5204                    row_layout.as_ref(),
5205                    dense_beta_curvature,
5206                    factored_row_projection,
5207                )
5208                .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5209        }
5210        // #1026 — decoder repulsion (collinearity-gated, registry-independent):
5211        // accumulate into the full-`B` β-tier here, BEFORE the frame transform,
5212        // so a framed system carries it identically to the analytic β penalties.
5213        // No-op unless two atoms are near-collinear (the frozen gate is `None`).
5214        if self.add_sae_decoder_repulsion(&mut sys, penalty_scale, dense_beta_curvature) {
5215            beta_penalty_assembly.record_curvature(dense_beta_curvature);
5216        }
5217        // #1026/#1522 — interior-point collapse-prevention barriers. The amplitude
5218        // barrier supplies the OUTWARD radial force at the zero-decoder collapse
5219        // point (the principal failure state the threshold repulsion skips), and
5220        // the separation barrier supplies the alignment-divergent separating
5221        // curvature on normalized shapes weighted by coactivation. Both accumulate
5222        // into the full-`B` β-tier here, BEFORE the frame transform, so a framed
5223        // system carries them identically to the analytic β penalties.
5224        // #1610 — on the dense path the barrier's Levenberg majorizer scatters
5225        // onto `sys.hbb`; on the matrix-free / framed production path `sys.hbb` is
5226        // unused, so the barrier hands back a per-atom scalar ridge which we fold
5227        // into `smooth_scaled_s` (the single source for the CPU composite penalty
5228        // op AND the device smooth blocks), restoring the collapse-prevention
5229        // curvature the operator was silently dropping there.
5230        let mut sep_atom_curv = vec![0.0_f64; self.atoms.len()];
5231        if self.add_sae_separation_barrier(
5232            &mut sys,
5233            penalty_scale,
5234            dense_beta_curvature,
5235            &mut sep_atom_curv,
5236        ) {
5237            if dense_beta_curvature {
5238                beta_penalty_assembly.record_curvature(true);
5239            } else {
5240                // Fold the per-atom majorizer `lev_k·I_{M_k}` into the smooth
5241                // penalty factor `λ S_k`. With `⊗ I_p` (full-`B`) or `⊗ I_{r_k}`
5242                // (factored, `U_kᵀU_k = I`) this is exactly the `lev_k·I` block
5243                // diagonal the dense path writes — and it now flows through the
5244                // structured penalty op and the device smooth blocks. No
5245                // `deferred_factored` mark: the curvature is in the smooth op, not
5246                // a deferred dense block, so the device path stays engaged.
5247                for atom_idx in 0..self.atoms.len() {
5248                    let c = sep_atom_curv[atom_idx];
5249                    if c > 0.0 {
5250                        let m = smooth_scaled_s[atom_idx].nrows();
5251                        for i in 0..m {
5252                            smooth_scaled_s[atom_idx][[i, i]] += c;
5253                        }
5254                        smooth_ops[atom_idx] = Arc::new(IdentityRightKroneckerPenaltyOp {
5255                            factor_a: smooth_scaled_s[atom_idx].clone(),
5256                            p,
5257                            global_offset: beta_offsets[atom_idx],
5258                            k: beta_dim,
5259                        });
5260                    }
5261                }
5262            }
5263        }
5264        if frames_engaged {
5265            // ── #972 / #977 T1 — FACTORED β-tier transform ──────────────────
5266            //
5267            // The entire β-tier above was assembled in the full-`B` (p-wide)
5268            // layout: `sys.gb` is `g_B` (length `beta_dim`), `sys.hbb` carries
5269            // any analytic Beta-tier penalty, and `g_blocks` is the
5270            // FRAME-INDEPENDENT basis Gram. We now rebuild the β-tier in the
5271            // factored coordinate space `C` (width `factored_border_dim`), the
5272            // full-`B` system sandwiched by `Φ = blkdiag(I_{M_k} ⊗ U_k)`:
5273            //   * gradient   `g_C = Φᵀ g_B`              (per atom `(g_B U_k)`),
5274            //   * data H      `Φᵀ(G⊗I_p)Φ = G_{ij}⊗(U_iᵀU_j)`,
5275            //   * smooth      `λ S_k ⊗ I_{r_k}`          (since `U_kᵀU_k = I`),
5276            //   * analytic    `Φᵀ hbb Φ`                 (dense, only if written).
5277            // Un-framed atoms ride the `r_k = p, U_k = I_p` identity special case.
5278            let off_c = &frame_projection.border_offsets;
5279            let ranks = &frame_projection.ranks;
5280            let basis_sizes = &frame_projection.basis_sizes;
5281            let border_dim = frame_projection.border_dim();
5282            let gb_c = frame_projection.project_border_vec(sys.gb.view());
5283
5284            // Data β-Hessian: `G_{ij} ⊗ W_{ij}` with `W_{ij} = U_iᵀU_j`. The
5285            // basis Gram `g_blocks` is unchanged; only the output factor is the
5286            // per-pair frame overlap (`I_{r_k}` within a framed atom, `I_p` for
5287            // un-framed).
5288            let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
5289            for ((atom_i, atom_j), data) in g_blocks.into_iter() {
5290                if data.iter().all(|&v| v == 0.0) {
5291                    continue;
5292                }
5293                // `W_{ij} = U_iᵀ U_j` from the precomputed per-atom frames.
5294                let w = self.frame_cross_factor(atom_i, atom_j);
5295                frame_blocks.push(FactoredFrameGBlock {
5296                    atom_i,
5297                    atom_j,
5298                    g: data,
5299                    w,
5300                });
5301            }
5302            // #1017/#1026 — snapshot the factored data-fit blocks for the
5303            // frames-engaged device PCG BEFORE `FactoredFrameKroneckerOp::new`
5304            // consumes them. Cheap clone (co-occurring blocks only).
5305            let device_frame_blocks = frame_blocks.clone();
5306            let data_op =
5307                FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
5308
5309            // Smooth penalty in factored space: `λ S_k ⊗ I_{r_k}` at `off_C[k]`.
5310            let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
5311            for k in 0..self.atoms.len() {
5312                let r = ranks[k];
5313                ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
5314                    factor_a: smooth_scaled_s[k].clone(),
5315                    p: r,
5316                    global_offset: off_c[k],
5317                    k: border_dim,
5318                }));
5319            }
5320            ops.push(Arc::new(data_op));
5321            // Analytic Beta-tier penalty: project the dense full-`B` `hbb` block
5322            // `Φᵀ hbb Φ` into the factored space. Only present when a Beta-tier
5323            // penalty actually wrote `hbb` (else `hbb` is all-zero and the dense
5324            // `(border_dim)²` op is skipped entirely, exactly as full-`B`).
5325            if beta_penalty_assembly.dense_written {
5326                let hbb_c =
5327                    self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
5328                ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5329            } else if beta_penalty_assembly.deferred_factored {
5330                // Registry Beta-tier curvature deferred to factored-space probing.
5331                // The registry may be absent when `deferred_factored` was set ONLY
5332                // by the frozen-gate decoder repulsion (which is
5333                // registry-independent), so start from a zero factored block in
5334                // that case instead of unwrapping.
5335                let mut hbb_c = match analytic_penalties {
5336                    Some(registry) => self.build_factored_beta_penalty_curvature(
5337                        registry,
5338                        penalty_scale,
5339                        &frame_projection,
5340                    ),
5341                    None => Array2::<f64>::zeros((
5342                        frame_projection.border_dim(),
5343                        frame_projection.border_dim(),
5344                    )),
5345                };
5346                // #1610 — the frozen-gate decoder repulsion's PSD majorizer was
5347                // dropped on this matrix-free/framed path (only its gradient was
5348                // applied). Project it into the factored block via the same
5349                // `psd_majorizer_hvp` + frame-projection probe pattern the registry
5350                // DecoderIncoherence uses, so the collapse-prevention curvature
5351                // reaches the operator here too. No-op when no repulsion is active.
5352                self.add_factored_repulsion_curvature(
5353                    &mut hbb_c,
5354                    penalty_scale,
5355                    &frame_projection,
5356                );
5357                ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5358            }
5359
5360            // Re-point the system's β-tier to the factored width. The t-tier
5361            // (per-row `htt`, `gt`) is frame-independent and untouched; row
5362            // cross-block slabs were allocated and assembled directly in
5363            // factored coordinates, so analytic row supplements and data-fit
5364            // cross terms already share shape `(q_i × factored_border_dim)`.
5365            sys.k = border_dim;
5366            sys.gb = gb_c;
5367            self.reclaim_border_hbb_workspace(&mut sys);
5368            // Factored per-atom block ranges for the block-Jacobi Schur
5369            // preconditioner: `[off_C[k] .. off_C[k] + M_k·r_k]`.
5370            let mut block_ranges: Vec<std::ops::Range<usize>> =
5371                Vec::with_capacity(self.atoms.len());
5372            for k in 0..self.atoms.len() {
5373                let start = off_c[k];
5374                block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
5375            }
5376            sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
5377            sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
5378            // #1017/#1026 — install the frames-engaged device SAE PCG data. Skipped
5379            // (CPU fallback) when a dense analytic Beta-tier penalty fired (the
5380            // device kernel does not model that extra dense term). Builder:
5381            // `crate::frames::build_framed_device_sae_data`.
5382            let has_dense_beta_penalty =
5383                beta_penalty_assembly.dense_written || beta_penalty_assembly.deferred_factored;
5384            if !has_dense_beta_penalty {
5385                let device = crate::frames::build_framed_device_sae_data(
5386                    crate::frames::FramedDeviceArgs {
5387                        p,
5388                        border_dim,
5389                        border_offsets: off_c.as_slice(),
5390                        ranks: ranks.as_slice(),
5391                        basis_sizes: basis_sizes.as_slice(),
5392                        smooth_scaled_s: &smooth_scaled_s,
5393                        frame_blocks: device_frame_blocks,
5394                        rows: &sys.rows,
5395                    },
5396                );
5397                sys.set_device_sae_pcg_data(device);
5398            }
5399        } else {
5400            let (device_a_phi, device_local_jac) =
5401                device_rows.expect("full-beta SAE PCG rows are cloned before row operator install");
5402            // Wire per-atom β block ranges so the Jacobi preconditioner builds one
5403            // dense Schur sub-block per atom (block-Jacobi) instead of scalar-diagonal
5404            // inversion.  Each atom's decoder coefficients form a natural block:
5405            // `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
5406            sys.set_block_offsets(self.beta_block_offsets());
5407            // Install the composite BetaPenaltyOp (#296): smoothness contributions
5408            // via per-atom KroneckerPenaltyOp (avoid dense K×K materialisation), the
5409            // data-fit Gauss-Newton β-Hessian as the structured `G ⊗ I_p`
5410            // SparseBlockKroneckerPenaltyOp (block-sparse over co-occurring
5411            // `(atom, atom')` pairs, block-diagonal across the `p` output channels,
5412            // identical per channel), plus — only when a Beta-tier analytic penalty
5413            // was written — the dense `sys.hbb` residual contribution. When no beta
5414            // penalty fired, `sys.hbb` is all-zero and the dense `(K·p)²` operator
5415            // is skipped entirely. The sparse data op tracks only the active-atom
5416            // couplings, so its storage and matvec cost scale with `k_active`, not
5417            // `K`, at `K = 100K`.
5418            // Convert the per-atom-pair coupling map into `SparseGBlock`s keyed
5419            // by μ-space offsets. Empty blocks (no co-occurrence) are simply
5420            // absent from the map.
5421            let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
5422                .into_iter()
5423                .filter_map(|((atom_i, atom_j), data)| {
5424                    if data.iter().all(|&v| v == 0.0) {
5425                        None
5426                    } else {
5427                        Some(SparseGBlock {
5428                            row_off: mu_offsets[atom_i],
5429                            col_off: mu_offsets[atom_j],
5430                            data,
5431                        })
5432                    }
5433                })
5434                .collect();
5435            let device_smooth_blocks = smooth_scaled_s
5436                .iter()
5437                .enumerate()
5438                .map(|(atom_idx, factor_a)| {
5439                    // #1117 — rank deficiency is removed at the basis layer, so the
5440                    // device PCG smooth block is just `λ S_k ⊗ I_p` (full-rank
5441                    // design); no data-null deflation is folded in here.
5442                    DeviceSaeSmoothBlock {
5443                        global_offset: beta_offsets[atom_idx],
5444                        factor_a: factor_a.clone(),
5445                    }
5446                })
5447                .collect();
5448            sys.set_device_sae_pcg_data(DeviceSaePcgData {
5449                p,
5450                beta_dim,
5451                a_phi: device_a_phi,
5452                local_jac: device_local_jac,
5453                smooth_blocks: device_smooth_blocks,
5454                sparse_g_blocks: g_sparse_blocks.clone(),
5455                frame: None,
5456            });
5457            let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
5458            ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
5459                p,
5460                dim_a: m_total,
5461                k: beta_dim,
5462                blocks: g_sparse_blocks,
5463            }));
5464            if beta_penalty_assembly.dense_written {
5465                ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
5466            }
5467            sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
5468            self.reclaim_border_hbb_workspace(&mut sys);
5469        }
5470        if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
5471            sys.set_row_gauge_deflation(deflation);
5472        }
5473        // #1038 IBP cross-row Woodbury source. The exact IBP Hessian has the
5474        // per-column rank-one cross-row block `H_(i,k),(j,k) = w·s'_k·z'_ik·z'_jk`
5475        // (for ALL `i,j`, including the `i=j` self term) that couples DISTINCT
5476        // latent rows through the shared empirical mass `M_k = Σ_i z_ik`. The
5477        // assembled row-block-diagonal `htt` already carries the `i=j` self term
5478        // `w·s'_k·z'_ik²` — it is the first summand of `assignment_hdiag`'s
5479        // `hessian_diag` value `w·(score_derivative·z_jac² + score·c_ik)` written
5480        // at the logit diagonal above. So the consumer (`solver::arrow_schur`,
5481        // #1038 `IbpCrossRowSource`/`CrossRowWoodbury`) DOWNDATES exactly
5482        // `Σ_k d_k·z'_ik²` (`self_term_downdate`) to recover the NO-SELF base
5483        // `H₀'`, then re-adds the FULL rank-one `U D Uᵀ` via the determinant
5484        // lemma — so value, the evidence log-determinant, and the θ/ρ-adjoint all
5485        // differentiate the SAME `H_full = H₀' + U D Uᵀ`.
5486        //
5487        // The source is built from the SAME `ibp_assignment_third_channels`
5488        // operator the #1006 θ-adjoint consumes:
5489        //   * `d[k] = cross_row_d[k] = w·s'_k = w·score_derivative_k` (the column
5490        //     `D`-coefficient — NOT sign-definite, hence the consumer's
5491        //     indefinite-capacitance LU);
5492        //   * `entries[(i,k)] = (global_t_index, k, z'_ik)` with `z'_ik =
5493        //     z_jac[i·K + k]`. For the DENSE layout (`assignment_coord_dim() = K`,
5494        //     `last_row_layout = None`) atom `k`'s logit slot is local position `k`
5495        //     of row `i`'s block, so `global_t_index = sys.row_offsets[i] + k`. For
5496        //     the COMPACT layout (#1420) only the row's active atoms are
5497        //     coordinates and atom `k` lives at local position `pos` of
5498        //     `active_atoms[row]`, so `global_t_index = sys.row_offsets[i] + pos`.
5499        //     Both pin the `U`-column convention bit-for-bit to the consumer's
5500        //     `ibp_logit_sites`/`row_vars_for_cache_row` slot mapping.
5501        if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
5502            let mut entries: Vec<(usize, usize, f64)> = Vec::with_capacity(n * k_atoms);
5503            for row in 0..n {
5504                let start = row * k_atoms;
5505                let g_base = sys.row_offsets[row];
5506                match row_layout.as_ref() {
5507                    // #1420: compact layout — the local logit slot `pos` (not the
5508                    // global atom index `k`) is the t-coordinate. Atom `k`'s logit
5509                    // lives at local position `pos` of `active_atoms[row]`, so emit
5510                    // `(g_base + pos, atom, z_jac[row·K + atom])` for the active set
5511                    // only. Using `g_base + k` would attach atom `k`'s derivative to
5512                    // the wrong slot (and run out of range for compact rows),
5513                    // violating the `IbpCrossRowSource` contract.
5514                    Some(layout) => {
5515                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
5516                            let z_prime = channels.z_jac[start + atom];
5517                            entries.push((g_base + pos, atom, z_prime));
5518                        }
5519                    }
5520                    // Dense layout: atom `k`'s logit slot is local position `k`.
5521                    None => {
5522                        for k in 0..k_atoms {
5523                            let z_prime = channels.z_jac[start + k];
5524                            entries.push((g_base + k, k, z_prime));
5525                        }
5526                    }
5527                }
5528            }
5529            let source = IbpCrossRowSource {
5530                r: k_atoms,
5531                d: channels.cross_row_d.clone(),
5532                entries,
5533            };
5534            sys.set_ibp_cross_row_source(source);
5535        }
5536        // Store the active-set layout for `apply_newton_step`.
5537        self.last_row_layout = row_layout;
5538        // Record whether `delta_beta` from this system is a factored ΔC (needs a
5539        // frame lift) or a full-`B` ΔB. Read by `apply_newton_step_impl`.
5540        self.last_frames_active = frames_engaged;
5541        Ok(sys)
5542    }
5543
5544    /// Project a dense full-`B` Beta-tier penalty Hessian `hbb` (`beta_dim ×
5545    /// beta_dim`, the analytic `∂²P/∂B∂B` block) into the factored coordinate
5546    /// space `Φᵀ hbb Φ` (`border_dim × border_dim`) for the #972 / #977 T1
5547    /// frame transform. `Φ = blkdiag(I_{M_k} ⊗ U_k)` maps C-space → B-space, so
5548    /// the projected block contracts both index legs through the per-atom frames.
5549    ///
5550    /// The projection is done in two passes to stay `O(beta_dim · border_dim +
5551    /// border_dim²)` instead of forming the dense `Φ` explicitly: first
5552    /// `T = hbb · Φ` (right multiply, columns fold `U`), then `Φᵀ · T` (left
5553    /// multiply, rows fold `U`). Analytic Beta-tier penalties are rare and small,
5554    /// so this only fires when one is actually installed.
5555    pub(crate) fn project_dense_penalty_to_factored(
5556        &self,
5557        hbb: ArrayView2<'_, f64>,
5558        projection: &FrameProjection,
5559    ) -> Array2<f64> {
5560        projection.project_block(hbb)
5561    }
5562
5563    pub(crate) fn build_factored_beta_penalty_curvature(
5564        &self,
5565        registry: &AnalyticPenaltyRegistry,
5566        penalty_scale: f64,
5567        projection: &FrameProjection,
5568    ) -> Array2<f64> {
5569        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
5570        let layout = registry.rho_layout();
5571        let target_beta = self.flatten_beta();
5572        let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
5573        for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
5574            if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
5575                continue;
5576            }
5577            let rho_local = rho_global.slice(s![rho_slice.clone()]);
5578            match tier {
5579                PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
5580                    self.add_factored_beta_penalty_curvature_for_penalty(
5581                        &mut hbb_c,
5582                        penalty,
5583                        target_beta.view(),
5584                        rho_local,
5585                        penalty_scale,
5586                        projection,
5587                    );
5588                }
5589                PenaltyTier::Beta => {
5590                    self.add_factored_beta_penalty_curvature_for_penalty(
5591                        &mut hbb_c,
5592                        penalty,
5593                        target_beta.view(),
5594                        rho_local,
5595                        penalty_scale,
5596                        projection,
5597                    );
5598                }
5599                _ => {}
5600            }
5601        }
5602        hbb_c
5603    }
5604
5605    pub(crate) fn add_factored_beta_penalty_curvature_for_penalty(
5606        &self,
5607        hbb_c: &mut Array2<f64>,
5608        penalty: &AnalyticPenaltyKind,
5609        target_beta: ArrayView1<'_, f64>,
5610        rho_local: ArrayView1<'_, f64>,
5611        penalty_scale: f64,
5612        projection: &FrameProjection,
5613    ) {
5614        let p = self.output_dim();
5615        if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
5616            let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
5617                return;
5618            };
5619            let beta_dim = self.beta_dim();
5620            let mut probe = Array1::<f64>::zeros(beta_dim);
5621            for k in 0..self.atoms.len() {
5622                for basis_col in 0..projection.basis_sizes[k] {
5623                    for frame_col in 0..projection.ranks[k] {
5624                        probe.fill(0.0);
5625                        projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5626                        let col = projection.border_offsets[k]
5627                            + basis_col * projection.ranks[k]
5628                            + frame_col;
5629                        let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5630                        projection
5631                            .project_border_vec(hv.view())
5632                            .iter()
5633                            .enumerate()
5634                            .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5635                    }
5636                }
5637            }
5638            return;
5639        }
5640        if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
5641            for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
5642                let atom_idx = projection
5643                    .beta_offsets
5644                    .iter()
5645                    .position(|&offset| offset == start)
5646                    .expect("live mechanism-sparsity offset must match an SAE atom");
5647                let block_len = end - start;
5648                let mut local_penalty = per_atom.clone();
5649                local_penalty.target = PsiSlice {
5650                    range: 0..block_len,
5651                    latent_dim: Some(projection.basis_sizes[atom_idx]),
5652                };
5653                let block = target_beta.slice(s![start..end]);
5654                let mut probe = Array1::<f64>::zeros(block_len);
5655                for basis_col in 0..projection.basis_sizes[atom_idx] {
5656                    for frame_col in 0..projection.ranks[atom_idx] {
5657                        probe.fill(0.0);
5658                        projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5659                        let col = projection.border_offsets[atom_idx]
5660                            + basis_col * projection.ranks[atom_idx]
5661                            + frame_col;
5662                        let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
5663                        projection.project_local_atom_vec_into(
5664                            atom_idx,
5665                            hv.view(),
5666                            hbb_c.column_mut(col),
5667                            penalty_scale,
5668                        );
5669                    }
5670                }
5671            }
5672            return;
5673        }
5674        if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
5675            for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
5676                let atom_idx = projection
5677                    .beta_offsets
5678                    .iter()
5679                    .position(|&offset| offset == start)
5680                    .expect("live nuclear-norm offset must match an SAE atom");
5681                let block = target_beta.slice(s![start..end]);
5682                let block_len = end - start;
5683                let mut probe = Array1::<f64>::zeros(block_len);
5684                for basis_col in 0..projection.basis_sizes[atom_idx] {
5685                    for frame_col in 0..projection.ranks[atom_idx] {
5686                        probe.fill(0.0);
5687                        projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5688                        let col = projection.border_offsets[atom_idx]
5689                            + basis_col * projection.ranks[atom_idx]
5690                            + frame_col;
5691                        let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
5692                        projection.project_local_atom_vec_into(
5693                            atom_idx,
5694                            hv.view(),
5695                            hbb_c.column_mut(col),
5696                            penalty_scale,
5697                        );
5698                    }
5699                }
5700            }
5701            return;
5702        }
5703        let beta_dim = self.beta_dim();
5704        let mut probe = Array1::<f64>::zeros(beta_dim);
5705        for k in 0..self.atoms.len() {
5706            for basis_col in 0..projection.basis_sizes[k] {
5707                for frame_col in 0..projection.ranks[k] {
5708                    probe.fill(0.0);
5709                    projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5710                    let col =
5711                        projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5712                    let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5713                    projection
5714                        .project_border_vec(hv.view())
5715                        .iter()
5716                        .enumerate()
5717                        .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5718                }
5719            }
5720        }
5721        assert_eq!(p, self.output_dim());
5722    }
5723
5724    /// #1610 — project the frozen-gate decoder-repulsion PSD majorizer into the
5725    /// factored β block `hbb_c`. Mirrors the `DecoderIncoherence` arm of
5726    /// [`Self::add_factored_beta_penalty_curvature_for_penalty`] but sources the
5727    /// penalty from [`Self::live_decoder_repulsion_penalty`] (registry-independent,
5728    /// collinearity-gated), so the repulsion curvature reaches the operator on the
5729    /// matrix-free/framed path where the dense `sys.hbb` write is unused. No-op
5730    /// when no repulsion is active.
5731    pub(crate) fn add_factored_repulsion_curvature(
5732        &self,
5733        hbb_c: &mut Array2<f64>,
5734        penalty_scale: f64,
5735        projection: &FrameProjection,
5736    ) {
5737        let Some(per_fit) = self.live_decoder_repulsion_penalty() else {
5738            return;
5739        };
5740        let beta_dim = self.beta_dim();
5741        let target_beta = self.flatten_beta();
5742        // The repulsion penalty is non-learnable; its strength is already folded
5743        // into the frozen gate (see `live_decoder_repulsion_penalty`), so the rho
5744        // slice is empty/inert.
5745        let rho_local = Array1::<f64>::zeros(0);
5746        let mut probe = Array1::<f64>::zeros(beta_dim);
5747        for k in 0..self.atoms.len() {
5748            for basis_col in 0..projection.basis_sizes[k] {
5749                for frame_col in 0..projection.ranks[k] {
5750                    probe.fill(0.0);
5751                    projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5752                    let col =
5753                        projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5754                    let hv =
5755                        per_fit.psd_majorizer_hvp(target_beta.view(), rho_local.view(), probe.view());
5756                    projection
5757                        .project_border_vec(hv.view())
5758                        .iter()
5759                        .enumerate()
5760                        .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5761                }
5762            }
5763        }
5764    }
5765
5766    pub(crate) fn ext_coord_matrix(&self) -> Array2<f64> {
5767        let n = self.n_obs();
5768        let q = self.assignment.row_block_dim();
5769        let flat = self.assignment.flatten_ext_coords();
5770        let mut out = Array2::<f64>::zeros((n, q));
5771        for row in 0..n {
5772            for col in 0..q {
5773                out[[row, col]] = flat[row * q + col];
5774            }
5775        }
5776        out
5777    }
5778
5779    pub(crate) fn ext_coord_manifold(&self) -> LatentManifold {
5780        let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
5781        for _ in 0..self.assignment.assignment_coord_dim() {
5782            parts.push(LatentManifold::Euclidean);
5783        }
5784        let mut any_constrained = false;
5785        for coord in &self.assignment.coords {
5786            if coord.manifold().is_euclidean() {
5787                for _ in 0..coord.latent_dim() {
5788                    parts.push(LatentManifold::Euclidean);
5789                }
5790            } else {
5791                any_constrained = true;
5792                parts.push(coord.manifold().clone());
5793            }
5794        }
5795        if any_constrained {
5796            LatentManifold::Product(parts)
5797        } else {
5798            LatentManifold::Euclidean
5799        }
5800    }
5801
5802    pub(crate) fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
5803        let manifold = self.ext_coord_manifold();
5804        if manifold.is_euclidean() {
5805            return;
5806        }
5807        let ext = self.ext_coord_matrix();
5808        let latent =
5809            LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
5810        sys.apply_riemannian_latent_geometry(&latent);
5811    }
5812
5813    /// Build the compact-layout ext-coord product manifold and point for one row.
5814    ///
5815    /// The dense `ext_coord_manifold()` is keyed to the full-`q` block ordering
5816    /// `[assignment parts (all Euclidean for IBP-MAP / JumpReLU), then per-atom
5817    /// coord blocks in atom order]`. A compact active-set row instead lays its
5818    /// `q_active` columns out as `[one Euclidean logit slot per active atom,
5819    /// then each active atom's coord block in `active` order]` (see
5820    /// [`SaeRowLayout::from_active_atoms`] / `coord_starts`). To reuse the exact
5821    /// per-row Riemannian projector on the compact block we rebuild a product
5822    /// manifold and the matching ext-coord point in that compact order: the
5823    /// `active.len()` logit slots are `Euclidean` (the assignment channel is
5824    /// always Euclidean for the modes that engage sparsity — `assignment_coord_dim
5825    /// == k_atoms`), and each active atom contributes its own coordinate
5826    /// manifold. On the shared active support this is byte-identical to slicing
5827    /// the dense full-`q` product manifold, so the compact projection matches the
5828    /// dense path exactly — it only drops the inactive atoms' (negligible-mass)
5829    /// coordinate blocks the compact layout already excludes from curvature.
5830    ///
5831    /// Returns `(manifold, t_compact)` where `t_compact` has length `q_active`.
5832    /// The logit-slot entries of `t_compact` are filled from the row logits (the
5833    /// Euclidean projector ignores the point, so any finite value is equivalent;
5834    /// using the true logits keeps the point well-defined and finite).
5835    pub(crate) fn compact_row_ext_manifold_and_point(
5836        &self,
5837        row: usize,
5838        layout: &SaeRowLayout,
5839    ) -> (LatentManifold, Array1<f64>) {
5840        let active = &layout.active_atoms[row];
5841        let q_active = layout.row_q_active(row);
5842        let mut parts: Vec<LatentManifold> = Vec::with_capacity(active.len() + active.len());
5843        let mut point = Array1::<f64>::zeros(q_active);
5844        // Logit slots: one Euclidean part per active atom, in `active` order.
5845        let logits_row = self.assignment.logits.row(row);
5846        for (j, &k) in active.iter().enumerate() {
5847            parts.push(LatentManifold::Euclidean);
5848            point[j] = logits_row[k];
5849        }
5850        // Coordinate blocks: each active atom's coordinate manifold + point, at
5851        // the compact coord start the layout assigned it.
5852        for (j, &k) in active.iter().enumerate() {
5853            let coord = &self.assignment.coords[k];
5854            let d = coord.latent_dim();
5855            let coord_start = layout.coord_starts[row][j];
5856            let manifold_k = coord.manifold();
5857            // A `d`-dim coordinate whose manifold is a product (e.g. a torus =
5858            // Circle×Circle) already carries its `d` parts; a scalar manifold is
5859            // one part. Either way the manifold's ambient width must equal `d`,
5860            // matching the `d` compact columns at `coord_start`.
5861            parts.push(manifold_k.clone());
5862            let coord_point = coord.row(row);
5863            for axis in 0..d {
5864                point[coord_start + axis] = coord_point[axis];
5865            }
5866        }
5867        (LatentManifold::Product(parts), point)
5868    }
5869
5870    /// Numerical rank of a symmetric matrix: the count of eigenvalues
5871    /// exceeding `tol · max_eig`, with `tol = 1e-9` (the conventional
5872    /// relative spectral cutoff used elsewhere in the codebase).
5873    ///
5874    /// Used to count the penalised dimension of each atom's `smooth_penalty`
5875    /// `S_k` so the REML criterion's `−½·p·rank(S)·log λ_smooth` Occam term
5876    /// uses the *effective* penalty rank rather than the ambient basis size
5877    /// (a thin-plate / B-spline penalty has a non-trivial null space).
5878    pub(crate) fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
5879        if s.nrows() != s.ncols() {
5880            return Err(format!(
5881                "SaeManifoldTerm::symmetric_rank: matrix must be square, got {}x{}",
5882                s.nrows(),
5883                s.ncols()
5884            ));
5885        }
5886        let m = s.ncols();
5887        if m == 0 {
5888            return Ok(0);
5889        }
5890        // Symmetrize defensively through the shared ndarray helper. The SAE
5891        // rank cutoff is intentionally local to the SAE evidence contract; only
5892        // the symmetric cleanup is shared with the other construction modules.
5893        let mut sym = s.clone();
5894        gam_linalg::matrix::symmetrize_in_place(&mut sym);
5895        let (evals, _evecs) = sym
5896            .eigh(Side::Lower)
5897            .map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
5898        let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
5899        if !(max_eig > 0.0) {
5900            return Ok(0);
5901        }
5902        let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
5903        Ok(evals.iter().filter(|&&v| v > tol).count())
5904    }
5905
5906    /// Penalised quasi-Laplace evidence score for the SAE term at a FIXED ρ.
5907    ///
5908    /// #1421: this is NOT a true normalized-prior REML/evidence objective. The
5909    /// assignment priors (softmax entropy, JumpReLU) have NO finite normalizer:
5910    /// for softmax the reference-logit chart sends `P(ℓ)→0` as a free logit →±∞
5911    /// so `∫ e^{−λP} dℓ = ∞`, and JumpReLU's bounded penalty `0<P<λ` keeps
5912    /// `e^{−λP}` bounded below over an unbounded domain, also divergent. There is
5913    /// therefore no ρ-independent assignment-prior normalizer that can be dropped
5914    /// as a constant. The smoothing-penalty `−½log|λS|_+` term IS a genuine
5915    /// (proper-Gaussian) REML normalizer and is kept exactly; the rest is a
5916    /// penalized quasi-Laplace score (Laplace curvature term `½log|H|` around the
5917    /// inner optimum), which the engine minimizes over ρ.
5918    ///
5919    /// Runs the inner `(t, β)` arrow-Schur Newton solve to convergence at the
5920    /// supplied ρ (with NO in-loop ARD update — ρ is owned by the engine),
5921    /// then forms the Laplace/REML cost
5922    ///
5923    /// ```text
5924    /// V(ρ) = ℓ_pen(t̂, β̂; ρ) + ½ log|H(t̂, β̂; ρ)|
5925    ///        − ½ · p · (Σ_k rank S_k) · log λ_smooth
5926    /// ```
5927    ///
5928    /// where `ℓ_pen = loss.total()` is the penalised objective at the inner
5929    /// optimum and `½ log|H|` is the Laplace normaliser. `H` is the joint
5930    /// `(t, β)` Hessian assembled by the arrow-Schur system; its `H_tt` block
5931    /// carries `α = exp(log_ard)` on its diagonal, so as α grows `½ log|H|`
5932    /// rises while the `−½·n·log α` already inside `loss.ard` falls — their
5933    /// balance IS the effective-dof term that the deleted `α = n/‖t‖²` rule
5934    /// dropped, which is why the criterion needs no clamp to stay finite on a
5935    /// collapsing axis.
5936    ///
5937    /// The final `−½·p·rank(S)·log λ_smooth` term is the smoothing-penalty
5938    /// normaliser `−½ log|λ S|_+` restricted to its ρ-dependent part: `S_k` is
5939    /// shared across all `p` decoder output channels (the `⊗ I_p` Kronecker
5940    /// structure), so `log|λ S|_+ = p·rank(S)·log λ + p·log|S|_+`, and the
5941    /// `½ p·log|S|_+` piece is ρ-independent. The ρ-independent additive
5942    /// constants that ARE dropped here (they shift `V` by a constant and do not
5943    /// affect the ρ-argmin) are the `2π` Laplace constant and the base
5944    /// `½ p·log|S|_+` penalty logdet. #1421: NO assignment-prior normalizer is
5945    /// dropped, because none exists (softmax/JumpReLU priors are improper — see
5946    /// the doc on this function): the quasi-Laplace score simply omits a
5947    /// normalizer that is not a finite constant.
5948    ///
5949    /// Returns `(V, loss)` so the engine can both rank ρ and surface the inner
5950    /// loss breakdown.
5951    pub fn reml_criterion(
5952        &mut self,
5953        target: ArrayView2<'_, f64>,
5954        rho: &SaeManifoldRho,
5955        registry: Option<&AnalyticPenaltyRegistry>,
5956        inner_max_iter: usize,
5957        learning_rate: f64,
5958        ridge_ext_coord: f64,
5959        ridge_beta: f64,
5960    ) -> Result<(f64, SaeManifoldLoss), String> {
5961        self.reml_criterion_with_refine_policy(
5962            target,
5963            rho,
5964            registry,
5965            inner_max_iter,
5966            learning_rate,
5967            ridge_ext_coord,
5968            ridge_beta,
5969            true,
5970        )
5971    }
5972
5973    pub(crate) fn reml_criterion_with_refine_policy(
5974        &mut self,
5975        target: ArrayView2<'_, f64>,
5976        rho: &SaeManifoldRho,
5977        registry: Option<&AnalyticPenaltyRegistry>,
5978        inner_max_iter: usize,
5979        learning_rate: f64,
5980        ridge_ext_coord: f64,
5981        ridge_beta: f64,
5982        refine_progress_extension: bool,
5983    ) -> Result<(f64, SaeManifoldLoss), String> {
5984        let plan = self.streaming_plan().admitted_or_error(
5985            self.n_obs(),
5986            self.output_dim(),
5987            self.k_atoms(),
5988        )?;
5989        if plan.streaming {
5990            // #1225: streaming and dense MUST optimize the SAME mathematical
5991            // objective — the full REML criterion `loss.total() + extra_penalty +
5992            // ½ log|H| − Occam`. The streaming branch previously returned only
5993            // `loss.total() + extra_penalty_energy`, dropping the Laplace
5994            // normalizer `½ log|H|` and the Occam term, so large shapes (exactly
5995            // where streaming is needed) were ranked by penalized loss rather than
5996            // REML — and dense vs streaming disagreed on the objective. Route
5997            // through the streaming exact-logdet path, which assembles the same
5998            // chunk-by-chunk-bit-identical `½ log|H|_stream` and the same
5999            // `−Occam`/extra-penalty terms as the dense `reml_criterion_with_cache`
6000            // (different memory strategy, same objective).
6001            self.reml_criterion_streaming_exact(
6002                target,
6003                rho,
6004                registry,
6005                inner_max_iter,
6006                learning_rate,
6007                ridge_ext_coord,
6008                ridge_beta,
6009            )
6010        } else {
6011            let (v, loss, _cache) = self.reml_criterion_with_cache_refine_policy(
6012                target,
6013                rho,
6014                registry,
6015                inner_max_iter,
6016                learning_rate,
6017                ridge_ext_coord,
6018                ridge_beta,
6019                refine_progress_extension,
6020            )?;
6021            Ok((v, loss))
6022        }
6023    }
6024
6025    /// As [`Self::reml_criterion`], but also returns the converged undamped
6026    /// `ArrowFactorCache` so callers (the EFS fixed-point step) can read the
6027    /// selected-inverse traces `(H⁻¹)_tt` / `(H⁻¹)_ββ` without re-factoring.
6028    /// The cache is the single shared O(K³) Direct factor; both the
6029    /// log-determinant criterion and the Fellner-Schall ρ-step consume it.
6030    pub fn reml_criterion_with_cache(
6031        &mut self,
6032        target: ArrayView2<'_, f64>,
6033        rho: &SaeManifoldRho,
6034        registry: Option<&AnalyticPenaltyRegistry>,
6035        inner_max_iter: usize,
6036        learning_rate: f64,
6037        ridge_ext_coord: f64,
6038        ridge_beta: f64,
6039    ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6040        self.reml_criterion_with_cache_refine_policy(
6041            target,
6042            rho,
6043            registry,
6044            inner_max_iter,
6045            learning_rate,
6046            ridge_ext_coord,
6047            ridge_beta,
6048            true,
6049        )
6050    }
6051
6052    pub(crate) fn reml_criterion_with_cache_refine_policy(
6053        &mut self,
6054        target: ArrayView2<'_, f64>,
6055        rho: &SaeManifoldRho,
6056        registry: Option<&AnalyticPenaltyRegistry>,
6057        inner_max_iter: usize,
6058        learning_rate: f64,
6059        ridge_ext_coord: f64,
6060        ridge_beta: f64,
6061        refine_progress_extension: bool,
6062    ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6063        let admission_plan = self.streaming_plan().admitted_or_error(
6064            self.n_obs(),
6065            self.output_dim(),
6066            self.k_atoms(),
6067        )?;
6068        if !admission_plan.direct_logdet_admitted() {
6069            return Err(format!(
6070                "SaeManifoldTerm::reml_criterion_with_cache: predicted working set {} bytes exceeds budget {} bytes for dense evidence cache; shape n={},p={},K={}; cost-only streaming route is required",
6071                admission_plan.estimated_direct_peak_bytes,
6072                admission_plan.in_core_budget_bytes,
6073                self.n_obs(),
6074                self.output_dim(),
6075                self.k_atoms()
6076            ));
6077        }
6078        // 1. Run the inner (t, β) Newton solve to convergence at FIXED ρ.
6079        //    `run_joint_fit_arrow_schur` no longer touches ρ.
6080        let mut rho_fixed = rho.clone();
6081        let mut loss = self.run_joint_fit_arrow_schur(
6082            target,
6083            &mut rho_fixed,
6084            registry,
6085            inner_max_iter,
6086            learning_rate,
6087            ridge_ext_coord,
6088            ridge_beta,
6089        )?;
6090
6091        // 2. Drive the inner (t, β) solve to the KKT/step-converged optimum and
6092        //    take one final UNDAMPED factor there to obtain the joint Hessian
6093        //    log-determinant. We force ridge = 0 and the dense `Direct` Schur
6094        //    mode so `arrow_log_det_from_cache` returns the exact
6095        //    `log|H| = Σ_i log|H_tt^(i)| + log|Schur_β|` (it rejects damped
6096        //    factors and InexactPCG caches, which have no dense Schur factor).
6097        //    This is the same evidence convention the main GAM REML path uses.
6098        //    The shared `converge_inner_for_undamped_logdet` driver guarantees
6099        //    the per-row `H_tt^(i)` blocks are PD at the converged optimum so
6100        //    the undamped (`ridge = 0`) factorization succeeds — the streaming
6101        //    log-det path reuses the identical driver so both rank the same
6102        //    converged Laplace optimum and stay bit-identical.
6103        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
6104        let cache = self.converge_inner_for_undamped_logdet(
6105            target,
6106            rho,
6107            &mut rho_fixed,
6108            registry,
6109            inner_max_iter,
6110            learning_rate,
6111            ridge_ext_coord,
6112            ridge_beta,
6113            &mut loss,
6114            &options,
6115            refine_progress_extension,
6116        )?;
6117        self.record_evidence_gauge_deflation_count(cache.gauge_deflated_directions)?;
6118        loss.evidence_gauge_deflated_directions = cache.gauge_deflated_directions;
6119        let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
6120            // Distinguish a GENUINE infeasibility — a probed ρ where the joint
6121            // Hessian is not PD so the Laplace evidence log-det is undefined —
6122            // from a real factorization defect. The cross-row IBP Woodbury
6123            // capacitance `C = I_R + D·Uᵀ H₀'⁻¹ U` can have det ≤ 0 at a ρ the
6124            // outer optimizer line-searches into (the indefinite basin adjacent
6125            // to the PD region); there the log-det legitimately does not exist.
6126            // That refusal must be RECOVERABLE (the outer BFGS should get +∞ and
6127            // steer back into the PD region), exactly like the "non-PD per-row
6128            // H_tt block" refusal — not a fatal `RemlOptimizationFailed` that
6129            // aborts the whole fit. See `is_recoverable_value_probe_refusal`.
6130            // (The old message claimed "no dense Schur factor", which is false
6131            // here — the Schur factor is present; the Woodbury correction is the
6132            // non-finite term.)
6133            if cache.cross_row_woodbury.is_some()
6134                && !cache.cross_row_woodbury_log_det().is_finite()
6135            {
6136                "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
6137                 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
6138                    .to_string()
6139            } else {
6140                "SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None \
6141                 (undamped joint Hessian log-det unavailable for the Laplace normaliser)"
6142                    .to_string()
6143            }
6144        })?;
6145
6146        // 3. Smoothing-penalty Occam term `−½·Σ_k r_k·rank(S_k)·log λ_smooth`
6147        //    plus the profiled-frame evidence-dimension correction
6148        //    `+½·Σ_k r_k·(p−r_k)·log λ_smooth` (issue #972). On the full-`B` path
6149        //    (`r_k == p`, no frames) this is exactly the historical
6150        //    `½·p·(Σ rank S_k)·log λ_smooth`, so the small-model criterion is
6151        //    unchanged. The single seam is `reml_occam_term`, shared with the
6152        //    streaming path so both rank the identical Laplace dimension count.
6153        let occam = self.reml_occam_term(rho)?;
6154
6155        // Decoder-block analytic-penalty energy (#671/#672). The inner solve
6156        // descended this energy (it enters `gb`/`hbb`) but it had no native
6157        // `loss.*` representative, so the Laplace criterion `v` was scoring a
6158        // different objective than the one minimized. Add the converged
6159        // decoder-penalty value so the ρ-sweep ranks the same penalized
6160        // deviance. Excludes the Psi-tier ARD/assignment penalties already
6161        // accounted for in `loss.total()` (see
6162        // `analytic_decoder_penalty_value_total`).
6163        // Extra analytic-penalty energy (#671/#737). Decoder-block penalties and
6164        // coordinate-tier isometry enter the inner solve but have no `loss.*`
6165        // representative, so the Laplace criterion must add them explicitly to
6166        // rank the same penalized deviance the Newton solve descends.
6167        let extra_penalty_energy = match registry {
6168            Some(reg) => self
6169                .reml_extra_penalty_value_total(reg)
6170                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
6171            None => 0.0,
6172        };
6173
6174        let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
6175        Ok((v, loss, cache))
6176    }
6177
6178    /// The #1037 quotient-dimension invariant: a Laplace normalizer `½log|H|` is
6179    /// only comparable across ρ at a COMMON quotient (gauge-deflation) dimension.
6180    /// The first observation pins the expected count; a later match is a no-op.
6181    ///
6182    /// A later observation that DIFFERS is, under the K>1 fit, a LEGITIMATE
6183    /// quotient-dimension event — an atom born, reseeded (the #976 collapse
6184    /// guards), or rank-reduced moves the number of gauge-flat rows. Because a
6185    /// deflated direction is lifted to unit stiffness and contributes the
6186    /// ρ-independent `log 1 = 0` to the evidence, re-anchoring the comparison to
6187    /// the new dimension is exactly evidence-preserving and keeps every future
6188    /// cross-ρ comparison consistent — the principled response, not an abort.
6189    ///
6190    /// The genuine pathology the guard still catches is a count that NEVER
6191    /// STABILIZES: re-anchors are bounded by the per-atom structural-event budget
6192    /// (`k·(reseed_budget+1)+1`), and a runaway quotient dimension past that
6193    /// bound refuses loudly. This supersedes the prior strict-constant guard and
6194    /// its ±1 flicker band (#1117) at root — the band was masking exactly the
6195    /// legitimate K>1 dimension changes this re-anchoring now handles.
6196    pub(crate) fn record_evidence_gauge_deflation_count(
6197        &mut self,
6198        count: usize,
6199    ) -> Result<(), String> {
6200        match self.expected_evidence_gauge_deflated_directions {
6201            Some(expected) if expected == count => Ok(()),
6202            Some(expected) => {
6203                // A change in the gauge-deflation count between two evidence
6204                // factorizations is a legitimate quotient-dimension event under
6205                // the K>1 fit: an atom can be born, reseeded (the #976 collapse
6206                // guards), or rank-reduced across the ρ-walk, and each such event
6207                // moves the number of gauge-flat rows. The #1037 invariant is
6208                // NOT "the count never changes" — it is "two Laplace normalizers
6209                // are only comparable at a COMMON quotient dimension". The
6210                // principled response to a legitimate change is therefore to
6211                // RE-ANCHOR the comparison to the new dimension (so every future
6212                // cross-ρ comparison within the optimization is consistent), not
6213                // to abort the fit. This is exactly evidence-preserving: each
6214                // gauge-deflated direction is lifted to unit stiffness and
6215                // contributes the ρ-independent `log 1 = 0` to `½log|H|`, so the
6216                // converged criterion value is identical whether a given row is
6217                // counted as deflated or not — only the BOOKKEEPING dimension
6218                // must agree across a comparison, and re-anchoring restores that.
6219                //
6220                // The genuine pathology the guard must still catch is a count
6221                // that NEVER STABILIZES — an OSCILLATING quotient dimension that
6222                // re-anchors without converging, signalling a truly ill-posed
6223                // evidence surface. But the deflation count is NOT a discrete
6224                // dictionary-level event count: it is the per-ROW-summed number of
6225                // near-null evidence directions across all N rows (#1217). On real
6226                // K≥2 activations it is an O(N) quantity that drifts SMOOTHLY and
6227                // monotonically as the conditioning improves over the ρ-walk
6228                // (e.g. 171→156→…→113 as smoothing increases) — a benign,
6229                // evidence-neutral change (each deflated direction contributes the
6230                // ρ-independent `log 1 = 0` to `½log|H|`, so re-anchoring never
6231                // moves the criterion value). Charging such a monotone drift
6232                // against a `k`-sized "structural event" budget was wrong: it
6233                // counts threshold crossings of a continuous per-row quantity, not
6234                // atom births/reseeds, so the budget tripped on a perfectly healthy
6235                // converging K=2 fit (#1217 regression from the #1189/#1190
6236                // basin-escape fixes, which shifted which rows sit near the
6237                // deflation floor).
6238                //
6239                // The principled discriminator is DIRECTION REVERSALS: a count
6240                // that drifts one way and settles is benign; a count that bounces
6241                // up and down without settling is the oscillating-quotient
6242                // pathology. We therefore charge the re-anchor budget ONLY on a
6243                // reversal of the change direction, and size the budget by the
6244                // number of distinct dictionary structural events (births/reseeds)
6245                // that can each legitimately flip the drift direction. A monotone
6246                // drift of any length re-anchors freely (it is consistently
6247                // re-anchored and evidence-neutral); a genuinely oscillating count
6248                // exhausts the reversal budget and refuses loudly.
6249                let delta_sign: i8 = if count > expected { 1 } else { -1 };
6250                let is_reversal = self.evidence_gauge_deflation_last_delta_sign != 0
6251                    && delta_sign != self.evidence_gauge_deflation_last_delta_sign;
6252                self.evidence_gauge_deflation_last_delta_sign = delta_sign;
6253                // A reversal alone is NOT the pathology — a BOUNDED flicker of a
6254                // few rows crossing the near-null deflation floor reverses
6255                // direction every step yet is the discretization jitter of a
6256                // continuous evidence spectrum, fully evidence-neutral (each
6257                // deflated direction contributes `log 1 = 0` either way). The
6258                // genuine "quotient dimension not stabilizing" pathology is a
6259                // WIDE-amplitude oscillation: a substantial FRACTION of the
6260                // dimension flipping back and forth. The count is an O(N) per-row
6261                // sum, so the discriminator must be the reversal AMPLITUDE
6262                // relative to the dimension level, not the bare reversal. Charge
6263                // the reversal budget only when a reversal's step exceeds a
6264                // relative jitter band; a converged-but-flickering fit (e.g.
6265                // 150<->147 on N=200, ~2% of the level) re-anchors freely while a
6266                // true runaway (e.g. 9<->2, ~80% of the level) still trips every
6267                // reversal and exhausts the budget. This was the second #795 root
6268                // cause: the single-planted-circle fit's per-row count flickers
6269                // 150<->147 near the deflation floor, so the bare-reversal guard
6270                // refused the simplest possible fit — with the isometry gauge ON
6271                // *or* OFF — long before the gauge magnitude mattered.
6272                let amplitude = expected.abs_diff(count);
6273                let level = expected.max(count);
6274                let jitter_band = (level / 4).max(2);
6275                if is_reversal && amplitude > jitter_band {
6276                    self.evidence_gauge_deflation_reanchors += 1;
6277                }
6278                let reversal_budget = self
6279                    .k_atoms()
6280                    .saturating_mul(
6281                        SAE_ATOM_COLLAPSE_RESEED_BUDGET
6282                            + SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET
6283                            + 1,
6284                    )
6285                    .saturating_add(1);
6286                if self.evidence_gauge_deflation_reanchors > reversal_budget {
6287                    return Err(format!(
6288                        "SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count \
6289                         oscillated (reversed direction {} times, last {expected}->{count}) within \
6290                         one optimization, exceeding the {reversal_budget}-reversal budget for {} \
6291                         atoms; the quotient dimension is not stabilizing, refusing to compare \
6292                         Laplace normalizers",
6293                        self.evidence_gauge_deflation_reanchors,
6294                        self.k_atoms()
6295                    ));
6296                }
6297                log::debug!(
6298                    "SaeManifoldTerm::reml_criterion: per-row evidence deflation count changed \
6299                     {expected}->{count} (a benign per-row conditioning drift across the ρ-walk; \
6300                     reversal {}/{reversal_budget}); re-anchoring the Laplace normalizer comparison \
6301                     to the new dimension",
6302                    self.evidence_gauge_deflation_reanchors
6303                );
6304                self.expected_evidence_gauge_deflated_directions = Some(count);
6305                Ok(())
6306            }
6307            None => {
6308                self.expected_evidence_gauge_deflated_directions = Some(count);
6309                Ok(())
6310            }
6311        }
6312    }
6313
6314    pub(crate) fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
6315        matches!(
6316            err,
6317            ArrowSchurError::PerRowFactorFailed { reason, .. }
6318                if reason.contains("H_tt is non-PD at base ridge")
6319                    && reason.contains("evidence mode preserves the genuine Cholesky")
6320        )
6321    }
6322
6323    /// Drive the inner `(t, β)` Newton solve to the KKT/step-converged optimum
6324    /// and return the final UNDAMPED (`ridge = 0`) joint-Hessian factor cache.
6325    ///
6326    /// The Laplace normaliser `½log|H|` is only the correct REML criterion at
6327    /// the inner optimum `(t̂, β̂)`, so the criterion must refine the inner state
6328    /// until either the KKT gradient or the undamped Newton step meets tolerance
6329    /// before factoring. Crucially, **at the converged optimum the per-row
6330    /// `H_tt^(i)` blocks are PD**, so the undamped (`ridge = 0`) factorization
6331    /// succeeds; an off-optimum iterate (e.g. the initial seed, or a state
6332    /// stopped after only `inner_max_iter` steps) can have an indefinite /
6333    /// rank-deficient per-row block (`p_out = 1` → rank-1 `JᵀJ`, softmax
6334    /// assignment-sparsity negative logit curvature) that surfaces
6335    /// `PerRowFactorFailed` from the undamped `factor_one_row`. Both the dense
6336    /// (`reml_criterion_with_cache`) and the streaming
6337    /// (`reml_criterion_streaming_exact`) evidence paths route through this same
6338    /// driver, so they converge to the identical inner state and their
6339    /// `ridge = 0` log-determinants stay bit-identical (#847).
6340    pub(crate) fn converge_inner_for_undamped_logdet(
6341        &mut self,
6342        target: ArrayView2<'_, f64>,
6343        rho: &SaeManifoldRho,
6344        rho_fixed: &mut SaeManifoldRho,
6345        registry: Option<&AnalyticPenaltyRegistry>,
6346        inner_max_iter: usize,
6347        learning_rate: f64,
6348        ridge_ext_coord: f64,
6349        ridge_beta: f64,
6350        loss: &mut SaeManifoldLoss,
6351        options: &ArrowSolveOptions,
6352        refine_progress_extension: bool,
6353    ) -> Result<ArrowFactorCache, String> {
6354        // `inner_max_iter == 0` is a genuine FREEZE of the inner `(t, β)` state
6355        // — a verbatim warm-start reuse, not a convergence request (gam#577/#579,
6356        // #850). The convergence/refinement loop below MUST NOT run even one
6357        // Newton step in that case (the old `inner_max_iter.max(1)` floor moved
6358        // β off the seed), so we factor exactly once at the frozen iterate and
6359        // return that undamped cache without invoking the stationarity gate.
6360        // The caller has already run `run_joint_fit_arrow_schur(..., 0, ...)`,
6361        // which under the `max_iter == 0` freeze (gam#577/#579, #850) runs ONLY
6362        // the β-neutral basis refresh and returns the loss without touching β —
6363        // it skips the rank-reduction, frame activation, re-seed guards, and the
6364        // #1026 decoder-LSQ polish that would otherwise refit β off the seed — so
6365        // `self` is at the warm-start β here.
6366        if inner_max_iter == 0 {
6367            let sys = self
6368                .assemble_arrow_schur(target, rho, registry)
6369                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6370            let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
6371                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6372            // The frozen-state Newton step (factored.0, factored.1) is discarded
6373            // — only the undamped factor cache (factored.2) is consumed for the
6374            // log-det / selected-inverse traces; β stays at the warm-start seed.
6375            return Ok(factored.2);
6376        }
6377        let mut total_inner_iter = inner_max_iter;
6378        let accepted_base_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
6379        let value_probe_base_refine_iter = inner_max_iter.max(1).saturating_mul(4).max(16);
6380        let base_refine_iter = if refine_progress_extension {
6381            accepted_base_refine_iter
6382        } else {
6383            value_probe_base_refine_iter
6384        };
6385        let progress_refine_iter = if refine_progress_extension {
6386            inner_max_iter.max(1).saturating_mul(64).max(256)
6387        } else {
6388            base_refine_iter
6389        };
6390        let mut previous_refine_grad_norm: Option<f64> = None;
6391        let mut saw_refine_progress = false;
6392        // #1051 — objective-stagnation convergence. On an ill-conditioned
6393        // penalised bilinear fit (the euclidean / Duchon decoder × latent
6394        // coordinate system on a trivial shape), the inner Newton crawls: each
6395        // refine round lowers the penalised objective by a shrinking amount while
6396        // the KKT gradient and the undamped step stay above their relative
6397        // tolerances (the near-singular Schur amplifies the step in the
6398        // weakly-identified decoder direction). The grad-OR-step gate then never
6399        // fires and the solve is rejected as "did not converge" — the 1e12
6400        // sentinel. A Newton/LM iterate whose objective has stopped decreasing
6401        // to within `√εmach` of its scale IS the numerical inner optimum; ranking
6402        // the Laplace criterion there is correct. We accept that fixed point
6403        // instead of grinding the budget.
6404        let entry_loss_total = loss.total();
6405        let mut previous_loss_total = entry_loss_total;
6406        let mut refine_rounds: usize = 0;
6407        // Consecutive stall rounds: counts how many successive refine rounds
6408        // ended in a stall AND a failed undamped factor.  Once this reaches
6409        // `SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS` the iterate is at
6410        // its numerical fixed point and cannot be improved further; returning
6411        // `Err` here is the same "did not converge" signal that
6412        // `is_recoverable_value_probe_refusal` already handles, so the outer
6413        // BFGS treats it as an INFINITY probe and tries a different ρ instead
6414        // of looping forever burning the extended progress budget.  Without
6415        // this counter the stagnation handler fell through when the undamped
6416        // factor failed and the loop kept extending via `saw_refine_progress`
6417        // from earlier rounds, accumulating minutes of wasted work (#1094).
6418        let mut consecutive_stall_factor_fail: usize = 0;
6419        loop {
6420            let sys = self
6421                .assemble_arrow_schur(target, rho, registry)
6422                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6423            // Evidence-only factorization: the Newton step (Δt, Δβ) is discarded
6424            // and only the factor cache is consumed — the exact undamped log-det
6425            // and the selected-inverse traces. As ρ sweeps to extremes (e.g. a
6426            // wide ARD-α sweep), H_tt is genuinely PD but can be ill-conditioned;
6427            // the standard Direct guard rejects that to protect Newton-step
6428            // accuracy, but the log-det is exact from diag(L) regardless of the
6429            // condition number and the traces only need the (PD) factor. So
6430            // tolerate the ill-conditioning rejection here (a genuine non-PD pivot
6431            // still errors). The cache stays undamped at ridge=0, so
6432            // `arrow_log_det_from_cache` remains exact.
6433            // The exact KKT stationarity residual is the joint gradient
6434            // ‖g‖ = √(Σ_i ‖g_t^(i)‖² + ‖g_β‖²), read straight off the assembled
6435            // system. Unlike the Newton step Δ = H⁻¹g, the gradient is
6436            // factorisation-independent: it is NOT amplified by an inverse, so a
6437            // genuinely stationary but ill-conditioned fit (tiny g, possibly large
6438            // Δ in a flat direction) is correctly recognised as converged. The
6439            // `with_ill_conditioning_tolerated` Direct factor below documents that
6440            // its Δ may be inaccurate in exactly those flat directions, so using Δ
6441            // alone as the convergence gate would falsely reject healthy fits.
6442            let grad_norm_sq: f64 = sys
6443                .rows
6444                .iter()
6445                .map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
6446                .sum::<f64>()
6447                + sys.gb.iter().map(|&v| v * v).sum::<f64>();
6448            let grad_norm = grad_norm_sq.sqrt();
6449            // Quotient KKT-gradient (#1117): the raw joint gradient retains a
6450            // persistent small component in the chart-gauge orbit and the
6451            // rank-deficient decoder β-null even at a stationary fit, so the raw
6452            // grad gate never clears on a rank-deficient circle and the inner
6453            // refine loop crawls until the (large) progress budget dies — the
6454            // 2-min stall. Measure the gradient on the SAME identified quotient
6455            // the step gate already uses: a fit whose only remaining gradient
6456            // lives in those flat directions is stationary on the quotient, so
6457            // ranking the Laplace criterion there is correct. The dense per-row
6458            // g_t is laid into the `n·q` coordinate layout the gauge basis spans;
6459            // non-dense/heterogeneous systems fall back to the raw norm.
6460            let quotient_grad_norm = {
6461                let n = self.n_obs();
6462                let q = self.assignment.row_block_dim();
6463                let dense_len = n.saturating_mul(q);
6464                let mut grad_ext_coord = Array1::<f64>::zeros(dense_len);
6465                let mut dense_layout_ok = sys.rows.len() == n;
6466                if dense_layout_ok {
6467                    for (row_idx, row) in sys.rows.iter().enumerate() {
6468                        let base = sys.row_offsets[row_idx];
6469                        let di = sys.row_dims[row_idx];
6470                        if base + di > dense_len || row.gt.len() < di {
6471                            dense_layout_ok = false;
6472                            break;
6473                        }
6474                        for axis in 0..di {
6475                            grad_ext_coord[base + axis] = row.gt[axis];
6476                        }
6477                    }
6478                }
6479                if dense_layout_ok {
6480                    self.quotient_gradient_norm_sq(
6481                        grad_ext_coord.view(),
6482                        sys.gb.view(),
6483                        grad_norm_sq,
6484                        &rho_fixed.lambda_smooth_vec(),
6485                    )
6486                    .map(|v| v.sqrt())
6487                    .unwrap_or(grad_norm)
6488                } else {
6489                    grad_norm
6490                }
6491            };
6492            let iterate_scale = self.inner_iterate_scale();
6493            // Relative parameter-step tolerance for Δ (well-conditioned charts)
6494            // and a scaled KKT-gradient tolerance. Convergence is accepted on
6495            // EITHER a small KKT gradient OR a small undamped Newton step: SAE
6496            // manifold fits contain gauge-like coordinate/decoder directions (the
6497            // circle's rotation gauge, decoder column-space rotations) where the
6498            // shared-block Hessian is near-singular, so the undamped step can stay
6499            // large in that flat direction even at a genuine stationary point; the
6500            // gradient, which is not amplified by the inverse, recognises it. With
6501            // the isometry Gauss-Newton block now a coherent PSD pullback (no
6502            // indefinite Schur pivot), the inner solve reaches true stationarity,
6503            // so the gradient tolerance is a standard relative KKT residual rather
6504            // than the 0.1.154-regression band-aid (3e-3) that masked the
6505            // non-convergence the indefinite curvature caused.
6506            let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
6507            let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
6508            if !grad_norm_sq.is_finite() {
6509                return Err(format!(
6510                    "SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
6511                     at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
6512                     factorisation is degenerate at this ρ"
6513                ));
6514            }
6515            let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
6516                match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
6517                    Ok(factored) => factored,
6518                    Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
6519                        if grad_norm <= grad_tolerance || quotient_grad_norm <= grad_tolerance {
6520                            // K>1: the softmax/IBP logit–coordinate Gauss-Newton
6521                            // cross-terms (H_zt = J_z^T J_t, assembled row-locally from
6522                            // the assignment JVP × basis JVP) can make a per-row H_tt
6523                            // indefinite at the TRUE KKT stationary point — when two
6524                            // atoms' decoders specialise in opposite directions the
6525                            // Schur complement of the logit block goes negative even
6526                            // though the priors and the full-joint GN term are PSD.
6527                            //
6528                            // The undamped evidence factor already conditions that
6529                            // block the PRINCIPLED way: `factor_spectral_deflated_
6530                            // evidence_row` discovers the negative/flat eigen-direction
6531                            // and stiffens it to UNIT curvature (eigenvalue → +1), so it
6532                            // contributes a ρ-INDEPENDENT log 1 = 0 to the evidence —
6533                            // the same quotient pseudo-determinant convention the gauge
6534                            // (#1037) and data-null (#1117) deflations use. Reaching
6535                            // THIS arm at stationarity therefore means even the spectral
6536                            // deflation declined (a non-finite block or a failed
6537                            // eigendecomposition): the state is genuinely broken, so we
6538                            // surface the hard refusal and let the outer BFGS treat this
6539                            // ρ as an INFINITY probe (`is_recoverable_value_probe_
6540                            // refusal`). We must NOT ridge-damp here: a `+ridge·I`
6541                            // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6542                            // bias into the VALUE that the analytic ρ-gradient (built
6543                            // for the undamped Laplace log-det) never sees, desyncing
6544                            // the outer line-search — the multi-atom non-convergence
6545                            // this fix (#1117) removes.
6546                            return Err(format!(
6547                                "SaeManifoldTerm::reml_criterion: stationary undamped \
6548                                 evidence factorization has a non-PD per-row H_tt block \
6549                                 that spectral unit-stiffness deflation could not \
6550                                 condition (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); \
6551                                 {err}"
6552                            ));
6553                        }
6554                        let refine_limit = Self::refine_iteration_limit(
6555                            total_inner_iter,
6556                            base_refine_iter,
6557                            progress_refine_iter,
6558                            previous_refine_grad_norm,
6559                            grad_norm,
6560                            saw_refine_progress,
6561                        );
6562                        if total_inner_iter >= refine_limit {
6563                            // #1117/#1118 — pre-stationarity genuinely-indefinite
6564                            // non-gauge H_tt under K>1 IBP/softmax row-sharing. The
6565                            // logit × coordinate Gauss-Newton cross term H_zt = J_zᵀJ_t
6566                            // can drive a shared row's H_tt Schur complement NEGATIVE off
6567                            // the gauge orbit; the LM-escalated refinement above cannot
6568                            // always cross the indefinite basin into the PD region within
6569                            // the descent-extended budget.
6570                            //
6571                            // The undamped (ridge=0) evidence factor already conditions
6572                            // that block the PRINCIPLED way: `factor_spectral_deflated_
6573                            // evidence_row` discovers the negative/flat eigen-direction
6574                            // and stiffens it to UNIT curvature (eigenvalue → +1), a
6575                            // ρ-INDEPENDENT log 1 = 0 evidence contribution — so the
6576                            // `Ok(factored)` arm above accepts the indefinite block and
6577                            // returns a finite, monotone-comparable value to the outer
6578                            // BFGS WITHOUT a ρ-dependent bias. Reaching THIS arm means
6579                            // even that spectral deflation declined (a non-finite block
6580                            // or a failed eigendecomposition): the iterate is genuinely
6581                            // broken, so we surface the hard refusal and let the outer
6582                            // BFGS treat this ρ as an INFINITY probe.
6583                            //
6584                            // We must NOT ridge-damp here: a `+ridge·I` evidence
6585                            // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6586                            // bias into the VALUE that the analytic ρ-gradient (built
6587                            // for the undamped Laplace log-det) never sees, desyncing
6588                            // the outer line-search — the multi-atom non-convergence this
6589                            // fix removes. K=1 (and any already-PD or spectral-deflatable
6590                            // K>1 row) never reaches this branch.
6591                            return Err(format!(
6592                                "SaeManifoldTerm::reml_criterion: undamped evidence \
6593                                 factorization hit a non-PD per-row H_tt block before KKT \
6594                                 stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
6595                                 and the refinement budget was exhausted after \
6596                                 {total_inner_iter} inner iterations; {err}"
6597                            ));
6598                        }
6599                        let remaining = refine_limit - total_inner_iter;
6600                        let refine_iter = inner_max_iter.max(1).min(remaining);
6601                        saw_refine_progress |=
6602                            Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6603                        previous_refine_grad_norm = Some(grad_norm);
6604                        *loss = self.run_joint_fit_arrow_schur(
6605                            target,
6606                            rho_fixed,
6607                            registry,
6608                            refine_iter,
6609                            learning_rate,
6610                            ridge_ext_coord,
6611                            ridge_beta,
6612                        )?;
6613                        total_inner_iter += refine_iter;
6614                        continue;
6615                    }
6616                    Err(err) => {
6617                        return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
6618                    }
6619                };
6620            // The Laplace normaliser ½log|H| is only the correct REML criterion at
6621            // the inner optimum (t̂, β̂). Convergence is judged by EITHER a small
6622            // gradient (KKT stationarity) OR a small undamped Newton step; the
6623            // solve is only rejected as non-converged when BOTH are large, i.e.
6624            // the iterate is neither stationary nor about to move negligibly. That
6625            // disjunction is what keeps an ill-conditioned-but-stationary fit
6626            // (small g, large Δ) from being rejected while still refusing to rank
6627            // an off-optimum Laplace criterion that is genuinely mid-flight.
6628            let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
6629                + delta_beta.iter().map(|&v| v * v).sum::<f64>();
6630            if !step_norm_sq.is_finite() {
6631                return Err(format!(
6632                    "SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
6633                     the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
6634                     Hessian factorisation is degenerate at this ρ"
6635                ));
6636            }
6637            let step_norm = step_norm_sq.sqrt();
6638            let quotient_step_norm_sq = self.quotient_newton_step_norm_sq(
6639                delta_t.view(),
6640                delta_beta.view(),
6641                step_norm_sq,
6642                &rho_fixed.lambda_smooth_vec(),
6643            )?;
6644            let quotient_step_norm = quotient_step_norm_sq.sqrt();
6645            // Converge on ANY of: the raw KKT gradient (well-conditioned fit),
6646            // the QUOTIENT KKT gradient (#1117 — rank-deficient fit whose only
6647            // residual gradient is gauge/null flat-direction crawl), or the
6648            // quotient Newton step. The quotient-gradient disjunct is what lets
6649            // a rank-deficient K=1 circle terminate in budget instead of crawling
6650            // the weakly-identified valley until the refine budget dies.
6651            if grad_norm <= grad_tolerance
6652                || quotient_grad_norm <= grad_tolerance
6653                || quotient_step_norm <= step_tolerance
6654            {
6655                return Ok(cache);
6656            }
6657            let refine_limit = Self::refine_iteration_limit(
6658                total_inner_iter,
6659                base_refine_iter,
6660                progress_refine_iter,
6661                previous_refine_grad_norm,
6662                grad_norm,
6663                saw_refine_progress,
6664            );
6665            if total_inner_iter >= refine_limit {
6666                // Inner solve did not converge in reml_criterion; the returned
6667                // Err below carries the full non-convergence diagnostic
6668                // (gradient / quotient-step norms and tolerances) to the caller.
6669                return Err(format!(
6670                    "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6671                     neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
6672                     the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
6673                     (raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
6674                     tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
6675                     off-optimum Laplace criterion."
6676                ));
6677            }
6678            let remaining = refine_limit - total_inner_iter;
6679            let refine_iter = inner_max_iter.max(1).min(remaining);
6680            saw_refine_progress |=
6681                Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6682            previous_refine_grad_norm = Some(grad_norm);
6683            *loss = self.run_joint_fit_arrow_schur(
6684                target,
6685                rho_fixed,
6686                registry,
6687                refine_iter,
6688                learning_rate,
6689                ridge_ext_coord,
6690                ridge_beta,
6691            )?;
6692            total_inner_iter += refine_iter;
6693            refine_rounds += 1;
6694            // #1051 — objective-stagnation fixed point. A whole refine round that
6695            // failed to lower the penalised objective by a meaningful FRACTION of
6696            // the total since-entry reduction means the Newton/LM iterate is at
6697            // its numerical optimum: the remaining KKT residual lives in the
6698            // weakly-identified decoder / gauge directions the near-singular Schur
6699            // cannot resolve. Ranking the Laplace criterion at this fixed point is
6700            // correct (the only further motion is cosmetic flat-valley crawl), so
6701            // accept the current cache instead of refining until the budget dies.
6702            // Requires a few completed refine rounds (so the fraction baseline is
6703            // meaningful) but is NOT gated behind the full refine budget — the
6704            // whole point is to terminate the crawl long before that.
6705            let new_loss_total = loss.total();
6706            // Two stagnation signals, both required: (1) the latest refine round
6707            // contributed a negligible FRACTION of the total objective reduction
6708            // achieved since entry — the fit has captured essentially all the
6709            // achievable improvement and is now crawling cosmetically along the
6710            // weakly-identified valley; (2) the absolute relative decrease is
6711            // itself tiny. The fraction test is scale- and rate-free (it fires
6712            // whether the crawl decays fast or slow), so it recognises the
6713            // over-smoothed / rank-deficient fixed point the bare relative floor
6714            // misses, while still never firing on a fit that is materially
6715            // improving round over round.
6716            let total_improvement = (entry_loss_total - new_loss_total).max(0.0);
6717            let round_improvement = (previous_loss_total - new_loss_total).max(0.0);
6718            let objective_scale = previous_loss_total.abs().max(new_loss_total.abs()) + 1.0;
6719            let relative_decrease = round_improvement / objective_scale;
6720            let captured_fraction = if total_improvement > 0.0 {
6721                round_improvement / total_improvement
6722            } else {
6723                0.0
6724            };
6725            let stalled = new_loss_total.is_finite()
6726                && relative_decrease.is_finite()
6727                && (relative_decrease < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL
6728                    || captured_fraction < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION);
6729            previous_loss_total = new_loss_total;
6730            if stalled && refine_rounds >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6731                let stationary_sys = self
6732                    .assemble_arrow_schur(target, rho_fixed, registry)
6733                    .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6734                if let Ok((_dt, _db, stationary_cache)) =
6735                    solve_arrow_newton_step_with_options(&stationary_sys, 0.0, 0.0, options)
6736                {
6737                    return Ok(stationary_cache);
6738                }
6739                // Stagnated AND the undamped factor still fails: this is the
6740                // numerical fixed point of the inner solve under rank-deficient
6741                // or ill-conditioned geometry (e.g. multi-atom euclidean with
6742                // near-zero initial latent coords, #1094).  The iterate cannot
6743                // be improved further at this ρ.  Treat it as "inner solve did
6744                // not converge" — the same signal `is_recoverable_value_probe_refusal`
6745                // already handles, causing the outer BFGS to return INFINITY for
6746                // this ρ probe and try a different one.  Without this early
6747                // return the stagnation handler fell through and the loop kept
6748                // burning the extended `progress_refine_iter` budget indefinitely.
6749                consecutive_stall_factor_fail += 1;
6750                if consecutive_stall_factor_fail >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6751                    return Err(format!(
6752                        "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6753                         objective stalled for {consecutive_stall_factor_fail} consecutive refine \
6754                         rounds (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) and the undamped \
6755                         evidence factorization failed at each stall point — the iterate is at the \
6756                         numerical fixed point under rank-deficient geometry (#{consecutive_stall_factor_fail} \
6757                         stall-factor-fail rounds; refusing to rank an off-optimum Laplace criterion)"
6758                    ));
6759                }
6760            } else {
6761                consecutive_stall_factor_fail = 0;
6762            }
6763        }
6764    }
6765
6766    pub(crate) fn refine_iteration_limit(
6767        total_inner_iter: usize,
6768        base_refine_iter: usize,
6769        progress_refine_iter: usize,
6770        previous_grad_norm: Option<f64>,
6771        grad_norm: f64,
6772        saw_refine_progress: bool,
6773    ) -> usize {
6774        // Flat affine-gauge valleys can keep crawling productively after the
6775        // historical base budget. Extend only when the measured KKT residual has
6776        // shown a real finite round-to-round drop; true stalls end at the base
6777        // work budget (#968/#1029). Value-order probes pass the base budget as
6778        // their progress budget, so this branch cannot make probes expensive.
6779        if total_inner_iter < base_refine_iter {
6780            return base_refine_iter;
6781        }
6782        let making_progress =
6783            saw_refine_progress || Self::refine_round_made_progress(previous_grad_norm, grad_norm);
6784        if making_progress && grad_norm.is_finite() {
6785            progress_refine_iter
6786        } else {
6787            base_refine_iter
6788        }
6789    }
6790
6791    pub(crate) fn refine_round_made_progress(
6792        previous_grad_norm: Option<f64>,
6793        grad_norm: f64,
6794    ) -> bool {
6795        previous_grad_norm
6796            .is_some_and(|prev| prev.is_finite() && grad_norm.is_finite() && grad_norm < prev)
6797    }
6798
6799    pub(crate) fn outer_gradient_arrow_solver<'a>(
6800        &'a self,
6801        cache: &'a ArrowFactorCache,
6802        penalized_gram_scale: &[f64],
6803    ) -> Result<DeflatedArrowSolver<'a>, OuterGradientError> {
6804        let Err(conditioning_err) = Self::outer_gradient_conditioning_error(cache) else {
6805            return Ok(DeflatedArrowSolver::plain(cache));
6806        };
6807        let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
6808            return Err(conditioning_err);
6809        };
6810        if !(max_pivot.is_finite() && max_pivot > 0.0) {
6811            return Err(conditioning_err);
6812        }
6813
6814        // The conditioning gate has already flagged a near-singular joint Hessian
6815        // (`conditioning_err`). Below we attempt to attribute that flatness to the
6816        // closed-form gauge orbit (chart step gauges) plus the penalty-aware
6817        // decoder-null directions and deflate it. When NO such deflatable
6818        // direction can be recovered, the flat subspace is genuinely
6819        // non-identifiable -- a degenerate direction OUTSIDE the gauge orbit -- a
6820        // diagnosis distinct from the raw pivot-ratio conditioning trip. Both
6821        // classes are #1273 FD-eligible, but surfacing the gauge-degenerate case
6822        // as its own [`OuterGradientError::NonIdentifiable`] keeps the diagnostic
6823        // distinction the FD-eligibility contract is built around.
6824        let non_identifiable_err = OuterGradientError::NonIdentifiable {
6825            reason: format!(
6826                "near-singular joint Hessian with no deflatable gauge/decoder-null \
6827                 direction (max pivot {max_pivot:.3e})"
6828            ),
6829        };
6830
6831        let full_len = cache.delta_t_len() + cache.k;
6832        let mut raw_gauges = Vec::new();
6833        for gauge in self
6834            .dense_step_gauge_vectors()
6835            .map_err(OuterGradientError::internal)?
6836        {
6837            if gauge.len() != full_len {
6838                continue;
6839            }
6840            let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6841            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6842                continue;
6843            }
6844            raw_gauges.push(gauge);
6845        }
6846        // #1051/#1273: admit the penalty-aware decoder-β null directions as
6847        // additional deflation candidates. A rank-deficient decoder design
6848        // (e.g. a euclidean-1D line in a p=2 ambient: decoder column rank 1 of
6849        // 3) puts a genuine near-null direction of the joint Hessian in the β
6850        // block, OUTSIDE the closed-form chart gauge orbit. #1273: probing the
6851        // RAW unit-β basis `e_j` produced an INCOMPLETE candidate set — the
6852        // true flat direction is the penalised null of `G_k + λ_smooth·S_k`,
6853        // not an axis-aligned coordinate, so the outer gate rejected trial ρ
6854        // with a pivot ratio (5.3e-16 < 1e-12) that the inner gate (which
6855        // already uses `decoder_beta_null_directions(λ_smooth)`) accepts. Use
6856        // the SAME penalty-aware null directions here, evaluated at the smooth
6857        // scale the Schur factor used, so the outer and inner gates agree.
6858        // These full (n·q + beta_dim)-length vectors drop into the same
6859        // Gram-Schmidt + Rayleigh + Faddeev-Popov path below; the Rayleigh
6860        // floor still keeps only genuinely flat (sub-floor) directions, so a
6861        // well-conditioned decoder is unaffected.
6862        for dir in self
6863            .decoder_beta_null_directions(penalized_gram_scale)
6864            .map_err(OuterGradientError::internal)?
6865        {
6866            if dir.len() == full_len {
6867                raw_gauges.push(dir);
6868            }
6869        }
6870        // #1051/#1273: also admit the decoder COLUMN-SPAN null (an unrealised
6871        // ambient output channel of a rank-deficient decoder), which the
6872        // channel-free basis-null above structurally cannot represent. The
6873        // rank-1-decoder-line geometry (e.g. a 1-D euclidean line in p=2
6874        // ambient: decoder column rank 1 of 2) puts the joint Hessian's
6875        // sub-floor pivot entirely in one output channel; without this
6876        // candidate the outer gate had nothing to deflate it with and rejected
6877        // the trial ρ. The Rayleigh floor below still prunes any candidate that
6878        // is not genuinely flat against the cached Hessian.
6879        for dir in self
6880            .decoder_channel_null_directions()
6881            .map_err(OuterGradientError::internal)?
6882        {
6883            if dir.len() == full_len {
6884                raw_gauges.push(dir);
6885            }
6886        }
6887        if raw_gauges.is_empty() {
6888            return Err(non_identifiable_err);
6889        }
6890
6891        let mut gauge_span: Vec<Array1<f64>> = Vec::new();
6892        for mut gauge in raw_gauges {
6893            for basis in &gauge_span {
6894                let coeff = gauge.dot(basis);
6895                for i in 0..gauge.len() {
6896                    gauge[i] -= coeff * basis[i];
6897                }
6898            }
6899            let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6900            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6901                continue;
6902            }
6903            let inv_norm = norm_sq.sqrt().recip();
6904            for value in gauge.iter_mut() {
6905                *value *= inv_norm;
6906            }
6907            gauge_span.push(gauge);
6908        }
6909        if gauge_span.is_empty() {
6910            return Err(non_identifiable_err);
6911        }
6912
6913        let span_rank = gauge_span.len();
6914        let mut h_span = Array2::<f64>::zeros((span_rank, span_rank));
6915        for col in 0..span_rank {
6916            let h_gauge = match apply_cached_arrow_hessian(
6917                cache,
6918                gauge_span[col].slice(s![..cache.delta_t_len()]),
6919                gauge_span[col].slice(s![cache.delta_t_len()..]),
6920            ) {
6921                Ok(value) => value,
6922                // #1451: a shape/dimension mismatch or non-finite intermediate
6923                // from the Hessian apply is an internal-invariant defect and MUST
6924                // propagate; only a genuine numeric failure on a finite,
6925                // correctly-shaped input keeps the FD-eligible conditioning class.
6926                Err(err) => {
6927                    return Err(OuterGradientError::classify_arrow_solver_error(
6928                        &err,
6929                        conditioning_err.clone(),
6930                    ));
6931                }
6932            };
6933            let h_flat = flatten_arrow_parts(h_gauge.t.view(), h_gauge.beta.view());
6934            for row in 0..span_rank {
6935                h_span[[row, col]] = gauge_span[row].dot(&h_flat);
6936            }
6937        }
6938        for row in 0..span_rank {
6939            for col in 0..row {
6940                let sym = 0.5 * (h_span[[row, col]] + h_span[[col, row]]);
6941                h_span[[row, col]] = sym;
6942                h_span[[col, row]] = sym;
6943            }
6944        }
6945        // #1451: a non-finite entry in the projected gauge Hessian is an
6946        // internal-invariant defect (a NaN/Inf intermediate leaked into the
6947        // span), not a conditioning failure — it MUST propagate rather than be
6948        // masked behind an FD descent. Guard finiteness BEFORE the eigh so only a
6949        // genuine decomposition failure on a finite, correctly-shaped matrix keeps
6950        // the FD-eligible conditioning class.
6951        if !h_span.iter().all(|v| v.is_finite()) {
6952            return Err(OuterGradientError::internal(format!(
6953                "outer_gradient_arrow_solver: non-finite entry in projected gauge \
6954                 Hessian (h_span is {span_rank}x{span_rank})"
6955            )));
6956        }
6957        let (evals, evecs) = h_span
6958            .eigh(Side::Lower)
6959            .map_err(|_| conditioning_err.clone())?;
6960        let strict_gauge_floor = SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR * max_pivot;
6961        let mut orthonormal: Vec<Array1<f64>> = Vec::new();
6962        for eig_idx in 0..evals.len() {
6963            let rayleigh = evals[eig_idx];
6964            if !(rayleigh.is_finite() && rayleigh <= strict_gauge_floor) {
6965                continue;
6966            }
6967            let mut direction = Array1::<f64>::zeros(full_len);
6968            for basis_idx in 0..span_rank {
6969                let coeff = evecs[[basis_idx, eig_idx]];
6970                for row in 0..full_len {
6971                    direction[row] += coeff * gauge_span[basis_idx][row];
6972                }
6973            }
6974            let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6975            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6976                continue;
6977            }
6978            let inv_norm = norm_sq.sqrt().recip();
6979            for value in direction.iter_mut() {
6980                *value *= inv_norm;
6981            }
6982            orthonormal.push(direction);
6983        }
6984        if orthonormal.is_empty() {
6985            // #1273/#1440: the conditioning gate has ALREADY certified a
6986            // near-singular joint Hessian (`conditioning_err`), so a genuine flat
6987            // direction exists inside the assembled gauge/decoder-null span even
6988            // when no projected-Hessian eigenvector cleared the strict or the
6989            // `fallback_gauge_floor` Rayleigh band. Rather than declining
6990            // (which historically routed the outer step to a finite-difference
6991            // descent direction — the FD instrument #1440 removes), deflate the
6992            // SMALLEST-Rayleigh eigenvector of the projected gauge Hessian
6993            // UNCONDITIONALLY. That eigenvector is the least-curvature member of
6994            // the validated gauge span (a Faddeev-Popov gauge candidate), so the
6995            // Tikhonov stiffness `max_pivot` in `from_orthonormal_gauges` bounds
6996            // its contribution at the Hessian scale and the components orthogonal
6997            // to it are byte-for-byte the plain analytic inverse solve. This keeps
6998            // the descent direction fully ANALYTIC (a projected/damped gradient),
6999            // never a differenced value path.
7000            let mut best_idx = None;
7001            let mut best_rayleigh = f64::INFINITY;
7002            for eig_idx in 0..evals.len() {
7003                let rayleigh = evals[eig_idx];
7004                if rayleigh.is_finite() && rayleigh < best_rayleigh {
7005                    best_idx = Some(eig_idx);
7006                    best_rayleigh = rayleigh;
7007                }
7008            }
7009            if let Some(eig_idx) = best_idx {
7010                let mut direction = Array1::<f64>::zeros(full_len);
7011                for basis_idx in 0..span_rank {
7012                    let coeff = evecs[[basis_idx, eig_idx]];
7013                    for row in 0..full_len {
7014                        direction[row] += coeff * gauge_span[basis_idx][row];
7015                    }
7016                }
7017                let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
7018                if norm_sq.is_finite() && norm_sq > 1.0e-24 {
7019                    let inv_norm = norm_sq.sqrt().recip();
7020                    for value in direction.iter_mut() {
7021                        *value *= inv_norm;
7022                    }
7023                    orthonormal.push(direction);
7024                }
7025            }
7026        }
7027        if orthonormal.is_empty() {
7028            return Err(non_identifiable_err);
7029        }
7030
7031        // Quotient-geometry gauge fixing: add stiffness only along the closed-form
7032        // gauge orbit (Faddeev-Popov style). Components orthogonal to that orbit
7033        // are identical to the original inverse solve, while gauge components are
7034        // bounded at the Hessian scale `max_pivot`.
7035        // #1451: a shape/length mismatch or non-finite stiffness/intermediate in
7036        // the deflated-solver assembly is an internal-invariant defect and MUST
7037        // propagate; only a genuine near-singular gauge Woodbury/back-solve keeps
7038        // the FD-eligible conditioning class.
7039        DeflatedArrowSolver::from_orthonormal_gauges(cache, orthonormal, max_pivot)
7040            .map_err(|err| OuterGradientError::classify_arrow_solver_error(&err, conditioning_err))
7041    }
7042
7043    pub(crate) fn outer_gradient_conditioning_error(
7044        cache: &ArrowFactorCache,
7045    ) -> Result<(), OuterGradientError> {
7046        let pivot = arrow_factor_min_pivot(cache);
7047        let Some(min_pivot) = pivot.min_pivot else {
7048            return Err(OuterGradientError::IllConditioned {
7049                reason: "joint Hessian numerically singular (no cached Cholesky pivots)"
7050                    .to_string(),
7051            });
7052        };
7053        let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
7054            return Err(OuterGradientError::IllConditioned {
7055                reason: "joint Hessian numerically singular (no cached Cholesky pivot scale)"
7056                    .to_string(),
7057            });
7058        };
7059        let ratio = min_pivot / max_pivot;
7060        if min_pivot.is_finite()
7061            && max_pivot.is_finite()
7062            && max_pivot > 0.0
7063            && ratio.is_finite()
7064            && ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
7065        {
7066            return Ok(());
7067        }
7068        Err(OuterGradientError::IllConditioned {
7069            reason: format!(
7070                "joint Hessian numerically singular (min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, max pivot {max_pivot:.3e})",
7071                floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
7072            ),
7073        })
7074    }
7075
7076    /// Smoothing-penalty Occam normalizer `−½ Σ_k r_k·rank(S_k)·log λ_smooth`
7077    /// PLUS the profiled-frame evidence-dimension term `½ Σ_k r_k·(p−r_k)·log
7078    /// λ_smooth` (issue #972).
7079    ///
7080    /// On the full-`B` path every atom's frame rank `r_k == p`, so the first
7081    /// piece reduces to the historical `½ p·(Σ rank S_k)·log λ_smooth` and the
7082    /// Grassmann term is zero — bit-for-bit unchanged. When a frame is active the
7083    /// decoder coordinates `C_k` carry the `⊗ I_{r_k}` Kronecker structure (the
7084    /// smoothing penalty `S_k` now acts on `r_k` channels, not `p`), so the
7085    /// penalty-logdet normalizer uses `r_k·rank(S_k)`; and the `r_k·(p−r_k)`
7086    /// frame degrees of freedom profiled OUT of the border are counted explicitly
7087    /// in the Laplace dimension accounting (evidence honesty) so the criterion
7088    /// cannot buy a free evidence boost by hiding decoder freedom in the frame.
7089    pub(crate) fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
7090        // #1556: λ_smooth is per-atom, so the Occam penalty normalizer and the
7091        // profiled-frame evidence-dimension term are both per-atom sums, each
7092        // atom `k` weighted by its own `log λ_smooth[k]`. With a uniform
7093        // (broadcast) vector this is bit-for-bit the historical global form.
7094        let mut acc = 0.0_f64;
7095        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7096            let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7097            // Penalized decoder dimension: `r_k` coordinate channels carry the
7098            // `S_k` roughness penalty (full-`B` path ⇒ `r_k == p`).
7099            let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7100            // Profiled Grassmann dimensions enter the Laplace evidence dimension
7101            // count with the OPPOSITE sign of the penalty Occam term (they are
7102            // free, unpenalized-by-`S` profiled directions), so `−occam` adds
7103            // `+½ r(p−r) log λ_k` to the criterion `V` — the honesty correction.
7104            let frame_dim = atom.frame_manifold_dimension();
7105            let log_lambda = rho.log_lambda_smooth[atom_idx];
7106            acc += 0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)) * log_lambda;
7107        }
7108        // `V = … − occam`, so the net occam SUBTRACTS the penalty normalizer and
7109        // ADDS the frame-dimension count after the caller's `− occam`.
7110        Ok(acc)
7111    }
7112
7113    /// Per-atom derivative `∂(occam)/∂log λ_smooth[k]` (#1556): atom `k`'s entry
7114    /// is `½·(r_k·rank(S_k) − frame_dim_k)`, matching the per-atom Occam term in
7115    /// [`Self::reml_occam_term`]. Returns one entry per atom in atom order.
7116    pub(crate) fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<Vec<f64>, String> {
7117        let mut out = Vec::with_capacity(self.atoms.len());
7118        for atom in &self.atoms {
7119            let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7120            let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7121            let frame_dim = atom.frame_manifold_dimension();
7122            out.push(0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)));
7123        }
7124        Ok(out)
7125    }
7126
7127    pub fn reml_criterion_streaming_exact(
7128        &mut self,
7129        target: ArrayView2<'_, f64>,
7130        rho: &SaeManifoldRho,
7131        registry: Option<&AnalyticPenaltyRegistry>,
7132        inner_max_iter: usize,
7133        learning_rate: f64,
7134        ridge_ext_coord: f64,
7135        ridge_beta: f64,
7136    ) -> Result<(f64, SaeManifoldLoss), String> {
7137        let mut rho_fixed = rho.clone();
7138        let mut loss = self.run_joint_fit_arrow_schur(
7139            target,
7140            &mut rho_fixed,
7141            registry,
7142            inner_max_iter,
7143            learning_rate,
7144            ridge_ext_coord,
7145            ridge_beta,
7146        )?;
7147        // Drive the inner (t, β) state to the SAME KKT/step-converged optimum the
7148        // dense `reml_criterion_with_cache` reaches before factoring. At that
7149        // optimum the per-row `H_tt^(i)` blocks are PD, so the undamped
7150        // (`ridge_t = 0`) streaming factorization in `streaming_exact_arrow_log_det`
7151        // succeeds — without this, a state stopped after only `inner_max_iter`
7152        // steps can leave a rank-deficient / indefinite row block (`p_out = 1` →
7153        // rank-1 `JᵀJ`, softmax negative-logit curvature) that surfaces
7154        // `PerRowFactorFailed` at base ridge 0. Sharing the driver also keeps the
7155        // streaming and dense log-determinants bit-identical (#847).
7156        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7157        // The dense factor cache from convergence is surplus here — the streaming
7158        // path recomputes the (bit-identical) log-determinant chunk-by-chunk in
7159        // `streaming_exact_arrow_log_det` to bound peak memory — so it is dropped.
7160        let converged_cache = self.converge_inner_for_undamped_logdet(
7161            target,
7162            rho,
7163            &mut rho_fixed,
7164            registry,
7165            inner_max_iter,
7166            learning_rate,
7167            ridge_ext_coord,
7168            ridge_beta,
7169            &mut loss,
7170            &options,
7171            true,
7172        )?;
7173        drop(converged_cache);
7174        let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
7175        let occam = self.reml_occam_term(rho)?;
7176        // Extra analytic-penalty energy (#671/#737), matching the full-batch
7177        // `reml_criterion_with_cache` path so streaming and dense criteria rank
7178        // the identical penalized objective.
7179        let extra_penalty_energy = match registry {
7180            Some(reg) => self
7181                .reml_extra_penalty_value_total(reg)
7182                .map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
7183            None => 0.0,
7184        };
7185        Ok((
7186            loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
7187            loss,
7188        ))
7189    }
7190
7191    pub fn streaming_exact_arrow_log_det(
7192        &mut self,
7193        target: ArrayView2<'_, f64>,
7194        rho: &SaeManifoldRho,
7195        registry: Option<&AnalyticPenaltyRegistry>,
7196    ) -> Result<f64, String> {
7197        if target.dim() != (self.n_obs(), self.output_dim()) {
7198            return Err(format!(
7199                "SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
7200                self.n_obs(),
7201                self.output_dim(),
7202                target.dim()
7203            ));
7204        }
7205        let plan = self.streaming_plan().admitted_or_error(
7206            self.n_obs(),
7207            self.output_dim(),
7208            self.k_atoms(),
7209        )?;
7210        if plan.estimated_dense_schur_bytes > plan.in_core_budget_bytes {
7211            return Err(format!(
7212                "SaeManifoldTerm::streaming_exact_arrow_log_det: predicted dense reduced Schur {} bytes exceeds budget {} bytes; cost-only matrix-free route is required",
7213                plan.estimated_dense_schur_bytes, plan.in_core_budget_bytes
7214            ));
7215        }
7216        let n_total = self.n_obs();
7217        let chunk_size = plan.chunk_size.min(n_total.max(1));
7218        // #972 / #977 T1: the reduced β-Schur is over the FACTORED border when
7219        // frames are active (each chunk inherits the frames via
7220        // `materialize_chunk`, so every `chunk_schur` is `border_dim²`), matching
7221        // the dense path's factored log-det. Full-`B` ⇒ `border_dim == beta_dim`.
7222        let border_dim = if self.frames_active() {
7223            self.factored_border_dim()
7224        } else {
7225            self.beta_dim()
7226        };
7227        let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
7228        let mut log_det_tt = 0.0_f64;
7229        // #1038 cross-row IBP Woodbury accumulators. `M = Uᵀ H₀'⁻¹ U` is
7230        // chunk-additive in `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` and `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ`
7231        // (`A = H₀'` block-diagonal, `U` row-supported), closed against the
7232        // GLOBAL reduced Schur `S = schur_acc` after the loop. `None` for every
7233        // non-IBP (softmax / JumpReLU) term, where the streaming log-det is
7234        // exactly the bare `log_det_tt + log_det_schur` as before.
7235        let mut wood_m0: Option<Array2<f64>> = None;
7236        let mut wood_w: Option<Array2<f64>> = None;
7237        let mut wood_d: Option<Array1<f64>> = None;
7238        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7239        let mut start = 0usize;
7240        while start < n_total {
7241            let end = (start + chunk_size).min(n_total);
7242            let penalty_scale = (end - start) as f64 / n_total as f64;
7243            let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
7244            let chunk_coords: Vec<Array2<f64>> = self
7245                .assignment
7246                .coords
7247                .iter()
7248                .map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
7249                .collect();
7250            let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
7251            // #1117 — rank deficiency is removed at the basis layer at fit entry
7252            // (`reduce_atoms_to_data_supported_rank`), so each chunk inherits the
7253            // already-reduced full-rank atoms via `materialize_chunk`; there are
7254            // no global deflation projectors to propagate.
7255            // #991: chunk terms inherit the row's design honesty weight slice
7256            // (global mean-1 normalization preserved — NOT re-normalized per
7257            // chunk — so the per-chunk sums reconstruct the global weighted
7258            // objective exactly).
7259            if let Some(w) = self.row_loss_weights.as_deref() {
7260                chunk.row_loss_weights = Some(w[start..end].to_vec());
7261            }
7262            let z_chunk = target.slice(s![start..end, ..]);
7263            let sys = chunk
7264                .assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
7265                .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7266            let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
7267            let (chunk_log_det_tt, chunk_schur, chunk_wood) = streaming
7268                .reduced_schur_log_det_tt_woodbury(0.0, 0.0, &options)
7269                .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7270            log_det_tt += chunk_log_det_tt;
7271            for row in 0..border_dim {
7272                for col in 0..border_dim {
7273                    schur_acc[[row, col]] += chunk_schur[[row, col]];
7274                }
7275            }
7276            if chunk_wood.is_some() && chunk_size < n_total {
7277                // The cross-row IBP empirical mass `M_k = Σ_i z_ik` couples ALL
7278                // rows, so the per-row `H₀'` diagonal (`score_derivative_k(M_k)`)
7279                // and the column coefficient `d_k = w·s'_k(M_k)` are only exact
7280                // when every row is assembled together — a SINGLE chunk. Under a
7281                // genuine multi-chunk pass each chunk would see a partial mass and
7282                // the Woodbury (and the bare per-row log-det) would be inexact, so
7283                // refuse loudly and route to the dense resident path rather than
7284                // return a silently-wrong evidence. The streaming log-det only
7285                // runs when the dense reduced Schur fits budget, so the single-
7286                // chunk regime is the common case; this guards the rest.
7287                return Err(
7288                    "SaeManifoldTerm::streaming_exact_arrow_log_det: exact cross-row IBP \
7289                     Woodbury evidence requires a single-chunk pass (the empirical mass \
7290                     M_k = Σ_i z_ik couples all rows); this shape needs >1 chunk. Route \
7291                     IBP-active large-n fits through the dense resident \
7292                     ArrowFactorCache::arrow_log_det."
7293                        .to_string(),
7294                );
7295            }
7296            if let Some(cw) = chunk_wood {
7297                wood_m0 = Some(match wood_m0.take() {
7298                    Some(mut acc) => {
7299                        acc += &cw.m0;
7300                        acc
7301                    }
7302                    None => cw.m0,
7303                });
7304                wood_w = Some(match wood_w.take() {
7305                    Some(mut acc) => {
7306                        acc += &cw.w;
7307                        acc
7308                    }
7309                    None => cw.w,
7310                });
7311                // `D = diag(d_k)` is per-atom; identical across chunks for a
7312                // single-chunk evidence pass (the regime the streaming log-det
7313                // runs in — the dense reduced Schur must fit budget here), where
7314                // it equals the global mass-derived `cross_row_d`.
7315                wood_d = Some(cw.d);
7316            }
7317            start = end;
7318        }
7319        let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
7320            .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7321        let mut total = log_det_tt + log_det_schur;
7322        // #1038/#1225: close the exact cross-row IBP Woodbury correction
7323        // `log det(I_R + D Uᵀ H₀'⁻¹ U)` so the streaming evidence equals the
7324        // dense `arrow_log_det_from_cache` (which adds the SAME term). Without
7325        // it the streaming criterion would silently drop the entire cross-row
7326        // coupling and disagree with the dense path by exactly `log|C|`.
7327        if let (Some(m0), Some(w), Some(d)) = (wood_m0, wood_w, wood_d) {
7328            let correction = streaming_cross_row_woodbury_log_det(&schur_acc, &m0, &w, &d)
7329                .map_err(|err| {
7330                    format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7331                })?
7332                .ok_or_else(|| {
7333                    "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
7334                     this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
7335                        .to_string()
7336                })?;
7337            total += correction;
7338        }
7339        Ok(total)
7340    }
7341
7342    /// Per-atom, per-axis coordinate sum-of-squares `‖t_kj‖² = Σ_i t_{i,k,j}²`.
7343    ///
7344    /// This is the data-fit sufficient statistic for the ARD precision update
7345    /// (the numerator-side `‖t‖²` of the deleted `α = n/‖t‖²` rule). Returned
7346    /// per atom as an `Array1` of length `d_k`.
7347    ///
7348    /// On a *periodic* (Circle) axis the relevant statistic is the von-Mises
7349    /// energy-equivalent `Σ_i 2/α·V(t_i) = Σ_i (2/κ²)(1−cos κ t_i)` (independent
7350    /// of α), so that `½·α·sumsq == Σ_i V(t_i)` matches `ard_value`. This keeps
7351    /// the Mackay/Fellner–Schall fixed point `α ← n / (sumsq + tr H⁻¹)`
7352    /// consistent with the actual periodic prior energy rather than the
7353    /// origin-dependent raw `t²`.
7354    pub(crate) fn ard_coord_sumsq(&self) -> Vec<Array1<f64>> {
7355        let mut out = Vec::with_capacity(self.k_atoms());
7356        for coord in &self.assignment.coords {
7357            let d = coord.latent_dim();
7358            let periods = coord.effective_axis_periods();
7359            let mut sq = Array1::<f64>::zeros(d);
7360            for row in 0..coord.n_obs() {
7361                let t = coord.row(row);
7362                for axis in 0..d {
7363                    // `sq_equiv` is independent of `alpha`; pass 1.0.
7364                    sq[axis] += ArdAxisPrior::eval(1.0, t[axis], periods[axis]).sq_equiv;
7365                }
7366            }
7367            out.push(sq);
7368        }
7369        out
7370    }
7371
7372    /// Per-atom, per-axis posterior-variance trace `tr_kj(H⁻¹) =
7373    /// Σ_i [(H⁻¹)_tt]_{(i,k,j),(i,k,j)}` from the converged factor cache.
7374    ///
7375    /// `cache.latent_block_inverse_diagonal()` returns the diagonal of the
7376    /// latent block `(H⁻¹)_tt` in the cache's compact per-row `delta_t`
7377    /// layout (length `row_offsets[N]`); each per-row block is laid out as
7378    /// `[logit scalars…, then per-active-atom coord axes…]`. This routine
7379    /// sums those diagonal entries over the coord positions belonging to each
7380    /// `(atom k, axis j)` across all observation rows where atom `k` is active.
7381    ///
7382    /// `self.last_row_layout` must be the layout from the *same* assemble that
7383    /// produced `cache`:
7384    /// - `Some(layout)`: compact active-set mode (JumpReLU / large-K
7385    ///   softmax-IBP truncation). For row `i`, atom `k`'s position in the
7386    ///   active list gives its compact coord-block start `coord_starts[i][pos]`;
7387    ///   inactive atoms contribute 0 (the prior dominates there anyway).
7388    /// - `None`: dense full-support layout, uniform row dim
7389    ///   `q = assignment_dim + Σ d_k`; atom `k`'s coord block sits at the
7390    ///   fixed full-row offset `coord_offsets[k]` after the assignment chart.
7391    ///
7392    /// This `tr_kj(H⁻¹)` is exactly the posterior-variance term the deleted
7393    /// `α = n/‖t‖²` rule dropped; the corrected Mackay/Fellner-Schall fixed
7394    /// point is `α_new = n / (‖t_kj‖² + tr_kj(H⁻¹))`.
7395    pub(crate) fn ard_inverse_traces(
7396        &self,
7397        cache: &ArrowFactorCache,
7398    ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7399        let inv_diag = cache.latent_block_inverse_diagonal()?;
7400        let n = self.n_obs();
7401        let coord_offsets = self.assignment.coord_offsets();
7402        let mut traces: Vec<Array1<f64>> = self
7403            .assignment
7404            .coords
7405            .iter()
7406            .map(|c| Array1::<f64>::zeros(c.latent_dim()))
7407            .collect();
7408        for row in 0..n {
7409            let row_base = cache.row_offsets[row];
7410            match self.last_row_layout {
7411                Some(ref layout) => {
7412                    let active = &layout.active_atoms[row];
7413                    let starts = &layout.coord_starts[row];
7414                    for (pos, &k) in active.iter().enumerate() {
7415                        let d = self.assignment.coords[k].latent_dim();
7416                        let block_start = starts[pos];
7417                        for axis in 0..d {
7418                            traces[k][axis] += inv_diag[row_base + block_start + axis];
7419                        }
7420                    }
7421                }
7422                None => {
7423                    for k in 0..self.k_atoms() {
7424                        let d = self.assignment.coords[k].latent_dim();
7425                        let block_start = coord_offsets[k];
7426                        for axis in 0..d {
7427                            traces[k][axis] += inv_diag[row_base + block_start + axis];
7428                        }
7429                    }
7430                }
7431            }
7432        }
7433        Ok(traces)
7434    }
7435
7436    pub(crate) fn ard_log_precision_explicit_derivatives(
7437        &self,
7438        rho: &SaeManifoldRho,
7439    ) -> Result<Vec<Array1<f64>>, String> {
7440        if rho.log_ard.len() != self.k_atoms() {
7441            return Err(format!(
7442                "ARD rho has {} atoms but term has {}",
7443                rho.log_ard.len(),
7444                self.k_atoms()
7445            ));
7446        }
7447        let n = self.n_obs() as f64;
7448        let mut out = Vec::with_capacity(self.k_atoms());
7449        for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
7450            let d = coord.latent_dim();
7451            let mut atom_out = Array1::<f64>::zeros(rho.log_ard[atom_idx].len());
7452            if rho.log_ard[atom_idx].is_empty() {
7453                out.push(atom_out);
7454                continue;
7455            }
7456            if rho.log_ard[atom_idx].len() != d {
7457                return Err(format!(
7458                    "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
7459                    rho.log_ard[atom_idx].len()
7460                ));
7461            }
7462            let periods = coord.effective_axis_periods();
7463            for axis in 0..d {
7464                let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
7465                let period = periods[axis];
7466                let mut energy_deriv = 0.0_f64;
7467                for row in 0..coord.n_obs() {
7468                    let t = coord.row(row)[axis];
7469                    energy_deriv += ArdAxisPrior::eval(alpha, t, period).value;
7470                }
7471                let normalizer_deriv = match period {
7472                    None => -0.5 * n,
7473                    Some(p) => {
7474                        let kappa = std::f64::consts::TAU / p;
7475                        let eta = alpha / (kappa * kappa);
7476                        // d/d(log α) of `n[-η + log I0(η)]` = `n η (I1/I0 - 1)`.
7477                        // The ratio is computed without forming `e^{η}`, so it
7478                        // stays finite for large `η` instead of the `inf/inf =
7479                        // NaN` that `bessel_i1(η)/bessel_i0(η)` produces (#1113).
7480                        let ratio = bessel_i0_log_and_ratio(eta).1;
7481                        n * eta * (-1.0 + ratio)
7482                    }
7483                };
7484                atom_out[axis] = energy_deriv + normalizer_deriv;
7485            }
7486            out.push(atom_out);
7487        }
7488        Ok(out)
7489    }
7490
7491    pub(crate) fn ard_log_precision_hessian_trace(
7492        &self,
7493        rho: &SaeManifoldRho,
7494        cache: &ArrowFactorCache,
7495        solver: &DeflatedArrowSolver<'_>,
7496    ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7497        // RAW selected-inverse diagonal: the per-axis diagonal contraction uses
7498        // the DEFLATED inverse; the full kept-subspace + rotation deflation
7499        // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per (row, axis)
7500        // afterwards via the Daleckii–Krein helper. Each ARD ρ-component
7501        // `(atom k, axis)` differentiates a SINGLE coordinate-slot diagonal entry,
7502        // so its `D` is the rank-one `hess·e_s e_sᵀ` at that local slot `s`.
7503        let inv_diag = solver
7504            .latent_inverse_diagonal()
7505            .map_err(|err| ArrowSchurError::SchurFactorFailed { reason: err })?;
7506        let n = self.n_obs();
7507        let total_t = cache.delta_t_len();
7508        let coord_offsets = self.assignment.coord_offsets();
7509        let ard_axis_periods: Vec<Vec<Option<f64>>> = self
7510            .assignment
7511            .coords
7512            .iter()
7513            .map(LatentCoordValues::effective_axis_periods)
7514            .collect();
7515        let mut traces: Vec<Array1<f64>> = self
7516            .assignment
7517            .coords
7518            .iter()
7519            .enumerate()
7520            .map(|(k, c)| {
7521                if rho.log_ard[k].is_empty() {
7522                    Array1::<f64>::zeros(0)
7523                } else {
7524                    Array1::<f64>::zeros(c.latent_dim())
7525                }
7526            })
7527            .collect();
7528        for row in 0..n {
7529            let row_base = cache.row_offsets[row];
7530            let q = cache.row_dims[row];
7531            let dirs = cache
7532                .deflated_row_directions
7533                .get(row)
7534                .map(Vec::as_slice)
7535                .unwrap_or(&[]);
7536            let spectrum = cache
7537                .deflation_row_spectra
7538                .get(row)
7539                .and_then(Option::as_ref);
7540            // Per-row selected-inverse t-block, built once (only when deflated).
7541            let inv_vv = if dirs.is_empty() {
7542                None
7543            } else {
7544                let mut m = Array2::<f64>::zeros((q, q));
7545                for col in 0..q {
7546                    let mut rhs_t = Array1::<f64>::zeros(total_t);
7547                    let rhs_beta = Array1::<f64>::zeros(cache.k);
7548                    rhs_t[row_base + col] = 1.0;
7549                    let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7550                        ArrowSchurError::SchurFactorFailed { reason: err }
7551                    })?;
7552                    for r in 0..q {
7553                        m[[r, col]] = solved.t[row_base + r];
7554                    }
7555                }
7556                Some(m)
7557            };
7558            // Correction for one local coordinate slot `s` with curvature `hess`.
7559            let slot_correction = |s: usize, hess: f64| -> f64 {
7560                let Some(iv) = inv_vv.as_ref() else {
7561                    return 0.0;
7562                };
7563                if s >= q || hess == 0.0 {
7564                    return 0.0;
7565                }
7566                let mut d = Array2::<f64>::zeros((q, q));
7567                d[[s, s]] = hess;
7568                Self::deflation_block_correction(iv, &d, dirs, spectrum)
7569            };
7570            match self.last_row_layout {
7571                Some(ref layout) => {
7572                    let active = &layout.active_atoms[row];
7573                    let starts = &layout.coord_starts[row];
7574                    for (pos, &k) in active.iter().enumerate() {
7575                        if rho.log_ard[k].is_empty() {
7576                            continue;
7577                        }
7578                        let coord = &self.assignment.coords[k];
7579                        let d = coord.latent_dim();
7580                        let block_start = starts[pos];
7581                        for axis in 0..d {
7582                            let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7583                            let t = coord.row(row)[axis];
7584                            let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7585                            let hess = prior.hess.max(0.0);
7586                            let s = block_start + axis;
7587                            traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7588                            traces[k][axis] -= 0.5 * slot_correction(s, hess);
7589                        }
7590                    }
7591                }
7592                None => {
7593                    for k in 0..self.k_atoms() {
7594                        if rho.log_ard[k].is_empty() {
7595                            continue;
7596                        }
7597                        let coord = &self.assignment.coords[k];
7598                        let d = coord.latent_dim();
7599                        let block_start = coord_offsets[k];
7600                        for axis in 0..d {
7601                            let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7602                            let t = coord.row(row)[axis];
7603                            let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7604                            let hess = prior.hess.max(0.0);
7605                            let s = block_start + axis;
7606                            traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7607                            traces[k][axis] -= 0.5 * slot_correction(s, hess);
7608                        }
7609                    }
7610                }
7611            }
7612        }
7613        Ok(traces)
7614    }
7615
7616    /// Per-atom decoder-smoothness penalty quadratic form (#1556): entry `k` is
7617    /// the λ-free `<B_k, ½(S_k+S_kᵀ)·B_k> = Σ_oc B_k[:,oc]ᵀ S_k B_k[:,oc]`, the
7618    /// per-atom denominator of atom `k`'s λ_smooth Fellner-Schall update. The sum
7619    /// over atoms is `βᵀ(⊕_k S_k ⊗ I_p)β`, the un-scaled total penalty energy.
7620    /// `S_k` is symmetrised defensively (as the assembler does); the per-atom
7621    /// `½(S+Sᵀ)·B_k` GEMMs ride the multi-GPU batched smoothness GEMM with an
7622    /// exact per-atom CPU fallback.
7623    pub(crate) fn decoder_smoothness_quadratic_form_per_atom(&self) -> Vec<f64> {
7624        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
7625            .atoms
7626            .iter()
7627            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
7628            .collect();
7629        let sb_all = batched_smooth_sb(&sb_inputs, true);
7630        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7631        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
7632            per_atom[atom_idx] = (&atom.decoder_coefficients * sb).sum();
7633        }
7634        per_atom
7635    }
7636
7637    /// Per-atom effective penalized dof of the decoder smoothness penalty
7638    /// (#1556): entry `k` is `tr(S_β⁻¹ · M_k)` with `M_k = (λ_smooth[k]·S_k) ⊗ I`
7639    /// and `S_β⁻¹ = (H⁻¹)_ββ` the Schur-complement inverse, each atom scaled by
7640    /// its OWN `lambda_smooth[atom_idx]`. Built on
7641    /// [`ArrowFactorCache::schur_inverse_apply`]: column `(k,μ,oc)` of `M_k` is
7642    /// `λ_k·S_k[:,μ] ⊗ e_oc` (sparse), so we apply `S_β⁻¹` to that K-vector and
7643    /// read back `result[col]`. The total edf is the sum of the returned vector
7644    /// (a uniform/broadcast λ reproduces the historical global trace).
7645    pub(crate) fn decoder_smoothness_effective_dof_per_atom(
7646        &self,
7647        cache: &ArrowFactorCache,
7648        lambda_smooth: &[f64],
7649    ) -> Result<Vec<f64>, ArrowSchurError> {
7650        let p = self.output_dim();
7651        let frames_active = self.frames_active();
7652        let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7653            let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7654            (
7655                self.factored_beta_offsets(),
7656                Box::new(move |k: usize| ranks[k]),
7657            )
7658        } else {
7659            (self.beta_offsets(), Box::new(move |_k: usize| p))
7660        };
7661        let k = cache.k;
7662        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7663        let mut m_col = Array1::<f64>::zeros(k);
7664        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7665            let s = &atom.smooth_penalty;
7666            let m = atom.basis_size();
7667            let off = offsets[atom_idx];
7668            let r = out_dim(atom_idx);
7669            let lambda = lambda_smooth[atom_idx];
7670            let mut trace = 0.0_f64;
7671            for mu in 0..m {
7672                for oc in 0..r {
7673                    let col = off + mu * r + oc;
7674                    m_col.fill(0.0);
7675                    for nu in 0..m {
7676                        let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7677                        m_col[off + nu * r + oc] = lambda * s_nu_mu;
7678                    }
7679                    let z = cache.schur_inverse_apply(m_col.view())?;
7680                    trace += z[col];
7681                }
7682            }
7683            per_atom[atom_idx] = trace;
7684        }
7685        Ok(per_atom)
7686    }
7687
7688    /// Per-atom effective penalized dof via the deflated solver (#1556): entry
7689    /// `k` is `tr((H⁻¹)_ββ · M_k)` for `M_k = (λ_smooth[k]·S_k) ⊗ I`, each atom
7690    /// scaled by its OWN `lambda_smooth[atom_idx]`. The total is the sum.
7691    pub(crate) fn decoder_smoothness_effective_dof_with_solver_per_atom(
7692        &self,
7693        cache: &ArrowFactorCache,
7694        solver: &DeflatedArrowSolver<'_>,
7695        lambda_smooth: &[f64],
7696    ) -> Result<Vec<f64>, String> {
7697        let p = self.output_dim();
7698        // #972 / #977 T1: the cache's β block is the FACTORED border when frames
7699        // are active (`cache.k == factored_border_dim`), so the smoothness edf
7700        // trace `tr((H⁻¹)_ββ · M)` is taken over the same factored layout, with
7701        // `M = ⊕_k (λ_k S_k) ⊗ I_{r_k}` at the factored offsets (the `U_kᵀU_k = I`
7702        // collapse means the per-coordinate-channel penalty is `λ_k S_k`, exactly
7703        // as in the full-`B` `⊗ I_p` case but with `r_k` channels). On the
7704        // full-`B` path `frames_active` is false: `out_dim_k = p`, the offsets
7705        // are `beta_offsets`, and this is bit-for-bit the historical trace.
7706        let frames_active = self.frames_active();
7707        let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7708            let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7709            (
7710                self.factored_beta_offsets(),
7711                Box::new(move |k: usize| ranks[k]),
7712            )
7713        } else {
7714            (self.beta_offsets(), Box::new(move |_k: usize| p))
7715        };
7716        let k = cache.k;
7717        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7718        let mut m_col = Array1::<f64>::zeros(k);
7719        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7720            let s = &atom.smooth_penalty;
7721            let m = atom.basis_size();
7722            let off = offsets[atom_idx];
7723            let r = out_dim(atom_idx);
7724            let lambda = lambda_smooth[atom_idx];
7725            let mut trace = 0.0_f64;
7726            for mu in 0..m {
7727                for oc in 0..r {
7728                    let col = off + mu * r + oc;
7729                    // M[:,col] = λ_k · S_k[:,mu] ⊗ e_oc (nonzero at off+ν·r+oc).
7730                    m_col.fill(0.0);
7731                    for nu in 0..m {
7732                        let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7733                        m_col[off + nu * r + oc] = lambda * s_nu_mu;
7734                    }
7735                    let zero_t = Array1::<f64>::zeros(cache.delta_t_len());
7736                    let z = solver.solve(zero_t.view(), m_col.view())?.beta;
7737                    trace += z[col];
7738                }
7739            }
7740            per_atom[atom_idx] = trace;
7741        }
7742        Ok(per_atom)
7743    }
7744
7745    pub(crate) fn assignment_log_strength_hessian_trace(
7746        &self,
7747        rho: &SaeManifoldRho,
7748        cache: &ArrowFactorCache,
7749        solver: &DeflatedArrowSolver<'_>,
7750    ) -> Result<f64, String> {
7751        let k_atoms = self.k_atoms();
7752        // #1038 softmax: `H` carries the DENSE entropy block, and since the
7753        // entropy curvature scales linearly with `λ_sparse = exp(ρ)`,
7754        // `∂H/∂ρ = H_entropy` (the full dense per-row block, not just its
7755        // diagonal). The trace `½ tr(H⁻¹ ∂H/∂ρ)` must therefore contract the
7756        // dense `∂H/∂ρ` against the per-row selected-inverse BLOCK, mirroring the
7757        // dense `log|H|` and θ-adjoint — a diagonal-only contraction would
7758        // desync the ρ-gradient from the criterion. The assembled majorizer
7759        // `D = diag(Σ_j|H_kj|)` is itself DIAGONAL (#1419), so the contraction
7760        // reduces to `½ Σ_slot (H⁻¹)_{slot,slot}·D_atom`. On the dense `None`
7761        // layout the logit slot equals the atom position; on the compact
7762        // softmax top-`k` layout (#1408/#1409) the slots are the row's active
7763        // atoms — the SAME `D_atom` (full-`K` abs-row-sum) the assembly wrote.
7764        if let AssignmentMode::Softmax {
7765            temperature,
7766            sparsity,
7767        } = self.assignment.mode
7768        {
7769            if k_atoms <= 1 {
7770                return Ok(0.0);
7771            }
7772            let inv_tau = 1.0 / temperature;
7773            let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
7774            let penalty = gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
7775                k_atoms,
7776                temperature,
7777            );
7778            // Softmax uses the reduced K−1 free-logit chart on the dense layout
7779            // (last reference logit fixed); the compact layout carries one slot
7780            // per active atom. The diagonal selected inverse gives each slot's
7781            // (H⁻¹)_{slot,slot}.
7782            let assignment_dim = self.assignment.assignment_coord_dim();
7783            // Kept-subspace inverse diagonal: the deflated inverse assigns
7784            // `1/λ̃ = 1` to each per-row UNIT-stiffness direction `vᵢ`, so a raw
7785            // diagonal `D` contraction would spuriously add `½ Σ_i vᵢᵀ D vᵢ` (a
7786            // ρ-independent direction must add 0). `latent_inverse_diagonal_kept`
7787            // removes that per-row deflated diagonal centrally.
7788            let inv_diag = solver
7789                .latent_inverse_diagonal_kept()
7790                .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7791            let mut trace = 0.0_f64;
7792            for row in 0..self.n_obs() {
7793                let row_base = cache.row_offsets[row];
7794                // ∂(scale·D)/∂ρ = scale·D (linear in λ_sparse = eᵖ) — the SAME
7795                // operator the assembly and θ-adjoint differentiate.
7796                match self.last_row_layout {
7797                    Some(ref layout) => {
7798                        // #1410: the compact adjoint reads `D_kk` only for this
7799                        // row's `≤ top_k` active atoms, so compute those entries
7800                        // directly from the softmax row `a` via the active-only
7801                        // Gershgorin helper — no full-`K` `row_logits` copy and no
7802                        // full-`K` `d` vector. `a` itself is the irreducible `O(K)`
7803                        // softmax normalisation, computed once per row and shared
7804                        // across the row's active slots.
7805                        let a = crate::assignment::softmax_row(
7806                            self.assignment.logits.row(row),
7807                            temperature,
7808                        );
7809                        let a = a.as_slice().expect("softmax row must be contiguous");
7810                        let m = softmax_majorizer_log_mean(a);
7811                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7812                            let d_atom =
7813                                active_softmax_gershgorin_majorizer_entry(a, atom, m, scale);
7814                            trace += inv_diag[row_base + pos] * d_atom;
7815                        }
7816                    }
7817                    None => {
7818                        // Dense layout genuinely contracts every free logit slot's
7819                        // `D_kk`, so the full-`K` `d` is intrinsic here; keep the
7820                        // single-source dense majorizer call.
7821                        let row_logits: Vec<f64> = (0..k_atoms)
7822                            .map(|k| self.assignment.logits[[row, k]])
7823                            .collect();
7824                        let d = penalty.psd_majorizer_abs_row_sums(&row_logits, scale);
7825                        let q = cache.row_dims[row];
7826                        let logit_dim = assignment_dim.min(q);
7827                        for atom in 0..logit_dim {
7828                            trace += inv_diag[row_base + atom] * d[atom];
7829                        }
7830                    }
7831                }
7832            }
7833            return Ok(0.5 * trace);
7834        }
7835        let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
7836        if hdiag.is_empty() {
7837            return Ok(0.0);
7838        }
7839        // RAW selected-inverse diagonal: the per-row diagonal contraction uses the
7840        // DEFLATED inverse; the full kept-subspace + β-Schur/rotation deflation
7841        // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per row afterwards
7842        // (`deflation_block_correction`), exactly as the data trace does. The
7843        // cross-row off-diagonal pass below contracts only DISTINCT rows `i ≠ j`,
7844        // off any single-row `vᵢ`'s support, so it needs no deflation correction.
7845        let inv_diag = solver
7846            .latent_inverse_diagonal()
7847            .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7848        let assignment_dim = self.assignment.assignment_coord_dim();
7849        let total_t = cache.delta_t_len();
7850        // #932 FRONT C: row-local Takahashi selected inverse on the plain arrow
7851        // for the per-row deflation correction below (the diagonal trace already
7852        // uses the cheap `latent_inverse_diagonal`); gauge / cross-row Woodbury
7853        // fall back to the per-row full-system `solve` loop.
7854        let fast_selected = solver.plain_selected_inverse_available();
7855        let selected_beta_inv = if fast_selected && cache.k > 0 {
7856            solver
7857                .beta_inv()
7858                .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?
7859        } else {
7860            Array2::<f64>::zeros((0, 0))
7861        };
7862        // #1416 cross-row IBP source: the per-row block that the deflation
7863        // factorizes is the NO-SELF base `H₀'` — the rank-one self curvature
7864        // `d_k·J_ik²` is DOWNDATED from each logit diagonal and re-applied through
7865        // the Woodbury carrier. The full-`H` diagonal contraction below still uses
7866        // the full `hdiag` (which carries that self term), but the per-row
7867        // DEFLATION correction must use `(∂H₀'/∂ρ)_tt`, i.e. `hdiag` MINUS the
7868        // downdated self term — otherwise the Daleckii–Krein correction
7869        // mis-attributes the (un-deflated) Woodbury self curvature's derivative to
7870        // the deflated subspace. For non-IBP modes there is no Woodbury source and
7871        // the self term is `0` (the deflated block IS the full block).
7872        // #1416 (compact-layout completion): the IBP cross-row Woodbury source is
7873        // installed for BOTH the dense and the compact (#1420 top-`k`) layouts (see
7874        // `set_ibp_cross_row_source`, which emits `(g_base + pos, atom, z'_ik)` for
7875        // the active set under a compact layout), so the deflated base `H₀'` is the
7876        // no-self block in BOTH layouts. The self-curvature downdate below must
7877        // therefore run regardless of layout — gating it to the dense path (the
7878        // pre-fix bug) left the compact deflation correction differentiating the
7879        // un-downdated full block. For non-IBP modes `ibp_assignment_third_channels`
7880        // returns `None`, there is no Woodbury source, and `self_curv` is
7881        // identically 0 (the deflated block IS the full block).
7882        let cross_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
7883        let learnable_alpha = matches!(
7884            self.assignment.mode,
7885            AssignmentMode::IBPMap {
7886                learnable_alpha: true,
7887                ..
7888            }
7889        );
7890        let self_curv = |row: usize, atom: usize| -> f64 {
7891            let Some(ch) = cross_channels.as_ref() else {
7892                return 0.0;
7893            };
7894            let d_k = if learnable_alpha {
7895                ch.cross_row_d_logalpha[atom]
7896            } else {
7897                ch.cross_row_d[atom]
7898            };
7899            let j = ch.z_jac[row * k_atoms + atom];
7900            d_k * j * j
7901        };
7902        let mut trace = 0.0_f64;
7903        for row in 0..self.n_obs() {
7904            let row_base = cache.row_offsets[row];
7905            let assignment_base = row * k_atoms;
7906            let q = cache.row_dims[row];
7907            // Per-row diagonal `(∂H₀'/∂ρ)_tt` for the deflation correction: the
7908            // assignment prior curves only the logit/assignment slots (coordinate
7909            // slots are 0 — ARD handles those), MINUS the downdated cross-row self
7910            // curvature. The full-`H` trace contraction keeps the full `hdiag`.
7911            let mut d_diag = Array1::<f64>::zeros(q);
7912            match self.last_row_layout {
7913                Some(ref layout) => {
7914                    for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7915                        let d_slot = hdiag[assignment_base + atom];
7916                        trace += inv_diag[row_base + pos] * d_slot;
7917                        if pos < q {
7918                            d_diag[pos] = d_slot - self_curv(row, atom);
7919                        }
7920                    }
7921                }
7922                None => {
7923                    for free_idx in 0..assignment_dim {
7924                        let d_slot = hdiag[assignment_base + free_idx];
7925                        trace += inv_diag[row_base + free_idx] * d_slot;
7926                        if free_idx < q {
7927                            d_diag[free_idx] = d_slot - self_curv(row, free_idx);
7928                        }
7929                    }
7930                }
7931            }
7932            let dirs = cache
7933                .deflated_row_directions
7934                .get(row)
7935                .map(Vec::as_slice)
7936                .unwrap_or(&[]);
7937            if !dirs.is_empty() {
7938                let inv_vv = if fast_selected {
7939                    let (inv_vv, _inv_vbeta) = solver
7940                        .selected_inverse_row_blocks(row, &selected_beta_inv)
7941                        .map_err(|err| {
7942                            format!("assignment_log_strength_hessian_trace: selected inverse: {err}")
7943                        })?;
7944                    inv_vv
7945                } else {
7946                    let mut inv_vv = Array2::<f64>::zeros((q, q));
7947                    for col in 0..q {
7948                        let mut rhs_t = Array1::<f64>::zeros(total_t);
7949                        let rhs_beta = Array1::<f64>::zeros(cache.k);
7950                        rhs_t[row_base + col] = 1.0;
7951                        let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7952                            format!(
7953                                "assignment_log_strength_hessian_trace: selected inverse: {err}"
7954                            )
7955                        })?;
7956                        for r in 0..q {
7957                            inv_vv[[r, col]] = solved.t[row_base + r];
7958                        }
7959                    }
7960                    inv_vv
7961                };
7962                let mut d_mat = Array2::<f64>::zeros((q, q));
7963                for s in 0..q {
7964                    d_mat[[s, s]] = d_diag[s];
7965                }
7966                let spectrum = cache
7967                    .deflation_row_spectra
7968                    .get(row)
7969                    .and_then(Option::as_ref);
7970                trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
7971            }
7972        }
7973        // #1416: the IBP prior Hessian is `H_p = d·J Jᵀ + diag(s, c)`, where the
7974        // rank-one `d·J Jᵀ` couples EVERY row pair `(i, j)` in a column `k`
7975        // through the shared empirical mass `M_k`. The assembled `H` carries the
7976        // full `H_full = H₀' + U D Uᵀ` (Woodbury, `set_ibp_cross_row_source`), and
7977        // for fixed alpha the entire IBP prior scales with `λ = eᵖ`, so
7978        // `∂H_p/∂ρ = H_p`. The diagonal loop above already captures the `i = j`
7979        // self terms (the `d·J_ik²` summand lives in `hdiag`); this pass adds the
7980        // omitted off-diagonal `½·d_k·Σ_{i≠j}(H⁻¹)_{ik,jk}·J_ik·J_jk`. Only IBP
7981        // has the cross-row rank-one source; for other diagonal modes
7982        // `ibp_assignment_third_channels` returns `None` and the trace stays the
7983        // pure diagonal contraction.
7984        //
7985        // #1416 (compact completion): this pass is LAYOUT-AGNOSTIC. Under the dense
7986        // layout atom `k`'s logit slot is local position `k`
7987        // (`row_offsets[i] + k`); under the compact (#1420 top-`k`) layout only the
7988        // row's active atoms carry coordinates and atom `k` lives at local position
7989        // `pos` of `active_atoms[row]` (`row_offsets[i] + pos`). The Woodbury source
7990        // and the θ-adjoint already use this active-slot mapping, so gating the
7991        // cross-row pass to the dense layout (the pre-fix bug) dropped the
7992        // off-diagonal term from `∂log|H|/∂ρ` whenever the budget/`top_k` engaged
7993        // the compact layout. We build per-column active sites `(row, t_index)` once
7994        // — exactly the θ-adjoint `col_sites` construction — then contract the
7995        // off-diagonal `i ≠ j` remainder with one solve per active site.
7996        if let Some(channels) = cross_channels.as_ref() {
7997            let n = self.n_obs();
7998            let total_t = cache.delta_t_len();
7999            // This trace is ½ ∂log|H|/∂ρ. For FIXED-α IBP the whole prior
8000            // scales with λ=eᵖ so ∂H_p/∂ρ = H_p and the rank-one coefficient
8001            // is the VALUE `cross_row_d[k] = w·s'_k`. For LEARNABLE-α this trace
8002            // is ½ ∂log|H|/∂logα, and the rank-one block's logα-derivative is
8003            // `∂d_k/∂logα = w·∂s'_k/∂logα` (`cross_row_d_logalpha[k]`) — the same
8004            // α-derivative the DIAGONAL channel (`hessian_diag_log_alpha_derivative`)
8005            // already uses. Using the value `s'_k` here (the pre-fix bug) made the
8006            // off-diagonal inconsistent with the diagonal and the α-gradient wrong.
8007            // (`learnable_alpha` is the same flag the self-curvature downdate uses.)
8008            // Per-column active sites `(row, global t-index)`. Layout-agnostic.
8009            let mut col_sites: Vec<Vec<(usize, usize)>> = vec![Vec::new(); k_atoms];
8010            match self.last_row_layout {
8011                Some(ref layout) => {
8012                    for row in 0..n {
8013                        let base = cache.row_offsets[row];
8014                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8015                            col_sites[atom].push((row, base + pos));
8016                        }
8017                    }
8018                }
8019                None => {
8020                    for row in 0..n {
8021                        let base = cache.row_offsets[row];
8022                        for k in 0..k_atoms {
8023                            col_sites[k].push((row, base + k));
8024                        }
8025                    }
8026                }
8027            }
8028            let mut cross = 0.0_f64;
8029            for k in 0..k_atoms {
8030                let d_k = if learnable_alpha {
8031                    channels.cross_row_d_logalpha[k]
8032                } else {
8033                    channels.cross_row_d[k]
8034                };
8035                if d_k == 0.0 || col_sites[k].len() < 2 {
8036                    continue;
8037                }
8038                for &(i, t_i) in &col_sites[k] {
8039                    let j_ik = channels.z_jac[i * k_atoms + k];
8040                    if j_ik == 0.0 {
8041                        continue;
8042                    }
8043                    // (H⁻¹) column at row `i`'s active logit-`k` slot.
8044                    let mut rhs_t = Array1::<f64>::zeros(total_t);
8045                    let rhs_beta = Array1::<f64>::zeros(cache.k);
8046                    rhs_t[t_i] = 1.0;
8047                    let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8048                        format!("assignment_log_strength_hessian_trace: {err}")
8049                    })?;
8050                    for &(j, t_j) in &col_sites[k] {
8051                        if j == i {
8052                            continue;
8053                        }
8054                        let j_jk = channels.z_jac[j * k_atoms + k];
8055                        if j_jk == 0.0 {
8056                            continue;
8057                        }
8058                        cross += d_k * solved.t[t_j] * j_ik * j_jk;
8059                    }
8060                }
8061            }
8062            trace += cross;
8063        }
8064        Ok(0.5 * trace)
8065    }
8066
8067    pub(crate) fn learnable_ibp_forward_alpha_data_derivative(
8068        &self,
8069        rho: &SaeManifoldRho,
8070        target: ArrayView2<'_, f64>,
8071    ) -> Result<f64, String> {
8072        let AssignmentMode::IBPMap {
8073            temperature: _,
8074            learnable_alpha: true,
8075            ..
8076        } = self.assignment.mode
8077        else {
8078            return Ok(0.0);
8079        };
8080        let alpha = self
8081            .assignment
8082            .mode
8083            .resolved_ibp_alpha(rho)
8084            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8085        let k_atoms = self.k_atoms();
8086        let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8087        let mut dprior = Array1::<f64>::zeros(k_atoms);
8088        for k in 0..k_atoms {
8089            // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8090            // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8091            // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8092            dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8093        }
8094        let n = self.n_obs();
8095        let p = self.output_dim();
8096        let row_loss_w = self.row_loss_weights.as_deref();
8097        let whitens = self
8098            .row_metric
8099            .as_ref()
8100            .is_some_and(|metric| metric.whitens_likelihood());
8101        let mut decoded = vec![0.0_f64; p];
8102        let mut fitted = Array1::<f64>::zeros(p);
8103        let mut f_rho = Array1::<f64>::zeros(p);
8104        let mut residual = Array1::<f64>::zeros(p);
8105        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8106        let mut assignments = vec![0.0_f64; k_atoms];
8107        let mut total = 0.0_f64;
8108        for row in 0..n {
8109            self.assignment
8110                .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8111            fitted.fill(0.0);
8112            f_rho.fill(0.0);
8113            for k in 0..k_atoms {
8114                self.atoms[k].fill_decoded_row(row, &mut decoded);
8115                // Ungated (#1026 background-tier) atoms have a force-fixed unit
8116                // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8117                // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8118                // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8119                // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8120                // but `a_k` still varies with α through `π_k`, so it must NOT be
8121                // skipped.)
8122                let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8123                    0.0
8124                } else {
8125                    (assignments[k] / prior[k]) * dprior[k]
8126                };
8127                for out_col in 0..p {
8128                    fitted[out_col] += assignments[k] * decoded[out_col];
8129                    f_rho[out_col] += da_rho * decoded[out_col];
8130                }
8131            }
8132            for out_col in 0..p {
8133                residual[out_col] = fitted[out_col] - target[[row, out_col]];
8134            }
8135            let residual_metric = match self.row_metric.as_ref() {
8136                Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8137                _ => residual.to_vec(),
8138            };
8139            let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8140            let mut row_dot = 0.0_f64;
8141            for out_col in 0..p {
8142                row_dot += residual_metric[out_col] * f_rho[out_col];
8143            }
8144            total += row_weight * row_dot;
8145        }
8146        Ok(total)
8147    }
8148
8149    /// Per-row spectral-deflation correction `tr((H⁻¹)_tt · (D − DΦ[D]))` for one
8150    /// evidence ρ-component, to be SUBTRACTED from the raw-derivative trace
8151    /// `tr((H⁻¹)_tt · D)` the trace otherwise accumulates.
8152    ///
8153    /// The criterion VALUE re-deflates each per-row `H_tt` at every ρ, so the
8154    /// correct evidence gradient contracts `(H⁻¹)_tt` against the deflation-map
8155    /// derivative `DΦ[D]`, not the raw `D = (∂H_raw/∂ρ)_tt`. By Daleckii–Krein,
8156    /// in the row's RAW eigenbasis `U`,
8157    ///   `DΦ[D] = U (F ∘ (Uᵀ D U)) Uᵀ`,  `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)`
8158    /// (raw `λ` in the denominator, conditioned `λ̃` in the numerator; the
8159    /// diagonal / degenerate entry is `f'(λₘ) = 1` for an unclamped kept
8160    /// direction and `0` otherwise). Hence `D − DΦ[D] = U ((1−F) ∘ (Uᵀ D U)) Uᵀ`,
8161    /// whose kept×kept block is `0`, deflated×deflated block is the full `M`, and
8162    /// kept(m)×deflated(i) block carries the ROTATION coefficient
8163    /// `(1−λᵢ)/(λₘ−λᵢ)`. Contracting against the FULL deflated selected-inverse
8164    /// t-block `inv_vv` (which carries the β-Schur back-substitution) captures
8165    /// both the within-row kept-subspace term and the deferred β-Schur/rotation
8166    /// coupling in one pass, matching the re-deflating fixed-state FD oracle.
8167    ///
8168    /// `spectrum = Some` (spectral deflation): exact Daleckii–Krein. `None` with a
8169    /// non-empty `dirs` (gauge-only deflation, ρ-independent structural null):
8170    /// fall back to the within-row kept-subspace term `Σᵢ vᵢᵀ D vᵢ`.
8171    /// `inv_vv` is assumed symmetric (selected inverse of a symmetric PD system).
8172    fn deflation_block_correction(
8173        inv_vv: &Array2<f64>,
8174        d_mat: &Array2<f64>,
8175        dirs: &[Array1<f64>],
8176        spectrum: Option<&RowDeflationSpectrum>,
8177    ) -> f64 {
8178        let q = inv_vv.nrows();
8179        let Some(spec) = spectrum else {
8180            // Gauge-only deflation: ρ-independent structural null → within-row term.
8181            let mut acc = 0.0_f64;
8182            for v in dirs {
8183                for a in 0..q {
8184                    let va = if a < v.len() { v[a] } else { 0.0 };
8185                    if va == 0.0 {
8186                        continue;
8187                    }
8188                    for b in 0..q {
8189                        let vb = if b < v.len() { v[b] } else { 0.0 };
8190                        acc += va * vb * d_mat[[a, b]];
8191                    }
8192                }
8193            }
8194            return acc;
8195        };
8196        let u = &spec.evecs;
8197        if u.nrows() != q || u.ncols() != q {
8198            return 0.0;
8199        }
8200        let raw = &spec.raw_evals;
8201        let cond = &spec.cond_evals;
8202        // M = Uᵀ D U, W = Uᵀ inv_vv U (both q×q, symmetric).
8203        let m = u.t().dot(d_mat).dot(u);
8204        let w = u.t().dot(inv_vv).dot(u);
8205        // correction = Σ_{m,l} W[m,l]·M[m,l]·(1 − F[m,l]).
8206        let mut acc = 0.0_f64;
8207        let eps = 1.0e-12;
8208        for a in 0..q {
8209            for b in 0..q {
8210                let denom = raw[a] - raw[b];
8211                let f1 = if denom.abs() > eps {
8212                    (cond[a] - cond[b]) / denom
8213                } else if cond[a] == raw[a] {
8214                    1.0
8215                } else {
8216                    0.0
8217                };
8218                acc += w[[a, b]] * m[[a, b]] * (1.0 - f1);
8219            }
8220        }
8221        acc
8222    }
8223
8224    /// #1417: exact `½ tr(H⁻¹ ∂H_data/∂logα)` for LEARNABLE IBP alpha.
8225    ///
8226    /// The forward assignment is `a_ik = σ(ℓ_ik/τ)·π_k(α)` with the #614
8227    /// consistent stick-breaking mean `π_k(α) = (α/(α+1))^(k+1)`, so
8228    /// `∂logπ_k/∂logα = (k+1)/(α+1)`. EVERY data-Jacobian column for atom `k` —
8229    /// the logit-JVP row (carries one `π_k`), the coordinate rows (carry one
8230    /// `a_k`), and the β-leg (`a_k·φ`) — carries exactly ONE `a_k`/`π_k` factor
8231    /// (`σ(ℓ/τ)` is α-independent). Hence each Jacobian column scales as
8232    /// `∂J_·k/∂logα = ((k+1)/(α+1))·J_·k`, and the data Hessian block for the
8233    /// atom pair `(k_a, k_b)` scales as
8234    ///   ∂H_data[a,b]/∂logα = (((k_a+1) + (k_b+1))/(α+1))·H_data[a,b].
8235    /// Therefore the exact data-block contribution to the α-logdet trace is
8236    ///   ½ tr(H⁻¹ ∂H_data/∂logα)
8237    ///     = ½/(α+1) · Σ_{a,b} ((k_a+1) + (k_b+1))·(H⁻¹)_{ba}·H_data[a,b],
8238    /// over the full joint `(t, β)` index set. `H_data[a,b]` is the data-fit
8239    /// Gauss-Newton block built from the SAME `row_jets_for_logdet` first-jets the
8240    /// θ-adjoint uses (`H_tt = ⟨J_a,J_b⟩`, `H_tβ = ⟨J_a,J_β⟩`, `H_ββ = ⟨J_β,J_β'⟩`),
8241    /// and `(H⁻¹)` is contracted through the same per-row selected-inverse blocks.
8242    /// This closes the learnable-α gradient: combined with the prior-Hessian
8243    /// trace (`assignment_log_strength_hessian_trace`) the full
8244    /// `½ tr(H⁻¹ ∂H/∂logα)` is now assembled. For FIXED alpha (and non-IBP modes)
8245    /// this is identically zero.
8246    pub(crate) fn learnable_ibp_data_logdet_alpha_trace(
8247        &self,
8248        rho: &SaeManifoldRho,
8249        cache: &ArrowFactorCache,
8250        solver: &DeflatedArrowSolver<'_>,
8251    ) -> Result<f64, String> {
8252        let AssignmentMode::IBPMap {
8253            learnable_alpha: true,
8254            ..
8255        } = self.assignment.mode
8256        else {
8257            return Ok(0.0);
8258        };
8259        let alpha = self
8260            .assignment
8261            .mode
8262            .resolved_ibp_alpha(rho)
8263            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8264        let inv_alpha1 = 1.0 / (alpha + 1.0);
8265        let n = self.n_obs();
8266        let total_t = cache.delta_t_len();
8267        let second_jets = self.atom_second_jets()?;
8268        let border = self.border_channels_for_cache(cache)?;
8269
8270        // β-tier selected inverse `(H⁻¹)_ββ` (shared across rows). #932 FRONT C:
8271        // on the plain bordered arrow this is the cached dense `S⁻¹` formed once
8272        // (no `K` full-system solves); when a gauge / #1038 cross-row Woodbury is
8273        // active the row-local Takahashi blocks are NOT valid, so we fall back to
8274        // the per-β-coordinate `solve` loop (bit-identical, just O(n) per call).
8275        let fast_selected = solver.plain_selected_inverse_available();
8276        let beta_inv = if cache.k == 0 {
8277            Array2::<f64>::zeros((0, 0))
8278        } else if fast_selected {
8279            solver.beta_inv().map_err(|err| {
8280                format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8281            })?
8282        } else {
8283            let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8284            let rhs_t = Array1::<f64>::zeros(total_t);
8285            for col in 0..cache.k {
8286                let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8287                rhs_beta[col] = 1.0;
8288                let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8289                    format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8290                })?;
8291                for r in 0..cache.k {
8292                    beta_inv[[r, col]] = solved.beta[r];
8293                }
8294            }
8295            beta_inv
8296        };
8297        // Atom index of each β border channel (the `k_b` weight for the β leg).
8298        let border_atom: Vec<usize> = border.iter().map(|c| c.atom).collect();
8299
8300        let mut trace = 0.0_f64;
8301        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8302        let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8303        // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8304        // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8305        // rows fall back to the scalar per-row path (bit-identical either way).
8306        let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8307            std::collections::VecDeque::new();
8308        let mut jet_window_next = 0usize;
8309        for row in 0..n {
8310            let q = cache.row_dims[row];
8311            let base = cache.row_offsets[row];
8312            let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8313            self.assignment
8314                .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8315            if jet_window.is_empty() {
8316                jet_window_next = self.refill_jet_window(
8317                    rho,
8318                    jet_window_next,
8319                    cache,
8320                    &second_jets,
8321                    &border,
8322                    &mut jet_window,
8323                )?;
8324            }
8325            let jets = jet_window.pop_front().expect("jet window must be non-empty");
8326            // Atom index (k-weight) of each local t-var.
8327            let var_atom: Vec<usize> = jets
8328                .vars
8329                .iter()
8330                .map(|v| match *v {
8331                    SaeLocalRowVar::Logit { atom } => atom,
8332                    SaeLocalRowVar::Coord { atom, .. } => atom,
8333                })
8334                .collect();
8335
8336            // Per-row selected inverse blocks `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ`.
8337            // #932 FRONT C: row-local Takahashi (O(q·(q+K))) on the plain arrow;
8338            // per-row full-system `solve` loop (O(n·q)) under gauge / cross-row
8339            // Woodbury where the row-local blocks are not valid.
8340            let (inv_vv, inv_vbeta) = if fast_selected {
8341                solver
8342                    .selected_inverse_row_blocks(row, &beta_inv)
8343                    .map_err(|err| {
8344                        format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8345                    })?
8346            } else {
8347                let mut inv_vv = Array2::<f64>::zeros((q, q));
8348                let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8349                for col in 0..q {
8350                    let mut rhs_t = Array1::<f64>::zeros(total_t);
8351                    let rhs_beta = Array1::<f64>::zeros(cache.k);
8352                    rhs_t[base + col] = 1.0;
8353                    let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8354                        format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8355                    })?;
8356                    for r in 0..q {
8357                        inv_vv[[r, col]] = solved.t[base + r];
8358                    }
8359                    for b in 0..cache.k {
8360                        inv_vbeta[[col, b]] = solved.beta[b];
8361                    }
8362                }
8363                (inv_vv, inv_vbeta)
8364            };
8365
8366            // #1026 — UNGATED (background-tier) atoms have a force-fixed unit gate,
8367            // so their mass `a_k ≡ 1` is α-INDEPENDENT: every data-Jacobian column
8368            // for an ungated atom carries `a_k = 1`, NOT `π_k(α)`, so its α-exponent
8369            // is `e_k = 0`, not `k+1`. Gated atoms keep `e_k = k+1`. (The prior trace
8370            // handles ungated separately by zeroing the fixed-logit `z_jac`.)
8371            let kfac = |atom: usize| -> f64 {
8372                if self.assignment.ungated.get(atom).copied().unwrap_or(false) {
8373                    0.0
8374                } else {
8375                    (atom + 1) as f64
8376                }
8377            };
8378            // t–t block: Σ_{a,b} (e_a + e_b)·(H⁻¹)_{ba}·⟨J_a, J_b⟩, where the
8379            // per-atom log-prior exponent is e_k = k+1 for the #614 consistent
8380            // stick-breaking mean π_k = (α/(α+1))^(k+1) (dlogπ_k/dlogα = (k+1)·inv_alpha1).
8381            for a in 0..q {
8382                for b in 0..q {
8383                    let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8384                    if h_ab == 0.0 {
8385                        continue;
8386                    }
8387                    let kw = kfac(var_atom[a]) + kfac(var_atom[b]);
8388                    trace += kw * inv_vv[[b, a]] * h_ab;
8389                }
8390            }
8391            // Deflation correction (kept-subspace restriction + β-Schur/rotation).
8392            // `inv_vv` is the DEFLATED selected inverse, so the t–t contraction
8393            // above contracts the RAW derivative `D` where the re-deflating
8394            // criterion uses the deflation-map derivative `DΦ[D]`. Subtract the
8395            // exact over-count `tr(inv_vv·(D − DΦ[D]))` via the Daleckii–Krein
8396            // helper, with `D_{ab} = kw_ab·⟨J_a, J_b⟩` the SAME t–t operator the
8397            // trace contracts. The t–β/β–β blocks are not deflated, so only the
8398            // t–t contraction is corrected.
8399            let dirs = cache
8400                .deflated_row_directions
8401                .get(row)
8402                .map(Vec::as_slice)
8403                .unwrap_or(&[]);
8404            if !dirs.is_empty() {
8405                let mut d_mat = Array2::<f64>::zeros((q, q));
8406                for a in 0..q {
8407                    for b in 0..q {
8408                        let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8409                        if h_ab == 0.0 {
8410                            continue;
8411                        }
8412                        d_mat[[a, b]] = (kfac(var_atom[a]) + kfac(var_atom[b])) * h_ab;
8413                    }
8414                }
8415                let spectrum = cache
8416                    .deflation_row_spectra
8417                    .get(row)
8418                    .and_then(Option::as_ref);
8419                trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
8420            }
8421            // t–β and β–t blocks: appear symmetrically, contract once with the
8422            // factor 2 (H, H⁻¹ symmetric; `(H⁻¹)_βt = (H⁻¹)_tβᵀ`).
8423            for a in 0..q {
8424                for (beta_pos, channel) in border.iter().enumerate() {
8425                    let h_ab = sae_dot(&jets.first[a], &jets.beta[beta_pos]);
8426                    if h_ab == 0.0 {
8427                        continue;
8428                    }
8429                    let kw = kfac(var_atom[a]) + kfac(border_atom[beta_pos]);
8430                    trace += 2.0 * kw * inv_vbeta[[a, channel.index]] * h_ab;
8431                }
8432            }
8433            // β–β block: Σ_{β,β'} (k_β + k_β')·(H⁻¹)_{β'β}·⟨J_β, J_β'⟩.
8434            for (beta_i, channel_i) in border.iter().enumerate() {
8435                for (beta_j, channel_j) in border.iter().enumerate() {
8436                    let h_ab = sae_dot(&jets.beta[beta_i], &jets.beta[beta_j]);
8437                    if h_ab == 0.0 {
8438                        continue;
8439                    }
8440                    let kw = kfac(border_atom[beta_i]) + kfac(border_atom[beta_j]);
8441                    trace += kw * beta_inv[[channel_i.index, channel_j.index]] * h_ab;
8442                }
8443            }
8444        }
8445        Ok(0.5 * inv_alpha1 * trace)
8446    }
8447
8448    pub(crate) fn add_learnable_ibp_forward_alpha_data_rhs(
8449        &self,
8450        rho: &SaeManifoldRho,
8451        target: ArrayView2<'_, f64>,
8452        cache: &ArrowFactorCache,
8453        t: &mut Array1<f64>,
8454        beta: &mut Array1<f64>,
8455    ) -> Result<(), String> {
8456        let AssignmentMode::IBPMap {
8457            temperature,
8458            learnable_alpha: true,
8459            ..
8460        } = self.assignment.mode
8461        else {
8462            return Ok(());
8463        };
8464        let alpha = self
8465            .assignment
8466            .mode
8467            .resolved_ibp_alpha(rho)
8468            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8469        let k_atoms = self.k_atoms();
8470        let p = self.output_dim();
8471        let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8472        let mut dprior = Array1::<f64>::zeros(k_atoms);
8473        for k in 0..k_atoms {
8474            // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8475            // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8476            // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8477            dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8478        }
8479        let inv_tau = 1.0 / temperature;
8480        let row_loss_w = self.row_loss_weights.as_deref();
8481        let whitens = self
8482            .row_metric
8483            .as_ref()
8484            .is_some_and(|metric| metric.whitens_likelihood());
8485        let border = self.border_channels_for_cache(cache)?;
8486        let mut decoded_rows = vec![vec![0.0_f64; p]; k_atoms];
8487        let mut decoded_deriv = vec![0.0_f64; p];
8488        let mut fitted = Array1::<f64>::zeros(p);
8489        let mut f_rho = Array1::<f64>::zeros(p);
8490        let mut residual = Array1::<f64>::zeros(p);
8491        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8492        let mut assignments = vec![0.0_f64; k_atoms];
8493        for row in 0..self.n_obs() {
8494            self.assignment
8495                .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8496            fitted.fill(0.0);
8497            f_rho.fill(0.0);
8498            for k in 0..k_atoms {
8499                self.atoms[k].fill_decoded_row(row, &mut decoded_rows[k]);
8500                // Ungated (#1026 background-tier) atoms have a force-fixed unit
8501                // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8502                // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8503                // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8504                // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8505                // but `a_k` still varies with α through `π_k`, so it must NOT be
8506                // skipped.)
8507                let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8508                    0.0
8509                } else {
8510                    (assignments[k] / prior[k]) * dprior[k]
8511                };
8512                for out_col in 0..p {
8513                    fitted[out_col] += assignments[k] * decoded_rows[k][out_col];
8514                    f_rho[out_col] += da_rho * decoded_rows[k][out_col];
8515                }
8516            }
8517            for out_col in 0..p {
8518                residual[out_col] = fitted[out_col] - target[[row, out_col]];
8519            }
8520            let residual_metric = match self.row_metric.as_ref() {
8521                Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8522                _ => residual.to_vec(),
8523            };
8524            let f_metric = match self.row_metric.as_ref() {
8525                Some(metric) if whitens => metric.apply_metric_row(row, f_rho.view()),
8526                _ => f_rho.to_vec(),
8527            };
8528            let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8529            let row_vars = self.row_vars_for_cache_row(row, cache)?;
8530            let row_base = cache.row_offsets[row];
8531            for (pos, var) in row_vars.iter().enumerate() {
8532                let mut contribution = 0.0_f64;
8533                match *var {
8534                    SaeLocalRowVar::Logit { atom } => {
8535                        let sigma = assignments[atom] / prior[atom];
8536                        let sigma_jac = sigma * (1.0 - sigma) * inv_tau;
8537                        let da_dl = sigma_jac * prior[atom];
8538                        let d_da_rho_dl = sigma_jac * dprior[atom];
8539                        for out_col in 0..p {
8540                            contribution += da_dl * decoded_rows[atom][out_col] * f_metric[out_col];
8541                            contribution += d_da_rho_dl
8542                                * decoded_rows[atom][out_col]
8543                                * residual_metric[out_col];
8544                        }
8545                    }
8546                    SaeLocalRowVar::Coord { atom, axis } => {
8547                        let sigma = assignments[atom] / prior[atom];
8548                        let da_rho = sigma * dprior[atom];
8549                        self.atoms[atom].fill_decoded_derivative_row(row, axis, &mut decoded_deriv);
8550                        for out_col in 0..p {
8551                            contribution +=
8552                                assignments[atom] * decoded_deriv[out_col] * f_metric[out_col];
8553                            contribution +=
8554                                da_rho * decoded_deriv[out_col] * residual_metric[out_col];
8555                        }
8556                    }
8557                }
8558                t[row_base + pos] += row_weight * contribution;
8559            }
8560            for channel in &border {
8561                let phi = self.atoms[channel.atom].basis_values[[row, channel.basis_col]];
8562                let sigma = assignments[channel.atom] / prior[channel.atom];
8563                let da_rho = sigma * dprior[channel.atom];
8564                let mut contribution = 0.0_f64;
8565                for out_col in 0..p {
8566                    let output = channel.output[out_col];
8567                    contribution += assignments[channel.atom] * phi * output * f_metric[out_col];
8568                    contribution += da_rho * phi * output * residual_metric[out_col];
8569                }
8570                beta[channel.index] += row_weight * contribution;
8571            }
8572        }
8573        Ok(())
8574    }
8575
8576    pub(crate) fn border_channels_for_cache(
8577        &self,
8578        cache: &ArrowFactorCache,
8579    ) -> Result<Vec<SaeBorderChannel>, String> {
8580        let p = self.output_dim();
8581        let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8582        let offsets = if frames_active {
8583            self.factored_beta_offsets()
8584        } else {
8585            self.beta_offsets()
8586        };
8587        let mut channels = Vec::with_capacity(cache.k);
8588        for (atom_idx, atom) in self.atoms.iter().enumerate() {
8589            let m = atom.basis_size();
8590            let frame = if frames_active {
8591                self.frame_output_matrix(atom_idx)
8592            } else {
8593                Array2::<f64>::eye(p)
8594            };
8595            let r = frame.ncols();
8596            for basis_col in 0..m {
8597                for channel in 0..r {
8598                    let mut output = vec![0.0_f64; p];
8599                    for out_col in 0..p {
8600                        output[out_col] = frame[[out_col, channel]];
8601                    }
8602                    channels.push(SaeBorderChannel {
8603                        atom: atom_idx,
8604                        basis_col,
8605                        index: offsets[atom_idx] + basis_col * r + channel,
8606                        output,
8607                    });
8608                }
8609            }
8610        }
8611        if channels.len() != cache.k {
8612            return Err(format!(
8613                "border channel layout has {} entries but cache border has {}",
8614                channels.len(),
8615                cache.k
8616            ));
8617        }
8618        Ok(channels)
8619    }
8620
8621    pub(crate) fn row_vars_for_cache_row(
8622        &self,
8623        row: usize,
8624        cache: &ArrowFactorCache,
8625    ) -> Result<Vec<SaeLocalRowVar>, String> {
8626        let q_row = cache.row_dims[row];
8627        let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
8628        match self.last_row_layout {
8629            Some(ref layout) => {
8630                for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8631                    vars[pos] = Some(SaeLocalRowVar::Logit { atom });
8632                    let start = layout.coord_starts[row][pos];
8633                    let d = self.assignment.coords[atom].latent_dim();
8634                    for axis in 0..d {
8635                        vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8636                    }
8637                }
8638            }
8639            None => {
8640                let assignment_dim = self.assignment.assignment_coord_dim();
8641                let coord_offsets = self.assignment.coord_offsets();
8642                for atom in 0..assignment_dim {
8643                    vars[atom] = Some(SaeLocalRowVar::Logit { atom });
8644                }
8645                for atom in 0..self.k_atoms() {
8646                    let start = coord_offsets[atom];
8647                    let d = self.assignment.coords[atom].latent_dim();
8648                    for axis in 0..d {
8649                        vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8650                    }
8651                }
8652            }
8653        }
8654        vars.into_iter()
8655            .enumerate()
8656            .map(|(idx, v)| {
8657                v.ok_or_else(|| {
8658                    format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
8659                })
8660            })
8661            .collect()
8662    }
8663
8664    pub(crate) fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
8665        let mut out = Vec::with_capacity(self.k_atoms());
8666        for (atom_idx, atom) in self.atoms.iter().enumerate() {
8667            let coords = self.assignment.coords[atom_idx].as_matrix();
8668            let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
8669                second.second_jet(coords.view())?
8670            } else {
8671                let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
8672                    format!(
8673                        "logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
8674                        atom.name
8675                    )
8676                })?;
8677                evaluator
8678                    .second_jet_dyn(coords.view())
8679                    .ok_or_else(|| {
8680                        format!(
8681                            "logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
8682                            atom.name
8683                        )
8684                    })??
8685            };
8686            let expected = (
8687                atom.n_obs(),
8688                atom.basis_size(),
8689                atom.latent_dim,
8690                atom.latent_dim,
8691            );
8692            if jet.dim() != expected {
8693                return Err(format!(
8694                    "logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
8695                    atom.name,
8696                    jet.dim(),
8697                    expected
8698                ));
8699            }
8700            out.push(jet);
8701        }
8702        Ok(out)
8703    }
8704
8705    // [#780 line-count gate] The per-row jet / reconstruction-channel cluster
8706    // (`reconstruction_row_program_for_logdet`, the const-generic
8707    // reconstruction / β-border channel fills and their dynamic dispatchers,
8708    // `row_jets_for_logdet`, `row_jets_for_logdet_batch4`, `batch4_assemble`,
8709    // and `refill_jet_window`) lives in the sibling
8710    // `construction_row_jet_logdet_channels.rs` file, inlined via `include!`
8711    // below at module scope as a second `impl SaeManifoldTerm` block. Splitting
8712    // it out keeps this tracked file under the 10k limit; `include!` preserves
8713    // the identical module scope and private-field access.
8714
8715    pub(crate) fn assignment_prior_hdiag_derivative_entry(
8716        &self,
8717        rho: &SaeManifoldRho,
8718        row: usize,
8719        diag_atom: usize,
8720        wrt: SaeLocalRowVar,
8721        ibp_channels: Option<&IbpHessianDiagThirdChannels>,
8722    ) -> f64 {
8723        let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
8724            return 0.0;
8725        };
8726        match self.assignment.mode {
8727            AssignmentMode::Softmax { .. } => {
8728                // #1038: the softmax entropy Hessian is now stored DENSE in
8729                // `block.htt` and its full θ-derivative `∂H_{k,j}/∂z_w` (diagonal
8730                // AND off-diagonal) is added inline in `logdet_theta_adjoint` from
8731                // the shared `row_dense_hessian_logit_derivative`. Returning the
8732                // diagonal contribution here too would double-count, so this
8733                // primitive is silent for softmax — the dense path is the single
8734                // source for value, logdet, and adjoint.
8735                0.0
8736            }
8737            AssignmentMode::JumpReLU {
8738                temperature,
8739                threshold,
8740            } => {
8741                if diag_atom != wrt_atom {
8742                    return 0.0;
8743                }
8744                let logit = self.assignment.logits[[row, diag_atom]];
8745                if !crate::assignment::jumprelu_in_optimization_band(
8746                    logit,
8747                    threshold,
8748                    temperature,
8749                ) {
8750                    return 0.0;
8751                }
8752                let inv_tau = 1.0 / temperature;
8753                let activation =
8754                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
8755                let slope = activation * (1.0 - activation);
8756                // #1415: P(ℓ)=λσ((ℓ−θ)/τ); P''(ℓ)=(λ/τ²)s(1−2a) so the third
8757                // derivative is P'''(ℓ)=(λ/τ³)·s·(1−6a+6a²), because
8758                // d/dℓ[s(1−2a)] = (1/τ)s[(1−2a)²−2s] = (1/τ)s(1−6a+6a²).
8759                rho.lambda_sparse()
8760                    * slope
8761                    * (1.0 - 6.0 * activation + 6.0 * activation * activation)
8762                    * inv_tau
8763                    * inv_tau
8764                    * inv_tau
8765            }
8766            AssignmentMode::IBPMap { .. } => {
8767                // The assembled `htt` diagonal consumes
8768                // `IBPAssignmentPenalty::hessian_diag`, whose logit derivative
8769                // splits into a row-local direct-`z` channel and a global
8770                // empirical-`M_k` channel (π_k couples every row in column k).
8771                // This same-row primitive returns only the LOCAL direct-`z`
8772                // channel — and only on the matching logit (`diag_atom == w`),
8773                // since H_ik depends on no other row's z explicitly. The global
8774                // M_k channel is accumulated column-wise in
8775                // `logdet_theta_adjoint` (it needs the per-row selected-inverse
8776                // diagonals), so adding it here would double-count.
8777                if diag_atom != wrt_atom {
8778                    return 0.0;
8779                }
8780                match ibp_channels {
8781                    Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
8782                    None => 0.0,
8783                }
8784            }
8785        }
8786    }
8787
8788    pub(crate) fn ard_majorized_hessian_derivative(
8789        &self,
8790        rho: &SaeManifoldRho,
8791        row: usize,
8792        atom: usize,
8793        axis: usize,
8794    ) -> f64 {
8795        if rho.log_ard[atom].is_empty() {
8796            return 0.0;
8797        }
8798        let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8799        let periods = self.assignment.coords[atom].effective_axis_periods();
8800        let t = self.assignment.coords[atom].row(row)[axis];
8801        let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
8802        if prior.hess <= 0.0 {
8803            return 0.0;
8804        }
8805        match periods[axis] {
8806            None => 0.0,
8807            Some(period) => {
8808                let kappa = std::f64::consts::TAU / period;
8809                -alpha * kappa * (kappa * t).sin()
8810            }
8811        }
8812    }
8813
8814    pub fn outer_rho_gradient_ift_rhs(
8815        &self,
8816        rho: &SaeManifoldRho,
8817        target: ArrayView2<'_, f64>,
8818        j: usize,
8819        cache: &ArrowFactorCache,
8820    ) -> Result<SaeArrowVector, String> {
8821        let n_params = rho.to_flat().len();
8822        if j >= n_params {
8823            return Err(format!(
8824                "outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
8825            ));
8826        }
8827        let mut t = Array1::<f64>::zeros(cache.delta_t_len());
8828        let mut beta = Array1::<f64>::zeros(cache.k);
8829        if j == 0 {
8830            let assignment_grad =
8831                assignment_prior_log_strength_target_mixed(&self.assignment, rho)?;
8832            let k_atoms = self.k_atoms();
8833            let assignment_dim = self.assignment.assignment_coord_dim();
8834            for row in 0..self.n_obs() {
8835                let base = cache.row_offsets[row];
8836                let assignment_base = row * k_atoms;
8837                match self.last_row_layout {
8838                    Some(ref layout) => {
8839                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8840                            t[base + pos] = assignment_grad[assignment_base + atom];
8841                        }
8842                    }
8843                    None => {
8844                        for free_idx in 0..assignment_dim {
8845                            t[base + free_idx] = assignment_grad[assignment_base + free_idx];
8846                        }
8847                    }
8848                }
8849            }
8850            self.add_learnable_ibp_forward_alpha_data_rhs(rho, target, cache, &mut t, &mut beta)?;
8851        } else if (1..=rho.log_lambda_smooth.len()).contains(&j) {
8852            // #1556: coordinate `j ∈ 1..=K` is the per-atom smoothness strength
8853            // `log λ_smooth[j-1]`. `∂(penalty)/∂log λ_k = λ_k·S_k C_k` touches ONLY
8854            // atom `k = j-1`'s decoder block; every other atom's RHS is zero.
8855            let target_atom = j - 1;
8856            let lambda = rho.lambda_smooth_for(target_atom);
8857            let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8858            let offsets = if frames_active {
8859                self.factored_beta_offsets()
8860            } else {
8861                self.beta_offsets()
8862            };
8863            let atom = &self.atoms[target_atom];
8864            let m = atom.basis_size();
8865            let coeffs = if frames_active {
8866                match &atom.decoder_frame {
8867                    Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
8868                    None => atom.decoder_coefficients.clone(),
8869                }
8870            } else {
8871                atom.decoder_coefficients.clone()
8872            };
8873            let r = coeffs.ncols();
8874            let off = offsets[target_atom];
8875            for mu in 0..m {
8876                for channel in 0..r {
8877                    let mut acc = 0.0_f64;
8878                    for nu in 0..m {
8879                        let s_sym =
8880                            0.5 * (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
8881                        acc += s_sym * coeffs[[nu, channel]];
8882                    }
8883                    beta[off + mu * r + channel] = lambda * acc;
8884                }
8885            }
8886        } else {
8887            let mut cursor = 1 + rho.log_lambda_smooth.len();
8888            for atom in 0..rho.log_ard.len() {
8889                for axis in 0..rho.log_ard[atom].len() {
8890                    if cursor == j {
8891                        let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8892                        let periods = self.assignment.coords[atom].effective_axis_periods();
8893                        for row in 0..self.n_obs() {
8894                            let row_t = self.assignment.coords[atom].row(row);
8895                            let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
8896                            let Some(pos) = sae_coord_penalty_offset(
8897                                self.last_row_layout.as_ref(),
8898                                self.assignment.coord_offsets()[atom] + axis,
8899                                row,
8900                                atom,
8901                            ) else {
8902                                continue;
8903                            };
8904                            t[cache.row_offsets[row] + pos] = prior.grad;
8905                        }
8906                        return Ok(SaeArrowVector { t, beta });
8907                    }
8908                    cursor += 1;
8909                }
8910            }
8911        }
8912        Ok(SaeArrowVector { t, beta })
8913    }
8914
8915    pub(crate) fn logdet_theta_adjoint(
8916        &self,
8917        rho: &SaeManifoldRho,
8918        cache: &ArrowFactorCache,
8919        solver: &DeflatedArrowSolver<'_>,
8920    ) -> Result<SaeArrowVector, String> {
8921        // Γ_a = tr(H⁻¹ ∂H/∂θ_a) over the inner variables θ (#1006). `H` here is
8922        // the SAME object the evidence factor builds — Gauss-Newton data
8923        // curvature plus the prior majorizers / `hessian_diag` diagonals the
8924        // Newton/Schur Cholesky factorizes — so each block's θ-derivative channel
8925        // is differentiated on the criterion's own branch (no value/gradient
8926        // desync). The IBP-MAP assignment prior is the one block whose
8927        // `hessian_diag` couples every row in a column through the plug-in
8928        // empirical mass `M_k = Σ_i z_ik`; its logit derivative therefore has a
8929        // row-local channel (handled inline via
8930        // `assignment_prior_hdiag_derivative_entry`) and a cross-row channel
8931        // (accumulated column-wise after the row loop, below).
8932        let n = self.n_obs();
8933        let total_t = cache.delta_t_len();
8934        let mut gamma_t = Array1::<f64>::zeros(total_t);
8935        let mut gamma_beta = Array1::<f64>::zeros(cache.k);
8936        let second_jets = self.atom_second_jets()?;
8937        let border = self.border_channels_for_cache(cache)?;
8938        // #932 FRONT C: plain-arrow `(H⁻¹)_ββ = S⁻¹` formed once from the cached
8939        // Schur factor; gauge / #1038 cross-row Woodbury fall back to the per-β
8940        // `solve` loop where the row-local Takahashi blocks are not valid.
8941        let fast_selected = solver.plain_selected_inverse_available();
8942        let beta_inv = if cache.k == 0 {
8943            Array2::<f64>::zeros((0, 0))
8944        } else if fast_selected {
8945            solver
8946                .beta_inv()
8947                .map_err(|err| format!("logdet_theta_adjoint: beta selected inverse: {err}"))?
8948        } else {
8949            let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8950            let rhs_t = Array1::<f64>::zeros(total_t);
8951            for col in 0..cache.k {
8952                let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8953                rhs_beta[col] = 1.0;
8954                let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8955                    format!("logdet_theta_adjoint: beta selected inverse solve: {err}")
8956                })?;
8957                for row in 0..cache.k {
8958                    beta_inv[[row, col]] = solved.beta[row];
8959                }
8960            }
8961            beta_inv
8962        };
8963        // IBP `hessian_diag` logit third-derivative channels (#1006). The full
8964        // IBP Hessian also has per-column cross-row rank-one terms
8965        // `H_(i,k),(j,k) = d_k·J_ik·J_jk`; these ARE carried in `H` via the #1038
8966        // Woodbury source (`IbpCrossRowSource`, construction.rs:4710-4752), the
8967        // ρ-trace differentiates them (#1416,
8968        // `assignment_log_strength_hessian_trace`), AND this θ-adjoint now
8969        // differentiates them exactly too: the empirical-`M_k` channel below
8970        // contracts the shared-mass coupling of the DIAGONAL curvature, and the
8971        // cross-row Woodbury pass (further below, using `cross_row_dd` and
8972        // `logit_curvature`) contracts the `∂/∂ℓ_w (d_k·J_ik·J_jk)` rank-one
8973        // derivative — so value, logdet, ρ-trace, and θ-adjoint all differentiate
8974        // the one operator `H = H₀ + Σ_k d_k u_k u_kᵀ`.
8975        let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
8976        let k_atoms = self.k_atoms();
8977        // #1038 softmax entropy: the dense per-row entropy Hessian written into
8978        // `block.htt` has off-diagonal logit terms whose θ-derivative the adjoint
8979        // must contract too (not just the diagonal). Build the SAME penalty +
8980        // `scale = λ/τ²` the assembly uses so value/logdet/adjoint differentiate
8981        // one operator. `None` for non-softmax modes (their diagonal/cross-row
8982        // channels are handled by `assignment_prior_hdiag_derivative_entry` and
8983        // the IBP column pass).
8984        let softmax_dense_adjoint: Option<(
8985            gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
8986            f64,
8987        )> = match self.assignment.mode {
8988            AssignmentMode::Softmax {
8989                temperature,
8990                sparsity,
8991            } if k_atoms > 1 => {
8992                let inv_tau = 1.0 / temperature;
8993                let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
8994                Some((
8995                    gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
8996                        k_atoms,
8997                        temperature,
8998                    ),
8999                    scale,
9000                ))
9001            }
9002            _ => None,
9003        };
9004        // Per active logit position: (row i, column k, global t-index,
9005        // (H⁻¹)_ik,ik) — the inputs to the IBP cross-row empirical-`M_k` channel.
9006        let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
9007
9008        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
9009        let mut assignments = Array1::<f64>::zeros(self.k_atoms());
9010        // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
9011        // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
9012        // rows fall back to the scalar per-row path (bit-identical either way).
9013        let mut jet_window: std::collections::VecDeque<SaeRowJets> =
9014            std::collections::VecDeque::new();
9015        let mut jet_window_next = 0usize;
9016        for row in 0..n {
9017            let q = cache.row_dims[row];
9018            let base = cache.row_offsets[row];
9019            let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
9020            self.assignment
9021                .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
9022            if jet_window.is_empty() {
9023                jet_window_next = self.refill_jet_window(
9024                    rho,
9025                    jet_window_next,
9026                    cache,
9027                    &second_jets,
9028                    &border,
9029                    &mut jet_window,
9030                )?;
9031            }
9032            let jets = jet_window.pop_front().expect("jet window must be non-empty");
9033
9034            // #932 FRONT C: row-local Takahashi on the plain arrow; per-row
9035            // full-system `solve` loop under gauge / cross-row Woodbury.
9036            let (inv_vv, inv_vbeta) = if fast_selected {
9037                solver
9038                    .selected_inverse_row_blocks(row, &beta_inv)
9039                    .map_err(|err| {
9040                        format!("logdet_theta_adjoint: selected inverse: {err}")
9041                    })?
9042            } else {
9043                let mut inv_vv = Array2::<f64>::zeros((q, q));
9044                let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
9045                for col in 0..q {
9046                    let mut rhs_t = Array1::<f64>::zeros(total_t);
9047                    let rhs_beta = Array1::<f64>::zeros(cache.k);
9048                    rhs_t[base + col] = 1.0;
9049                    let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
9050                        format!("logdet_theta_adjoint: selected inverse solve: {err}")
9051                    })?;
9052                    for r in 0..q {
9053                        inv_vv[[r, col]] = solved.t[base + r];
9054                    }
9055                    for b in 0..cache.k {
9056                        inv_vbeta[[col, b]] = solved.beta[b];
9057                    }
9058                }
9059                (inv_vv, inv_vbeta)
9060            };
9061
9062            // Record each active logit's column, global t-index, and
9063            // selected-inverse diagonal (H⁻¹)_ik,ik for the IBP cross-row pass.
9064            if ibp_channels.is_some() {
9065                for (pos, var) in jets.vars.iter().enumerate() {
9066                    if let SaeLocalRowVar::Logit { atom } = *var {
9067                        ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
9068                    }
9069                }
9070            }
9071
9072            // #1419: when `w` is a logit and the assignment is softmax, the per-row
9073            // Gershgorin majorizer `D = diag(Σ_j|H_kj|)` is what the assembly wrote
9074            // into `htt` (the genuine Loewner majorizer that replaces the indefinite
9075            // exact entropy Hessian). Its full θ-derivative `∂D_{k,k}/∂z_w` (diagonal;
9076            // `∂D_kk/∂z_w = Σ_j sign(H_kj)·∂H_kj/∂z_w`) is the SAME operator the
9077            // assembly and logdet now differentiate, so value and adjoint stay on ONE
9078            // exact branch. Compute it once per logit `w` and add it at every logit
9079            // pair `(a,b)` below. The diagonal softmax case is therefore handled here,
9080            // NOT in `assignment_prior_hdiag_derivative_entry` (which returns 0 for
9081            // softmax to avoid double-counting).
9082            // #1410: the softmax majorizer θ-derivative `∂D_kk/∂z_w` is DIAGONAL
9083            // (`D` is diagonal), and the compact adjoint reads it only for this
9084            // row's `≤ top_k` active atoms. Compute the needed diagonal entry
9085            // directly from the softmax row `a` (= `assignments`, in hand) via
9086            // `active_softmax_majorizer_logit_derivative_entry`, instead of the old
9087            // per-(row, logit) full `K×K` `row_psd_majorizer_logit_derivative`
9088            // allocation. `m = Σ_j a_j l_j` is shared across all `(w, k)` pairs of
9089            // the row, so compute it once. `inv_tau` carries the softmax `∂a/∂z`
9090            // convention.
9091            let softmax_adjoint_row: Option<(&[f64], f64, f64, f64)> =
9092                match (softmax_dense_adjoint.as_ref(), self.assignment.mode) {
9093                    (Some((_penalty, scale)), AssignmentMode::Softmax { temperature, .. }) => {
9094                        let a = assignments
9095                            .as_slice()
9096                            .expect("softmax assignments row must be contiguous");
9097                        let m = softmax_majorizer_log_mean(a);
9098                        Some((a, m, *scale, 1.0 / temperature))
9099                    }
9100                    _ => None,
9101                };
9102            // Per-row UNIT-stiffness deflated directions: the selected inverse
9103            // `inv_vv` is the DEFLATED inverse (it assigns `1/λ̃ = 1` to each
9104            // `vᵢ`), so every `inv_vv`-weighted t–t contraction of `∂H/∂θ_w`
9105            // below spuriously contracts the RAW derivative where the re-deflating
9106            // criterion uses the deflation-map derivative `DΦ`. The kept-subspace Γ
9107            // subtracts `tr(inv_vv·(D − DΦ[D]))` over the t–t block via the same
9108            // Daleckii–Krein helper the ρ-traces use (the t–β / β–β blocks are not
9109            // deflated). `θ` enters only the per-row block (no cross-row Woodbury
9110            // self-downdate on the θ path), so the raw t–t derivative `D` is used
9111            // directly.
9112            let defl_dirs = cache
9113                .deflated_row_directions
9114                .get(row)
9115                .map(Vec::as_slice)
9116                .unwrap_or(&[]);
9117            let defl_spectrum = cache
9118                .deflation_row_spectra
9119                .get(row)
9120                .and_then(Option::as_ref);
9121            for w in 0..q {
9122                let mut gamma = 0.0_f64;
9123                // The active logit `w` differentiates against; `None` unless this
9124                // slot is a softmax logit on the softmax path.
9125                let softmax_d_dw: Option<(&[f64], f64, f64, f64, usize)> =
9126                    match (softmax_adjoint_row, jets.vars[w]) {
9127                        (Some((a, m, scale, inv_tau)), SaeLocalRowVar::Logit { atom: atom_w }) => {
9128                            Some((a, m, scale, inv_tau, atom_w))
9129                        }
9130                        _ => None,
9131                    };
9132                let mut dh_mat = Array2::<f64>::zeros((q, q));
9133                for a in 0..q {
9134                    for b in 0..q {
9135                        let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
9136                            + sae_dot(&jets.first[a], &jets.second[b][w]);
9137                        // `∂D/∂z_w` is diagonal, so it contributes only when the two
9138                        // logit slots are the SAME atom (`atom_a == atom_b`).
9139                        if let (
9140                            Some((a_soft, m, scale, inv_tau, _atom_w)),
9141                            SaeLocalRowVar::Logit { atom: atom_a },
9142                            SaeLocalRowVar::Logit { atom: atom_b },
9143                        ) = (softmax_d_dw, jets.vars[a], jets.vars[b])
9144                        {
9145                            if atom_a == atom_b {
9146                                dh += active_softmax_majorizer_logit_derivative_entry(
9147                                    a_soft, atom_a, _atom_w, m, scale, inv_tau,
9148                                );
9149                            }
9150                        }
9151                        if a == b {
9152                            dh += match jets.vars[a] {
9153                                SaeLocalRowVar::Logit { atom } => self
9154                                    .assignment_prior_hdiag_derivative_entry(
9155                                        rho,
9156                                        row,
9157                                        atom,
9158                                        jets.vars[w],
9159                                        ibp_channels.as_ref(),
9160                                    ),
9161                                SaeLocalRowVar::Coord { atom, axis } if a == w => {
9162                                    self.ard_majorized_hessian_derivative(rho, row, atom, axis)
9163                                }
9164                                _ => 0.0,
9165                            };
9166                        }
9167                        dh_mat[[a, b]] = dh;
9168                        gamma += inv_vv[[b, a]] * dh;
9169                    }
9170                }
9171                if !defl_dirs.is_empty() {
9172                    gamma -= Self::deflation_block_correction(
9173                        &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9174                    );
9175                }
9176                for a in 0..q {
9177                    for (beta_pos, channel) in border.iter().enumerate() {
9178                        let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
9179                            + sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
9180                        gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9181                    }
9182                }
9183                for (beta_i, channel_i) in border.iter().enumerate() {
9184                    for (beta_j, channel_j) in border.iter().enumerate() {
9185                        let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
9186                            + sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
9187                        gamma += beta_inv[[channel_i.index, channel_j.index]] * dh;
9188                    }
9189                }
9190                gamma_t[base + w] = gamma;
9191            }
9192
9193            for (w_beta_pos, w_channel) in border.iter().enumerate() {
9194                let mut gamma = 0.0_f64;
9195                let mut dh_mat = Array2::<f64>::zeros((q, q));
9196                for a in 0..q {
9197                    for b in 0..q {
9198                        let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
9199                            + sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
9200                        dh_mat[[a, b]] = dh;
9201                        gamma += inv_vv[[b, a]] * dh;
9202                    }
9203                }
9204                if !defl_dirs.is_empty() {
9205                    gamma -= Self::deflation_block_correction(
9206                        &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9207                    );
9208                }
9209                for a in 0..q {
9210                    for (beta_pos, channel) in border.iter().enumerate() {
9211                        let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
9212                        gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9213                    }
9214                }
9215                gamma_beta[w_channel.index] += gamma;
9216            }
9217        }
9218
9219        // IBP cross-row empirical-`M_k` channel of Γ (#1006). The assembled
9220        // diagonal H_ik consumes `hessian_diag`, whose dependence on the column
9221        // mass M_k = Σ_i z_ik couples every row in a column. Differentiating
9222        // tr(H⁻¹ ∂H/∂ℓ_wk) on that shared branch:
9223        //   Γ_wk += [ Σ_i (H⁻¹)_ik,ik · ∂_M H_ik ] · J_wk = C_k · J_wk,
9224        // where ∂_M H_ik = `m_channel[i*K+k]` and J_wk = `z_jac[w*K+k]`. The
9225        // row-local direct-`z` channel was already added inline above, so this
9226        // pass adds only the cross-row remainder (it spans `w ≠ i` and the
9227        // self-row M_k self-coupling, which the row-local primitive deliberately
9228        // omits to avoid double-counting).
9229        if let Some(channels) = ibp_channels.as_ref() {
9230            let mut col_coeff = vec![0.0_f64; k_atoms];
9231            for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
9232                col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
9233            }
9234            for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
9235                gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
9236            }
9237
9238            // #1416 / #1641: the EXACT cross-row Woodbury derivative of Γ. The
9239            // assembled `H` carries the per-column rank-one block
9240            // `W_k = d_k·u_k u_kᵀ` with `u_k` the J-weighted column indicator
9241            // (`u_k[slot(i,k)] = J_ik`) and `d_k = w·s'_k` (`cross_row_d[k]`). Both
9242            // `d_k` (through `M_k`) and the `u_k` entries (through `ℓ_ik`) depend on
9243            // the logits, so
9244            //   ∂W_k/∂ℓ_wk = dd_k·J_wk·u_k u_kᵀ
9245            //               + d_k·c_wk·(e_w u_kᵀ + u_k e_wᵀ),
9246            // where `dd_k = ∂d_k/∂M_k = w·s''_k` (`cross_row_dd[k]`),
9247            // `c_wk = ∂J_wk/∂ℓ_wk` (`logit_curvature`), and `e_w` is the unit
9248            // vector at row `w`'s logit-`k` slot.
9249            //
9250            // The θ-adjoint contracts the FULL trace `Γ_wk = tr(H⁻¹ ∂H/∂ℓ_wk)`
9251            // (NOT the `½ tr` the ρ-trace uses — `fixed_state_logdet` differentiates
9252            // the full `log|H|`, and the per-row blocks above contract `inv_vv·dh`
9253            // with no ½). Critically, the `i=j` self curvature `w·s'_k·J_ik²` of the
9254            // rank-one block lives on the assembled `htt` DIAGONAL `H_ik`, so its
9255            // derivative is ALREADY differentiated by the row-local
9256            // `local_logit_third` channel (direct-z, `i=w`) and the `m_channel`
9257            // column pass (via `M_k`) above. This Woodbury pass must therefore add
9258            // ONLY the off-diagonal `i≠j` remainder — otherwise the self term is
9259            // double-counted (the #1641 defect: the pre-fix pass summed the full
9260            // `u_k u_kᵀ` including `i=j`, AND carried the ρ-trace ½, AND dropped the
9261            // factor 2 on the symmetric `e_w u_kᵀ + u_k e_wᵀ` term). Excluding `i=j`
9262            // is also why this pass needs no deflation correction: it contracts only
9263            // DISTINCT rows, off any single-row `vᵢ`'s support (matching the
9264            // #1416 ρ-trace cross-row pass).
9265            //
9266            // Contracting `tr(H⁻¹ ∂W_k/∂ℓ_wk)` over `i≠j` only:
9267            //   Γ_wk += dd_k·J_wk·( u_kᵀ H⁻¹ u_k − Σ_i P_ii·J_ik² )       (term A)
9268            //         + 2·d_k·c_wk·( (H⁻¹ u_k)_{slot(w,k)} − P_ww·J_wk )  (term B),
9269            // where `P_ii = (H⁻¹)_{slot(i,k),slot(i,k)}` is the selected-inverse
9270            // diagonal recorded in `ibp_logit_sites`. The subtracted self pieces are
9271            // exactly the `i=j` terms the diagonal channels own. Both `u_kᵀ H⁻¹ u_k`
9272            // and `(H⁻¹ u_k)` come from ONE solve per column, `x_k = H⁻¹ u_k` — so
9273            // the adjoint differentiates the SAME `H = H₀ + Σ_k W_k` the
9274            // value/logdet use, closing the one-operator contract on the rank-one
9275            // block too.
9276            //
9277            // Group the column sites once (the layout is mode-agnostic: dense or
9278            // compact, `ibp_logit_sites` already carries each active logit's
9279            // global t-index AND its selected-inverse diagonal `G_ii`), then per
9280            // column build `u_k`, solve, and distribute the OFF-DIAGONAL remainder.
9281            //
9282            // #1416 FIX: the diagonal (`i = w`) parts of term A and term B are
9283            // ALREADY supplied — `diag(term A) = dd_k·J_w·Σ_i G_ii·J_i²` by the
9284            // `m_channel` column pass above (whose `m_channel = w·(s''·J² + s'·c)`
9285            // carries the `s''·J²` self piece), and `diag(term B) = 2·d_k·c_w·G_ww·J_w`
9286            // by the inline `local_logit_third` self channel (whose
9287            // `s'·2J·∂_z J` piece is exactly that). So this pass must add ONLY the
9288            // cross-row off-diagonal remainder; double-counting the diagonal here
9289            // (the pre-fix `0.5·dd·J·uᵀGu + d·c·x_w` form, which is neither the
9290            // full nor the off-diagonal value) desynced the θ-adjoint from the FD
9291            // of `log|H|`. The exact `tr(H⁻¹ ∂W_k/∂ℓ_wk)` is
9292            //   Γ_wk += dd_k·J_wk·(uᵀ G u − Σ_i G_ii·J_ik²)   (term A, off-diagonal)
9293            //         + 2·d_k·c_wk·((G u)_w − G_ww·J_wk)        (term B, off-diagonal),
9294            // with `uᵀGu = Σ_i J_ik·(Gu)_i`, `(Gu) = x_k = H⁻¹ u_k` from one solve,
9295            // and `G_ii` the per-site selected-inverse diagonal.
9296            let total_t = cache.delta_t_len();
9297            let mut col_sites: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); k_atoms];
9298            for &(row, atom, t_index, inv_diag) in &ibp_logit_sites {
9299                col_sites[atom].push((row, t_index, inv_diag));
9300            }
9301            for atom in 0..k_atoms {
9302                let d_k = channels.cross_row_d[atom];
9303                let dd_k = channels.cross_row_dd[atom];
9304                if col_sites[atom].is_empty() || (d_k == 0.0 && dd_k == 0.0) {
9305                    continue;
9306                }
9307                // u_k as a full t-RHS: J at each active logit-k slot.
9308                let mut rhs_t = Array1::<f64>::zeros(total_t);
9309                let rhs_beta = Array1::<f64>::zeros(cache.k);
9310                for &(row, t_index, _g) in &col_sites[atom] {
9311                    rhs_t[t_index] = channels.z_jac[row * k_atoms + atom];
9312                }
9313                let x_k = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
9314                    format!("logdet_theta_adjoint: IBP cross-row Woodbury solve: {err}")
9315                })?;
9316                // (JᵀH⁻¹J)_k = u_kᵀ x_k, and the diagonal `Σ_i G_ii·J_ik²` that the
9317                // `m_channel` pass already counted (subtract it from term A so this
9318                // pass holds only the off-diagonal `i ≠ j` remainder).
9319                let mut jt_hinv_j = 0.0_f64;
9320                let mut diag_jt_g_j = 0.0_f64;
9321                for &(row, t_index, g_ii) in &col_sites[atom] {
9322                    let j = channels.z_jac[row * k_atoms + atom];
9323                    jt_hinv_j += j * x_k.t[t_index];
9324                    diag_jt_g_j += g_ii * j * j;
9325                }
9326                let off_diag_a = jt_hinv_j - diag_jt_g_j;
9327                for &(row, t_index, g_ii) in &col_sites[atom] {
9328                    let j_wk = channels.z_jac[row * k_atoms + atom];
9329                    let c_wk = channels.logit_curvature[row * k_atoms + atom];
9330                    // term A (off-diagonal) + term B (off-diagonal); the inline /
9331                    // `m_channel` passes already added the diagonal parts.
9332                    let off_diag_b = x_k.t[t_index] - g_ii * j_wk;
9333                    gamma_t[t_index] += dd_k * j_wk * off_diag_a + 2.0 * d_k * c_wk * off_diag_b;
9334                }
9335            }
9336        }
9337
9338        Ok(SaeArrowVector {
9339            t: gamma_t,
9340            beta: gamma_beta,
9341        })
9342    }
9343
9344    /// #1418: apply the EXACT stationarity-Jacobian correction `ΔC·v = (A − B)·v`
9345    /// to a joint `(t, β)` vector, matrix-free and per row.
9346    ///
9347    /// `A = ∇²_θθ L` is the true inner-fit Hessian; `B` is the assembled
9348    /// evidence/Newton operator the solver factors. They differ ONLY by the three
9349    /// curvature substitutions the assembly makes for stability:
9350    ///   1. data: `B` uses Gauss-Newton `J̃J̃ᵀ`, dropping the residual curvature
9351    ///      `R[a,b] = Σ_out r_out·∂²f_out/∂θ_a∂θ_b` (t–t via `jets.second`, t–β via
9352    ///      `jets.beta_deriv`; the decoder is linear in β so the β–β block is 0);
9353    ///   2. softmax: `B` uses the Gershgorin majorizer `D = diag(Σ_j|H_kj|)`,
9354    ///      dropping `H_entropy − D` (#1419);
9355    ///   3. periodic ARD: `B` uses `max(V'',0)`, dropping the negative part
9356    ///      `min(V'',0)` (the indefinite tail past a quarter period).
9357    /// `ΔC` is the sum of exactly these three deltas, each built from the SAME
9358    /// jets / penalty curvatures the assembly and the θ-adjoint use, so
9359    /// `A = B + ΔC` is the one true Hessian. Exact on BOTH the isotropic and the
9360    /// whitened-metric paths: the data fit is `½ r_nᵀ M_n r_n`, so the residual
9361    /// curvature is `Σ_out (M_n r_n)_out·∂²f_out/∂θ_a∂θ_b` — contract the
9362    /// metric-applied √w-scaled residual `error_metric = √w·M_n r_n` (the SAME
9363    /// quantity the assembly's β-tier gradient uses) against the RAW second jets
9364    /// `jets.second`/`jets.beta_deriv` (the same raw-jet convention the whole
9365    /// θ-adjoint and the Gauss-Newton `htt = J̃J̃ᵀ = J M Jᵀ` assembly use). On the
9366    /// isotropic path `M_n = I` so `error_metric = √w·r` and `J M Jᵀ = JJᵀ`,
9367    /// recovering the plain case. The softmax / ARD deltas are logit/coord-space
9368    /// prior curvatures and carry no output metric, so they are path-independent.
9369    fn apply_exact_hessian_minus_b(
9370        &self,
9371        rho: &SaeManifoldRho,
9372        target: ArrayView2<'_, f64>,
9373        cache: &ArrowFactorCache,
9374        v: &SaeArrowVector,
9375    ) -> Result<SaeArrowVector, String> {
9376        let p = self.output_dim();
9377        let n = self.n_obs();
9378        let k_atoms = self.k_atoms();
9379        let total_t = cache.delta_t_len();
9380        let second_jets = self.atom_second_jets()?;
9381        let border = self.border_channels_for_cache(cache)?;
9382        let row_loss_w = self.row_loss_weights.as_deref();
9383        let ard_axis_periods: Vec<Vec<Option<f64>>> = self
9384            .assignment
9385            .coords
9386            .iter()
9387            .map(|coord| coord.effective_axis_periods())
9388            .collect();
9389
9390        // Optional softmax exact-entropy-minus-majorizer delta operator (#1419).
9391        let softmax_delta: Option<(
9392            gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
9393            f64,
9394        )> = match self.assignment.mode {
9395            AssignmentMode::Softmax {
9396                temperature,
9397                sparsity,
9398            } if k_atoms > 1 => {
9399                let inv_tau = 1.0 / temperature;
9400                let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
9401                Some((
9402                    gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
9403                        k_atoms,
9404                        temperature,
9405                    ),
9406                    scale,
9407                ))
9408            }
9409            _ => None,
9410        };
9411
9412        let mut out = SaeArrowVector {
9413            t: Array1::<f64>::zeros(total_t),
9414            beta: Array1::<f64>::zeros(cache.k),
9415        };
9416        let whitens = self
9417            .row_metric
9418            .as_ref()
9419            .is_some_and(|metric| metric.whitens_likelihood());
9420        let mut decoded = vec![0.0_f64; p];
9421        let mut fitted = Array1::<f64>::zeros(p);
9422        let mut error = Array1::<f64>::zeros(p);
9423        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
9424        let mut assignments = Array1::<f64>::zeros(self.k_atoms());
9425        // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
9426        // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
9427        // rows fall back to the scalar per-row path (bit-identical either way).
9428        let mut jet_window: std::collections::VecDeque<SaeRowJets> =
9429            std::collections::VecDeque::new();
9430        let mut jet_window_next = 0usize;
9431        for row in 0..n {
9432            let q = cache.row_dims[row];
9433            let base = cache.row_offsets[row];
9434            let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
9435            self.assignment
9436                .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
9437            if jet_window.is_empty() {
9438                jet_window_next = self.refill_jet_window(
9439                    rho,
9440                    jet_window_next,
9441                    cache,
9442                    &second_jets,
9443                    &border,
9444                    &mut jet_window,
9445                )?;
9446            }
9447            let jets = jet_window.pop_front().expect("jet window must be non-empty");
9448            let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
9449
9450            // √w-scaled metric-applied per-row residual `error_metric = √w·M_n r_n`
9451            // (the SAME object the assembly's β-tier gradient contracts). The
9452            // data-fit `½ r_nᵀ M_n r_n` has residual curvature `Σ (M_n r_n)·∂²f`,
9453            // so this is exactly the residual contracted against the raw `∂²f`
9454            // jets. `M_n = I` on the isotropic path ⇒ `error_metric = √w·r`.
9455            fitted.fill(0.0);
9456            for k in 0..k_atoms {
9457                self.atoms[k].fill_decoded_row(row, &mut decoded);
9458                let a_k = assignments[k];
9459                for out_col in 0..p {
9460                    fitted[out_col] += a_k * decoded[out_col];
9461                }
9462            }
9463            for out_col in 0..p {
9464                error[out_col] = sqrt_row_w * (fitted[out_col] - target[[row, out_col]]);
9465            }
9466            let error_metric: Vec<f64> = match self.row_metric.as_ref() {
9467                Some(metric) if whitens => metric.apply_metric_row(row, error.view()),
9468                _ => error.to_vec(),
9469            };
9470
9471            // Local t-slice of `v` for this row.
9472            let v_t: Vec<f64> = (0..q).map(|c| v.t[base + c]).collect();
9473
9474            // (1a) residual curvature, t–t: ΔC_tt[a,b] = ⟨r, ∂²f_ab⟩.
9475            for a in 0..q {
9476                let mut acc = 0.0_f64;
9477                for b in 0..q {
9478                    let r_ab = sae_dot(&error_metric, &jets.second[a][b]);
9479                    acc += r_ab * v_t[b];
9480                }
9481                out.t[base + a] += acc;
9482            }
9483            // (1b) residual curvature, t–β and β–t: ΔC_tβ[a,β] = ⟨r, ∂²f_aβ⟩.
9484            //      `jets.beta_deriv[a][β]` = ∂(∂f/∂β_β)/∂θ_a (the mixed second jet).
9485            for a in 0..q {
9486                for (beta_pos, channel) in border.iter().enumerate() {
9487                    let r_ab = sae_dot(&error_metric, &jets.beta_deriv[a][beta_pos]);
9488                    // t row picks up β leg of v; β row picks up t leg of v.
9489                    out.t[base + a] += r_ab * v.beta[channel.index];
9490                    out.beta[channel.index] += r_ab * v_t[a];
9491                }
9492            }
9493
9494            // (2) softmax: ΔC_logit = (H_entropy − D) over the free logits, where
9495            // `D = diag(Σ_j|H_kj|)` is the Gershgorin majorizer the assembled `B`
9496            // wrote into the logit block (#1419). Adding `H_entropy − D` recovers the
9497            // EXACT entropy curvature `A = B + ΔC`, so the solver's exact-Hessian
9498            // correction differentiates the SAME operator the assembly installed.
9499            if let Some((_penalty, scale)) = softmax_delta.as_ref() {
9500                let assignment_dim = self.assignment.assignment_coord_dim();
9501                // #1410: the correction only contracts the ACTIVE logit slots
9502                // (`jets.vars` carries the row's `≤ top_k` active atoms on the
9503                // compact layout), so build only the active sub-block of
9504                // `ΔC = H_entropy − D` ENTRY-WISE rather than materialising the
9505                // full `K×K` `row_dense_hessian` / `row_psd_majorizer` matrices per
9506                // row (an `O(K²)`-per-row allocation that defeated the compact
9507                // contract at the LLM shape). `D` is diagonal, so it subtracts only
9508                // on `ka == kb`; the off-diagonal `H_entropy` entries come from the
9509                // shared `(a, l, m)` algebra. The softmax row `a_soft` is the one
9510                // irreducible `O(K)` term, computed once per row.
9511                // #1557 — reuse this iteration's `assignments` (bit-identical).
9512                let a_soft = assignments
9513                    .as_slice()
9514                    .expect("softmax assignments row must be contiguous");
9515                let m = softmax_majorizer_log_mean(a_soft);
9516                for (a, va) in jets.vars.iter().enumerate() {
9517                    let SaeLocalRowVar::Logit { atom: ka } = *va else {
9518                        continue;
9519                    };
9520                    if ka >= assignment_dim {
9521                        continue;
9522                    }
9523                    let mut acc = 0.0_f64;
9524                    for (b, vb) in jets.vars.iter().enumerate() {
9525                        let SaeLocalRowVar::Logit { atom: kb } = *vb else {
9526                            continue;
9527                        };
9528                        if kb >= assignment_dim {
9529                            continue;
9530                        }
9531                        let h_entropy =
9532                            softmax_dense_entropy_hessian_entry(a_soft, ka, kb, m, *scale);
9533                        // `D` is the diagonal Gershgorin majorizer (#1419), so it
9534                        // contributes only on the diagonal `ka == kb`.
9535                        let delta = if ka == kb {
9536                            h_entropy
9537                                - active_softmax_gershgorin_majorizer_entry(a_soft, ka, m, *scale)
9538                        } else {
9539                            h_entropy
9540                        };
9541                        acc += delta * v_t[b];
9542                    }
9543                    out.t[base + a] += acc;
9544                }
9545            }
9546
9547            // (3) periodic ARD: ΔC_coord = (V'' − max(V'',0)) = min(V'',0), diagonal.
9548            for (a, va) in jets.vars.iter().enumerate() {
9549                let SaeLocalRowVar::Coord { atom, axis } = *va else {
9550                    continue;
9551                };
9552                if rho.log_ard[atom].is_empty() {
9553                    continue;
9554                }
9555                let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
9556                let t_val = self.assignment.coords[atom].row(row)[axis];
9557                let prior = ArdAxisPrior::eval(alpha, t_val, ard_axis_periods[atom][axis]);
9558                let neg = prior.hess.min(0.0);
9559                if neg != 0.0 {
9560                    out.t[base + a] += neg * v_t[a];
9561                }
9562            }
9563        }
9564        Ok(out)
9565    }
9566
9567    /// #1418: matrix-free apply of the EXACT stationarity Jacobian `A = ∇²_θθ L`:
9568    /// `A v = B v + ΔC v`, the assembled arrow Hessian apply
9569    /// ([`apply_cached_arrow_hessian`]) plus the matrix-free dropped-curvature
9570    /// correction `ΔC = A − B` ([`Self::apply_exact_hessian_minus_b`]).
9571    fn apply_exact_hessian(
9572        &self,
9573        rho: &SaeManifoldRho,
9574        target: ArrayView2<'_, f64>,
9575        cache: &ArrowFactorCache,
9576        v: &SaeArrowVector,
9577    ) -> Result<SaeArrowVector, String> {
9578        let b_v = apply_cached_arrow_hessian(cache, v.t.view(), v.beta.view())?;
9579        let dc_v = self.apply_exact_hessian_minus_b(rho, target, cache, v)?;
9580        Ok(SaeArrowVector {
9581            t: &b_v.t + &dc_v.t,
9582            beta: &b_v.beta + &dc_v.beta,
9583        })
9584    }
9585
9586    /// #1418: solve `A x = rhs` for the EXACT stationarity Jacobian `A = ∇²_θθ L`
9587    /// via `B`-preconditioned CG ([`solve_b_preconditioned_cg`]) with the
9588    /// matrix-free `A v = B v + ΔC v` apply ([`Self::apply_exact_hessian`]). The
9589    /// IFT step `θ̂_ρ = −A⁻¹ g_ρ` must invert the EXACT `A`, not the surrogate `B`;
9590    /// CG converges for any `ρ(B⁻¹ΔC)`, where the earlier Neumann series diverged
9591    /// once the dropped curvature `ΔC = ⟨r, ∂²f⟩` grew (large unmodellable residual).
9592    fn solve_exact_stationarity(
9593        &self,
9594        rho: &SaeManifoldRho,
9595        target: ArrayView2<'_, f64>,
9596        cache: &ArrowFactorCache,
9597        solver: &DeflatedArrowSolver<'_>,
9598        rhs: &SaeArrowVector,
9599    ) -> Result<SaeArrowVector, String> {
9600        solve_b_preconditioned_cg(solver, rhs, |v| {
9601            self.apply_exact_hessian(rho, target, cache, v)
9602        })
9603    }
9604
9605    /// Analytic SAE REML outer-ρ gradient components at the already converged
9606    /// inner state represented by `loss` and `cache`.
9607    ///
9608    /// The returned gradient is the assembled analytic outer derivative:
9609    /// explicit penalty terms, direct logdet traces, Occam terms, and the #1006
9610    /// implicit-state third-order correction.
9611    pub(crate) fn analytic_outer_rho_gradient_components(
9612        &self,
9613        target: ArrayView2<'_, f64>,
9614        rho: &SaeManifoldRho,
9615        loss: &SaeManifoldLoss,
9616        cache: &ArrowFactorCache,
9617        solver: &DeflatedArrowSolver<'_>,
9618    ) -> Result<SaeOuterRhoGradientComponents, OuterGradientError> {
9619        let n_params = rho.to_flat().len();
9620        let mut explicit = Array1::<f64>::zeros(n_params);
9621        let mut logdet_trace = Array1::<f64>::zeros(n_params);
9622        let mut occam = Array1::<f64>::zeros(n_params);
9623        let mut third_order_correction = Array1::<f64>::zeros(n_params);
9624
9625        explicit[0] = assignment_prior_log_strength_derivative(&self.assignment, rho)
9626            + self
9627                .learnable_ibp_forward_alpha_data_derivative(rho, target)
9628                .map_err(OuterGradientError::internal)?;
9629        // #1417: the FULL `½ tr(H⁻¹ ∂H/∂logα)` for the assignment coordinate.
9630        // For LEARNABLE IBP alpha the forward assignments `a_ik = σ(ℓ/τ)·π_k(α)`
9631        // carry an explicit α-dependence (`∂logπ_k/∂logα = k/(α+1)`), so BOTH the
9632        // assignment-prior Hessian AND the data Gauss-Newton blocks
9633        // `H_ββ`, `H_tβ`, `H_tt` depend on logα. We assemble both traces:
9634        //   • prior:  `assignment_log_strength_hessian_trace`,
9635        //   • data:   `learnable_ibp_data_logdet_alpha_trace` (#1417), using the
9636        //             exact `(k_a+k_b)/(α+1)` block-scaling identity.
9637        // For FIXED alpha (and non-IBP modes) the data term is identically zero,
9638        // so the fixed-alpha gradient is unchanged and exact.
9639        logdet_trace[0] = self
9640            .assignment_log_strength_hessian_trace(rho, cache, solver)
9641            .map_err(OuterGradientError::internal)?
9642            + self
9643                .learnable_ibp_data_logdet_alpha_trace(rho, cache, solver)
9644                .map_err(OuterGradientError::internal)?;
9645
9646        // #1556: λ_smooth is per-atom, so the smoothness gradient block occupies
9647        // flat indices `1..1+K` (one per atom), not a single index 1. Each atom
9648        // `k` carries its own explicit penalty-energy derivative, log|H| trace,
9649        // and Occam-normalizer derivative.
9650        let k_smooth = rho.log_lambda_smooth.len();
9651        let lambda_smooth_vec = rho.lambda_smooth_vec();
9652        // Explicit `∂loss.smoothness/∂log λ_k = 0.5·λ_k·<B_k, S_k B_k>` (the
9653        // per-atom split). Its sum is the λ-scaled penalty energy; renormalize to
9654        // `loss.smoothness` so the total matches the criterion's reported energy
9655        // bit-for-bit (folding in any minibatch `penalty_scale` baked into it).
9656        let mut smooth_explicit = self.decoder_smoothness_value_per_atom(&lambda_smooth_vec);
9657        let smooth_explicit_sum: f64 = smooth_explicit.iter().sum();
9658        if smooth_explicit_sum.abs() > 0.0 {
9659            let renorm = loss.smoothness / smooth_explicit_sum;
9660            for v in smooth_explicit.iter_mut() {
9661                *v *= renorm;
9662            }
9663        }
9664        let smooth_logdet = self
9665            .decoder_smoothness_effective_dof_with_solver_per_atom(
9666                cache,
9667                solver,
9668                &lambda_smooth_vec,
9669            )
9670            .map_err(|err| OuterGradientError::InternalInvariant {
9671                reason: format!("analytic_outer_rho_gradient_components: {err}"),
9672            })?;
9673        let smooth_occam = self
9674            .reml_occam_log_lambda_smooth_derivative()
9675            .map_err(OuterGradientError::internal)?;
9676        for atom_idx in 0..k_smooth {
9677            explicit[1 + atom_idx] = smooth_explicit[atom_idx];
9678            logdet_trace[1 + atom_idx] = 0.5 * smooth_logdet[atom_idx];
9679            occam[1 + atom_idx] = -smooth_occam[atom_idx];
9680        }
9681
9682        let ard_explicit = self
9683            .ard_log_precision_explicit_derivatives(rho)
9684            .map_err(OuterGradientError::internal)?;
9685        let ard_trace = self
9686            .ard_log_precision_hessian_trace(rho, cache, solver)
9687            .map_err(|err| OuterGradientError::InternalInvariant {
9688                reason: format!("analytic_outer_rho_gradient_components: {err}"),
9689            })?;
9690        let mut cursor = 1 + k_smooth;
9691        for k in 0..rho.log_ard.len() {
9692            for axis in 0..rho.log_ard[k].len() {
9693                explicit[cursor] = ard_explicit[k][axis];
9694                logdet_trace[cursor] = ard_trace[k][axis];
9695                cursor += 1;
9696            }
9697        }
9698
9699        let gamma = self
9700            .logdet_theta_adjoint(rho, cache, solver)
9701            .map_err(OuterGradientError::internal)?;
9702        // #1418: the implicit-function correction is `−½·Γᵀ·θ̂_ρ` with
9703        // `θ̂_ρ = −A⁻¹ g_ρ`, where `A = ∇²_θθ L` is the EXACT stationarity
9704        // Jacobian of the inner fit — data residual curvature, exact softmax
9705        // entropy Hessian, exact periodic ARD curvature. The matrix the `solver`
9706        // factors is `B` (Gauss-Newton data curvature, softmax Fisher metric,
9707        // `max(V'',0)` ARD majorizers): the `½log|B|` Laplace term is consistent
9708        // with `Γ = ½tr(B⁻¹ ∂B/∂θ)`, but the implicit step is governed by `A`.
9709        // `solve_exact_stationarity` applies the TRUE `A⁻¹` via a B⁻¹-
9710        // preconditioned Neumann fixed point (`A = B + ΔC`,
9711        // `ΔC = apply_exact_hessian_minus_b`), so the correction is no longer
9712        // biased by `(B⁻¹ − A⁻¹)`.
9713        for coord in 0..n_params {
9714            let rhs = self
9715                .outer_rho_gradient_ift_rhs(rho, target, coord, cache)
9716                .map_err(OuterGradientError::internal)?;
9717            let solved = self
9718                .solve_exact_stationarity(rho, target, cache, solver, &rhs)
9719                .map_err(OuterGradientError::internal)?;
9720            let mut dot = 0.0_f64;
9721            for idx in 0..gamma.t.len() {
9722                dot += gamma.t[idx] * solved.t[idx];
9723            }
9724            for idx in 0..gamma.beta.len() {
9725                dot += gamma.beta[idx] * solved.beta[idx];
9726            }
9727            third_order_correction[coord] = -0.5 * dot;
9728        }
9729
9730        Ok(SaeOuterRhoGradientComponents {
9731            explicit,
9732            logdet_trace,
9733            occam,
9734            third_order_correction,
9735        })
9736    }
9737
9738    /// Public analytic outer-ρ gradient at a converged inner state, constructing
9739    /// the deflated arrow solver from the supplied cache. Use this seam from
9740    /// integration tests and external consumers that have a converged
9741    /// `(loss, cache)` from [`Self::reml_criterion_with_cache`] but no access to
9742    /// the crate-private `DeflatedArrowSolver`.
9743    pub fn analytic_outer_rho_gradient_at_converged(
9744        &self,
9745        target: ArrayView2<'_, f64>,
9746        rho: &SaeManifoldRho,
9747        loss: &SaeManifoldLoss,
9748        cache: &ArrowFactorCache,
9749    ) -> Result<SaeOuterRhoGradientComponents, String> {
9750        let solver = self.outer_gradient_arrow_solver(cache, &rho.lambda_smooth_vec())?;
9751        self.analytic_outer_rho_gradient_components(target, rho, loss, cache, &solver)
9752            .map_err(|e| e.to_string())
9753    }
9754
9755    /// Compose the SAE LAML criterion as a sum of atoms (#931 SAE pilot).
9756    ///
9757    /// This is the single seam that establishes value↔gradient coherence for
9758    /// the SAE objective: it runs the inner solve once via
9759    /// [`Self::reml_criterion_with_cache`], reads the value decomposition
9760    /// (`loss.total() + extra_penalty_energy`, `log|H|`, `occam`) and the
9761    /// matching gradient channels (`SaeOuterRhoGradientComponents`) from the
9762    /// SAME converged cache, and hands them to [`SaeCriterion::assemble`]. The
9763    /// returned criterion's [`SaeCriterion::value`] and
9764    /// [`SaeCriterion::gradient`] are then projections of one factorization —
9765    /// the outer optimizer can no longer evaluate a value path and a gradient
9766    /// path that disagree (the #752/#748/#901 desync class). The
9767    /// implicit-stationarity envelope correction (#1006's Γ term) is its own
9768    /// named atom, so the channel the desync class keeps dropping is visible
9769    /// rather than a silent zero.
9770    pub fn criterion_as_atoms(
9771        &mut self,
9772        target: ArrayView2<'_, f64>,
9773        rho: &SaeManifoldRho,
9774        registry: Option<&AnalyticPenaltyRegistry>,
9775        inner_max_iter: usize,
9776        learning_rate: f64,
9777        ridge_ext_coord: f64,
9778        ridge_beta: f64,
9779    ) -> Result<SaeCriterion, String> {
9780        let (_v, loss, cache) = self.reml_criterion_with_cache(
9781            target,
9782            rho,
9783            registry,
9784            inner_max_iter,
9785            learning_rate,
9786            ridge_ext_coord,
9787            ridge_beta,
9788        )?;
9789        let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
9790            "criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
9791        })?;
9792        let occam = self.reml_occam_term(rho)?;
9793        let extra_penalty_energy = match registry {
9794            Some(reg) => self
9795                .reml_extra_penalty_value_total(reg)
9796                .map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
9797            None => 0.0,
9798        };
9799        let data_fit_priors_value = loss.total() + extra_penalty_energy;
9800
9801        let solver = self.outer_gradient_arrow_solver(&cache, &rho.lambda_smooth_vec())?;
9802        let components =
9803            self.analytic_outer_rho_gradient_components(target, rho, &loss, &cache, &solver)?;
9804        Ok(SaeCriterion::assemble(
9805            data_fit_priors_value,
9806            log_det,
9807            occam,
9808            components.explicit,
9809            components.logdet_trace,
9810            components.occam,
9811            components.third_order_correction,
9812        ))
9813    }
9814
9815    // [#780 line-count gate] reconstruction_dispersion + assemble_shape_uncertainty
9816    // + complete_born_atom_shape_bands + shape_uncertainty_without_decoder_covariance
9817    // (the contiguous trailing methods of this impl block) were split into the
9818    // sibling construction_reconstruction.rs (declared in mod.rs); callers reach
9819    // them bare via use super::*.
9820}
9821
9822// [#780 line-count gate] Per-row jet / reconstruction-channel assembly for the
9823// streaming-exact arrow log-det lives in a sibling file as a second
9824// `impl SaeManifoldTerm` block, inlined here so it keeps the SAME module scope
9825// and private-field access. Keeps this tracked file under the 10k limit.
9826include!("construction_row_jet_logdet_channels.rs");
9827
9828// [#780 line-count gate] `term_from_padded_blocks_with_mode` (the padded-FFI
9829// term builder) was split into the sibling `construction_padded_blocks.rs`
9830// module (declared and re-exported from `mod.rs`), keeping this tracked file
9831// under the 10k limit. Callers still reach it bare through `use super::*`.
9832
9833// [#780 line-count gate] `refresh_isometry_caches_from_atom` and
9834// `refresh_isometry_caches_from_term` were split into the sibling
9835// `construction_cache_refresh.rs` module (declared and re-exported from
9836// `mod.rs`), keeping this tracked file under the 10k limit. Callers still reach
9837// both functions bare through `use super::*`.
9838
9839// [#780 line-count gate] The `#[cfg(test)]` modules below the production code
9840// are mechanically split into a sibling `*_tests` file and inlined via
9841// `include!` (the sanctioned cohesive-module decomposition — see build.rs
9842// file_stem_is_exempt_test_module). Keeps this tracked file under the 10k limit.
9843include!("construction_tests.rs");